pgps-demo / app.py
asdfasdfdsafdsa's picture
Fix tensor dimension mismatch by disabling MLM pretrain for demo
96836c8 verified
import gradio as gr
import torch
import torch.nn as nn
from PIL import Image
import numpy as np
import sys
import os
# Add current directory to path
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from core.network import Network, MLMTransformerPretrain
from model.backbone import get_visual_backbone
from model.encoder import get_encoder
from model.decoder import get_decoder
from datasets.preprossing import SN, SrcLang, TgtLang
from datasets.utils import get_combined_text, get_var_arg, get_text_index
from datasets.operators import normalize_exp
import datasets.diagram_aug as T_diagram
# Configuration class
class Config:
def __init__(self):
# Visual backbone
self.visual_backbone = "ResNet10"
self.diagram_size = 128
self.pretrain_vis_path = '' # Added missing attribute
# Encoder
self.encoder_type = "gru"
self.encoder_layers = 2
self.encoder_embedding_size = 256
self.encoder_hidden_size = 512
self.max_input_len = 400
# Decoder
self.decoder_type = "rnn_decoder"
self.decoder_layers = 2
self.decoder_embedding_size = 512
self.decoder_hidden_size = 512
self.max_output_len = 40
# General
self.dropout_rate = 0.2
self.beam_size = 10
self.use_MLM_pretrain = False # Disabled due to dimension mismatch issues in demo
self.MLM_pretrain_path = './LM_MODEL.pth'
self.pretrain_emb_path = ''
# Dataset
self.without_stru = False
# Logger (dummy for compatibility)
self.logger = type('obj', (object,), {'info': lambda x: print(x)})
# Initialize model
def load_model():
cfg = Config()
# Load vocabularies using proper Lang classes
src_lang = SrcLang('./vocab/vocab_src.txt')
tgt_lang = TgtLang('./vocab/vocab_tgt.txt')
# Create model
model = Network(cfg, src_lang, tgt_lang)
# Load pretrained weights if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
if os.path.exists('./LM_MODEL.pth'):
try:
# Load with proper device mapping
checkpoint = torch.load('./LM_MODEL.pth', map_location=device)
if 'state_dict' in checkpoint:
state_dict = checkpoint['state_dict']
else:
state_dict = checkpoint
# Filter out incompatible keys
model_dict = model.state_dict()
filtered_dict = {k: v for k, v in state_dict.items()
if k in model_dict and v.shape == model_dict[k].shape}
model_dict.update(filtered_dict)
model.load_state_dict(model_dict, strict=False)
print(f"Loaded {len(filtered_dict)}/{len(state_dict)} parameters from checkpoint")
except Exception as e:
print(f"Warning: Could not load full model weights: {e}")
print("Continuing with randomly initialized weights")
model = model.to(device)
model.eval()
return model, src_lang, tgt_lang, cfg
# Process image and text
def process_input(image, text_input, model, src_lang, tgt_lang, cfg):
# Get device
device = next(model.parameters()).device
# Transform image
diagram_transform = T_diagram.Compose([
T_diagram.Resize(cfg.diagram_size),
T_diagram.CenterCrop(cfg.diagram_size),
T_diagram.ToTensor(),
T_diagram.Normalize()
])
diagram = diagram_transform(image).unsqueeze(0).to(device)
# Process text input
# Create a simple text structure
text_sn = SN()
text_sn.word_list = text_input.split() if text_input.strip() else ["[PAD]"]
text_sn.clause_list = [text_input] if text_input.strip() else ["[PAD]"]
text_sn.token = text_sn.word_list
text_sn.sect_tag = ["[PROB]"] * len(text_sn.word_list)
text_sn.class_tag = ["[GEN]"] * len(text_sn.word_list)
# Create empty parsing structures (will be filled with defaults)
parsing_stru = SN()
parsing_stru.word_list = []
parsing_stru.clause_list = []
parsing_stru.token = []
parsing_stru.sect_tag = []
parsing_stru.class_tag = []
parsing_sem = SN()
parsing_sem.word_list = []
parsing_sem.clause_list = []
parsing_sem.token = []
parsing_sem.sect_tag = []
parsing_sem.class_tag = []
# Combine text - but if get_combined_text fails, use fallback
combine_text = SN()
try:
get_combined_text(text_sn, parsing_stru, parsing_sem, combine_text, cfg)
except:
# Fallback if get_combined_text fails
combine_text.token = text_sn.token
combine_text.sect_tag = text_sn.sect_tag
combine_text.class_tag = text_sn.class_tag
# Get text indices - ensure we have at least one token
try:
text_token, text_sect_tag, text_class_tag = get_text_index(combine_text, src_lang)
except:
# Fallback to simple processing
text_token = [src_lang.word2index.get(w, 0) for w in combine_text.token]
text_sect_tag = [1] * len(text_token) # Default to [PROB]
text_class_tag = [1] * len(text_token) # Default to [GEN]
# Ensure minimum length
if len(text_token) == 0:
text_token = [0] # PAD token
text_sect_tag = [0]
text_class_tag = [0]
# Convert to tensors and move to device
text_dict = {
'token': torch.LongTensor([text_token]).to(device),
'sect_tag': torch.LongTensor([text_sect_tag]).to(device),
'class_tag': torch.LongTensor([text_class_tag]).to(device),
'len': torch.LongTensor([len(text_token)]).to(device)
}
# Get variables and arguments
var_arg_positions, var_values, arg_values = get_var_arg(combine_text, cfg)
var_dict = {
'pos': torch.LongTensor([var_arg_positions]).to(device),
'len': torch.LongTensor([len(var_arg_positions)]).to(device),
'var_value': var_values,
'arg_value': arg_values
}
# Create dummy expression dict for inference
exp_dict = {
'exp': torch.LongTensor([[1]]).to(device), # SOS token
'len': torch.LongTensor([1]).to(device),
'answer': 0
}
# Run inference
with torch.no_grad():
outputs = model(diagram, text_dict, var_dict, exp_dict, is_train=False)
# Decode outputs
if outputs is not None:
# Convert output indices to symbols
output_symbols = []
for idx in outputs[0]:
if idx < len(tgt_lang.index2word):
symbol = tgt_lang.index2word[idx]
if symbol == 'EOS':
break
if symbol not in ['PAD', 'SOS']:
output_symbols.append(symbol)
expression = ' '.join(output_symbols)
# Try to evaluate the expression
try:
# Simple evaluation (this would need more sophisticated handling in production)
result = eval_expression(expression, var_values, arg_values)
return f"Expression: {expression}\nResult: {result}"
except:
return f"Expression: {expression}\n(Could not evaluate)"
return "Could not generate solution"
def eval_expression(expr, var_values, arg_values):
# This is a simplified evaluator - would need proper implementation
# For now, just return the expression
return expr
# Gradio interface
def predict(image, text):
if image is None:
return "Please upload a geometry diagram image"
if not text or text.strip() == "":
text = "Find the value of x" # Default text if empty
try:
# Try the simple inference first
from simple_inference import simple_process_input
result = simple_process_input(image, text, model, src_lang, tgt_lang, cfg)
return result
except Exception as e:
# Fallback to original method
try:
result = process_input(image, text, model, src_lang, tgt_lang, cfg)
return result
except Exception as e2:
import traceback
error_details = traceback.format_exc()
return f"Error processing input: {str(e2)}\n\nDetails:\n{error_details[-500:]}" # Show last 500 chars of traceback
# Load model on startup
print("Loading PGPS model...")
model, src_lang, tgt_lang, cfg = load_model()
print("Model loaded successfully!")
# Create Gradio interface with v5+ compatible syntax
with gr.Blocks(title="PGPS: Neural Geometric Problem Solver") as demo:
gr.Markdown("# PGPS: Neural Geometric Problem Solver")
gr.Markdown("Upload a geometry diagram and provide the problem text to get a solution.")
with gr.Row():
with gr.Column():
image_input = gr.Image(
type="pil",
label="Geometry Diagram",
height=300
)
text_input = gr.Textbox(
lines=3,
placeholder="Enter the geometry problem text here...\nExample: Find the angle x if angle ABC is 60 degrees",
label="Problem Text"
)
submit_btn = gr.Button("Solve", variant="primary")
with gr.Column():
output = gr.Textbox(
label="Solution",
lines=10,
max_lines=20
)
# Examples
gr.Examples(
examples=[
[None, "Find the value of angle x if angle ABC is 60 degrees and angle BCD is 90 degrees"],
[None, "Calculate the area of triangle ABC if AB = 5, BC = 7, and angle B = 60 degrees"],
[None, "In triangle PQR, if angle P = 45 degrees and angle Q = 60 degrees, find angle R"],
[None, "Find the perimeter of a rectangle with length 8 and width 5"]
],
inputs=[image_input, text_input],
outputs=output,
fn=predict,
cache_examples=False
)
# Event handlers
submit_btn.click(
fn=predict,
inputs=[image_input, text_input],
outputs=output
)
text_input.submit(
fn=predict,
inputs=[image_input, text_input],
outputs=output
)
if __name__ == "__main__":
demo.launch()