import gradio as gr import torch import torch.nn as nn from PIL import Image import numpy as np import sys import os # Add current directory to path sys.path.append(os.path.dirname(os.path.abspath(__file__))) from core.network import Network, MLMTransformerPretrain from model.backbone import get_visual_backbone from model.encoder import get_encoder from model.decoder import get_decoder from datasets.preprossing import SN, SrcLang, TgtLang from datasets.utils import get_combined_text, get_var_arg, get_text_index from datasets.operators import normalize_exp import datasets.diagram_aug as T_diagram # Configuration class class Config: def __init__(self): # Visual backbone self.visual_backbone = "ResNet10" self.diagram_size = 128 self.pretrain_vis_path = '' # Added missing attribute # Encoder self.encoder_type = "gru" self.encoder_layers = 2 self.encoder_embedding_size = 256 self.encoder_hidden_size = 512 self.max_input_len = 400 # Decoder self.decoder_type = "rnn_decoder" self.decoder_layers = 2 self.decoder_embedding_size = 512 self.decoder_hidden_size = 512 self.max_output_len = 40 # General self.dropout_rate = 0.2 self.beam_size = 10 self.use_MLM_pretrain = False # Disabled due to dimension mismatch issues in demo self.MLM_pretrain_path = './LM_MODEL.pth' self.pretrain_emb_path = '' # Dataset self.without_stru = False # Logger (dummy for compatibility) self.logger = type('obj', (object,), {'info': lambda x: print(x)}) # Initialize model def load_model(): cfg = Config() # Load vocabularies using proper Lang classes src_lang = SrcLang('./vocab/vocab_src.txt') tgt_lang = TgtLang('./vocab/vocab_tgt.txt') # Create model model = Network(cfg, src_lang, tgt_lang) # Load pretrained weights if available device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {device}") if os.path.exists('./LM_MODEL.pth'): try: # Load with proper device mapping checkpoint = torch.load('./LM_MODEL.pth', map_location=device) if 'state_dict' in checkpoint: state_dict = checkpoint['state_dict'] else: state_dict = checkpoint # Filter out incompatible keys model_dict = model.state_dict() filtered_dict = {k: v for k, v in state_dict.items() if k in model_dict and v.shape == model_dict[k].shape} model_dict.update(filtered_dict) model.load_state_dict(model_dict, strict=False) print(f"Loaded {len(filtered_dict)}/{len(state_dict)} parameters from checkpoint") except Exception as e: print(f"Warning: Could not load full model weights: {e}") print("Continuing with randomly initialized weights") model = model.to(device) model.eval() return model, src_lang, tgt_lang, cfg # Process image and text def process_input(image, text_input, model, src_lang, tgt_lang, cfg): # Get device device = next(model.parameters()).device # Transform image diagram_transform = T_diagram.Compose([ T_diagram.Resize(cfg.diagram_size), T_diagram.CenterCrop(cfg.diagram_size), T_diagram.ToTensor(), T_diagram.Normalize() ]) diagram = diagram_transform(image).unsqueeze(0).to(device) # Process text input # Create a simple text structure text_sn = SN() text_sn.word_list = text_input.split() if text_input.strip() else ["[PAD]"] text_sn.clause_list = [text_input] if text_input.strip() else ["[PAD]"] text_sn.token = text_sn.word_list text_sn.sect_tag = ["[PROB]"] * len(text_sn.word_list) text_sn.class_tag = ["[GEN]"] * len(text_sn.word_list) # Create empty parsing structures (will be filled with defaults) parsing_stru = SN() parsing_stru.word_list = [] parsing_stru.clause_list = [] parsing_stru.token = [] parsing_stru.sect_tag = [] parsing_stru.class_tag = [] parsing_sem = SN() parsing_sem.word_list = [] parsing_sem.clause_list = [] parsing_sem.token = [] parsing_sem.sect_tag = [] parsing_sem.class_tag = [] # Combine text - but if get_combined_text fails, use fallback combine_text = SN() try: get_combined_text(text_sn, parsing_stru, parsing_sem, combine_text, cfg) except: # Fallback if get_combined_text fails combine_text.token = text_sn.token combine_text.sect_tag = text_sn.sect_tag combine_text.class_tag = text_sn.class_tag # Get text indices - ensure we have at least one token try: text_token, text_sect_tag, text_class_tag = get_text_index(combine_text, src_lang) except: # Fallback to simple processing text_token = [src_lang.word2index.get(w, 0) for w in combine_text.token] text_sect_tag = [1] * len(text_token) # Default to [PROB] text_class_tag = [1] * len(text_token) # Default to [GEN] # Ensure minimum length if len(text_token) == 0: text_token = [0] # PAD token text_sect_tag = [0] text_class_tag = [0] # Convert to tensors and move to device text_dict = { 'token': torch.LongTensor([text_token]).to(device), 'sect_tag': torch.LongTensor([text_sect_tag]).to(device), 'class_tag': torch.LongTensor([text_class_tag]).to(device), 'len': torch.LongTensor([len(text_token)]).to(device) } # Get variables and arguments var_arg_positions, var_values, arg_values = get_var_arg(combine_text, cfg) var_dict = { 'pos': torch.LongTensor([var_arg_positions]).to(device), 'len': torch.LongTensor([len(var_arg_positions)]).to(device), 'var_value': var_values, 'arg_value': arg_values } # Create dummy expression dict for inference exp_dict = { 'exp': torch.LongTensor([[1]]).to(device), # SOS token 'len': torch.LongTensor([1]).to(device), 'answer': 0 } # Run inference with torch.no_grad(): outputs = model(diagram, text_dict, var_dict, exp_dict, is_train=False) # Decode outputs if outputs is not None: # Convert output indices to symbols output_symbols = [] for idx in outputs[0]: if idx < len(tgt_lang.index2word): symbol = tgt_lang.index2word[idx] if symbol == 'EOS': break if symbol not in ['PAD', 'SOS']: output_symbols.append(symbol) expression = ' '.join(output_symbols) # Try to evaluate the expression try: # Simple evaluation (this would need more sophisticated handling in production) result = eval_expression(expression, var_values, arg_values) return f"Expression: {expression}\nResult: {result}" except: return f"Expression: {expression}\n(Could not evaluate)" return "Could not generate solution" def eval_expression(expr, var_values, arg_values): # This is a simplified evaluator - would need proper implementation # For now, just return the expression return expr # Gradio interface def predict(image, text): if image is None: return "Please upload a geometry diagram image" if not text or text.strip() == "": text = "Find the value of x" # Default text if empty try: # Try the simple inference first from simple_inference import simple_process_input result = simple_process_input(image, text, model, src_lang, tgt_lang, cfg) return result except Exception as e: # Fallback to original method try: result = process_input(image, text, model, src_lang, tgt_lang, cfg) return result except Exception as e2: import traceback error_details = traceback.format_exc() return f"Error processing input: {str(e2)}\n\nDetails:\n{error_details[-500:]}" # Show last 500 chars of traceback # Load model on startup print("Loading PGPS model...") model, src_lang, tgt_lang, cfg = load_model() print("Model loaded successfully!") # Create Gradio interface with v5+ compatible syntax with gr.Blocks(title="PGPS: Neural Geometric Problem Solver") as demo: gr.Markdown("# PGPS: Neural Geometric Problem Solver") gr.Markdown("Upload a geometry diagram and provide the problem text to get a solution.") with gr.Row(): with gr.Column(): image_input = gr.Image( type="pil", label="Geometry Diagram", height=300 ) text_input = gr.Textbox( lines=3, placeholder="Enter the geometry problem text here...\nExample: Find the angle x if angle ABC is 60 degrees", label="Problem Text" ) submit_btn = gr.Button("Solve", variant="primary") with gr.Column(): output = gr.Textbox( label="Solution", lines=10, max_lines=20 ) # Examples gr.Examples( examples=[ [None, "Find the value of angle x if angle ABC is 60 degrees and angle BCD is 90 degrees"], [None, "Calculate the area of triangle ABC if AB = 5, BC = 7, and angle B = 60 degrees"], [None, "In triangle PQR, if angle P = 45 degrees and angle Q = 60 degrees, find angle R"], [None, "Find the perimeter of a rectangle with length 8 and width 5"] ], inputs=[image_input, text_input], outputs=output, fn=predict, cache_examples=False ) # Event handlers submit_btn.click( fn=predict, inputs=[image_input, text_input], outputs=output ) text_input.submit( fn=predict, inputs=[image_input, text_input], outputs=output ) if __name__ == "__main__": demo.launch()