Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import torch.nn as nn | |
| from PIL import Image | |
| import numpy as np | |
| import sys | |
| import os | |
| # Add current directory to path | |
| sys.path.append(os.path.dirname(os.path.abspath(__file__))) | |
| from core.network import Network, MLMTransformerPretrain | |
| from model.backbone import get_visual_backbone | |
| from model.encoder import get_encoder | |
| from model.decoder import get_decoder | |
| from datasets.preprossing import SN, SrcLang, TgtLang | |
| from datasets.utils import get_combined_text, get_var_arg, get_text_index | |
| from datasets.operators import normalize_exp | |
| import datasets.diagram_aug as T_diagram | |
| # Configuration class | |
| class Config: | |
| def __init__(self): | |
| # Visual backbone | |
| self.visual_backbone = "ResNet10" | |
| self.diagram_size = 128 | |
| self.pretrain_vis_path = '' # Added missing attribute | |
| # Encoder | |
| self.encoder_type = "gru" | |
| self.encoder_layers = 2 | |
| self.encoder_embedding_size = 256 | |
| self.encoder_hidden_size = 512 | |
| self.max_input_len = 400 | |
| # Decoder | |
| self.decoder_type = "rnn_decoder" | |
| self.decoder_layers = 2 | |
| self.decoder_embedding_size = 512 | |
| self.decoder_hidden_size = 512 | |
| self.max_output_len = 40 | |
| # General | |
| self.dropout_rate = 0.2 | |
| self.beam_size = 10 | |
| self.use_MLM_pretrain = False # Disabled due to dimension mismatch issues in demo | |
| self.MLM_pretrain_path = './LM_MODEL.pth' | |
| self.pretrain_emb_path = '' | |
| # Dataset | |
| self.without_stru = False | |
| # Logger (dummy for compatibility) | |
| self.logger = type('obj', (object,), {'info': lambda x: print(x)}) | |
| # Initialize model | |
| def load_model(): | |
| cfg = Config() | |
| # Load vocabularies using proper Lang classes | |
| src_lang = SrcLang('./vocab/vocab_src.txt') | |
| tgt_lang = TgtLang('./vocab/vocab_tgt.txt') | |
| # Create model | |
| model = Network(cfg, src_lang, tgt_lang) | |
| # Load pretrained weights if available | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print(f"Using device: {device}") | |
| if os.path.exists('./LM_MODEL.pth'): | |
| try: | |
| # Load with proper device mapping | |
| checkpoint = torch.load('./LM_MODEL.pth', map_location=device) | |
| if 'state_dict' in checkpoint: | |
| state_dict = checkpoint['state_dict'] | |
| else: | |
| state_dict = checkpoint | |
| # Filter out incompatible keys | |
| model_dict = model.state_dict() | |
| filtered_dict = {k: v for k, v in state_dict.items() | |
| if k in model_dict and v.shape == model_dict[k].shape} | |
| model_dict.update(filtered_dict) | |
| model.load_state_dict(model_dict, strict=False) | |
| print(f"Loaded {len(filtered_dict)}/{len(state_dict)} parameters from checkpoint") | |
| except Exception as e: | |
| print(f"Warning: Could not load full model weights: {e}") | |
| print("Continuing with randomly initialized weights") | |
| model = model.to(device) | |
| model.eval() | |
| return model, src_lang, tgt_lang, cfg | |
| # Process image and text | |
| def process_input(image, text_input, model, src_lang, tgt_lang, cfg): | |
| # Get device | |
| device = next(model.parameters()).device | |
| # Transform image | |
| diagram_transform = T_diagram.Compose([ | |
| T_diagram.Resize(cfg.diagram_size), | |
| T_diagram.CenterCrop(cfg.diagram_size), | |
| T_diagram.ToTensor(), | |
| T_diagram.Normalize() | |
| ]) | |
| diagram = diagram_transform(image).unsqueeze(0).to(device) | |
| # Process text input | |
| # Create a simple text structure | |
| text_sn = SN() | |
| text_sn.word_list = text_input.split() if text_input.strip() else ["[PAD]"] | |
| text_sn.clause_list = [text_input] if text_input.strip() else ["[PAD]"] | |
| text_sn.token = text_sn.word_list | |
| text_sn.sect_tag = ["[PROB]"] * len(text_sn.word_list) | |
| text_sn.class_tag = ["[GEN]"] * len(text_sn.word_list) | |
| # Create empty parsing structures (will be filled with defaults) | |
| parsing_stru = SN() | |
| parsing_stru.word_list = [] | |
| parsing_stru.clause_list = [] | |
| parsing_stru.token = [] | |
| parsing_stru.sect_tag = [] | |
| parsing_stru.class_tag = [] | |
| parsing_sem = SN() | |
| parsing_sem.word_list = [] | |
| parsing_sem.clause_list = [] | |
| parsing_sem.token = [] | |
| parsing_sem.sect_tag = [] | |
| parsing_sem.class_tag = [] | |
| # Combine text - but if get_combined_text fails, use fallback | |
| combine_text = SN() | |
| try: | |
| get_combined_text(text_sn, parsing_stru, parsing_sem, combine_text, cfg) | |
| except: | |
| # Fallback if get_combined_text fails | |
| combine_text.token = text_sn.token | |
| combine_text.sect_tag = text_sn.sect_tag | |
| combine_text.class_tag = text_sn.class_tag | |
| # Get text indices - ensure we have at least one token | |
| try: | |
| text_token, text_sect_tag, text_class_tag = get_text_index(combine_text, src_lang) | |
| except: | |
| # Fallback to simple processing | |
| text_token = [src_lang.word2index.get(w, 0) for w in combine_text.token] | |
| text_sect_tag = [1] * len(text_token) # Default to [PROB] | |
| text_class_tag = [1] * len(text_token) # Default to [GEN] | |
| # Ensure minimum length | |
| if len(text_token) == 0: | |
| text_token = [0] # PAD token | |
| text_sect_tag = [0] | |
| text_class_tag = [0] | |
| # Convert to tensors and move to device | |
| text_dict = { | |
| 'token': torch.LongTensor([text_token]).to(device), | |
| 'sect_tag': torch.LongTensor([text_sect_tag]).to(device), | |
| 'class_tag': torch.LongTensor([text_class_tag]).to(device), | |
| 'len': torch.LongTensor([len(text_token)]).to(device) | |
| } | |
| # Get variables and arguments | |
| var_arg_positions, var_values, arg_values = get_var_arg(combine_text, cfg) | |
| var_dict = { | |
| 'pos': torch.LongTensor([var_arg_positions]).to(device), | |
| 'len': torch.LongTensor([len(var_arg_positions)]).to(device), | |
| 'var_value': var_values, | |
| 'arg_value': arg_values | |
| } | |
| # Create dummy expression dict for inference | |
| exp_dict = { | |
| 'exp': torch.LongTensor([[1]]).to(device), # SOS token | |
| 'len': torch.LongTensor([1]).to(device), | |
| 'answer': 0 | |
| } | |
| # Run inference | |
| with torch.no_grad(): | |
| outputs = model(diagram, text_dict, var_dict, exp_dict, is_train=False) | |
| # Decode outputs | |
| if outputs is not None: | |
| # Convert output indices to symbols | |
| output_symbols = [] | |
| for idx in outputs[0]: | |
| if idx < len(tgt_lang.index2word): | |
| symbol = tgt_lang.index2word[idx] | |
| if symbol == 'EOS': | |
| break | |
| if symbol not in ['PAD', 'SOS']: | |
| output_symbols.append(symbol) | |
| expression = ' '.join(output_symbols) | |
| # Try to evaluate the expression | |
| try: | |
| # Simple evaluation (this would need more sophisticated handling in production) | |
| result = eval_expression(expression, var_values, arg_values) | |
| return f"Expression: {expression}\nResult: {result}" | |
| except: | |
| return f"Expression: {expression}\n(Could not evaluate)" | |
| return "Could not generate solution" | |
| def eval_expression(expr, var_values, arg_values): | |
| # This is a simplified evaluator - would need proper implementation | |
| # For now, just return the expression | |
| return expr | |
| # Gradio interface | |
| def predict(image, text): | |
| if image is None: | |
| return "Please upload a geometry diagram image" | |
| if not text or text.strip() == "": | |
| text = "Find the value of x" # Default text if empty | |
| try: | |
| # Try the simple inference first | |
| from simple_inference import simple_process_input | |
| result = simple_process_input(image, text, model, src_lang, tgt_lang, cfg) | |
| return result | |
| except Exception as e: | |
| # Fallback to original method | |
| try: | |
| result = process_input(image, text, model, src_lang, tgt_lang, cfg) | |
| return result | |
| except Exception as e2: | |
| import traceback | |
| error_details = traceback.format_exc() | |
| return f"Error processing input: {str(e2)}\n\nDetails:\n{error_details[-500:]}" # Show last 500 chars of traceback | |
| # Load model on startup | |
| print("Loading PGPS model...") | |
| model, src_lang, tgt_lang, cfg = load_model() | |
| print("Model loaded successfully!") | |
| # Create Gradio interface with v5+ compatible syntax | |
| with gr.Blocks(title="PGPS: Neural Geometric Problem Solver") as demo: | |
| gr.Markdown("# PGPS: Neural Geometric Problem Solver") | |
| gr.Markdown("Upload a geometry diagram and provide the problem text to get a solution.") | |
| with gr.Row(): | |
| with gr.Column(): | |
| image_input = gr.Image( | |
| type="pil", | |
| label="Geometry Diagram", | |
| height=300 | |
| ) | |
| text_input = gr.Textbox( | |
| lines=3, | |
| placeholder="Enter the geometry problem text here...\nExample: Find the angle x if angle ABC is 60 degrees", | |
| label="Problem Text" | |
| ) | |
| submit_btn = gr.Button("Solve", variant="primary") | |
| with gr.Column(): | |
| output = gr.Textbox( | |
| label="Solution", | |
| lines=10, | |
| max_lines=20 | |
| ) | |
| # Examples | |
| gr.Examples( | |
| examples=[ | |
| [None, "Find the value of angle x if angle ABC is 60 degrees and angle BCD is 90 degrees"], | |
| [None, "Calculate the area of triangle ABC if AB = 5, BC = 7, and angle B = 60 degrees"], | |
| [None, "In triangle PQR, if angle P = 45 degrees and angle Q = 60 degrees, find angle R"], | |
| [None, "Find the perimeter of a rectangle with length 8 and width 5"] | |
| ], | |
| inputs=[image_input, text_input], | |
| outputs=output, | |
| fn=predict, | |
| cache_examples=False | |
| ) | |
| # Event handlers | |
| submit_btn.click( | |
| fn=predict, | |
| inputs=[image_input, text_input], | |
| outputs=output | |
| ) | |
| text_input.submit( | |
| fn=predict, | |
| inputs=[image_input, text_input], | |
| outputs=output | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |