Spaces:
Sleeping
Sleeping
File size: 10,376 Bytes
383bfb8 2a2cec1 383bfb8 6f74e93 383bfb8 96836c8 383bfb8 6f74e93 383bfb8 2a2cec1 383bfb8 9658342 383bfb8 51e6305 9658342 51e6305 383bfb8 9658342 383bfb8 9658342 383bfb8 9658342 383bfb8 bf8c161 383bfb8 bf8c161 383bfb8 bf8c161 383bfb8 bf8c161 383bfb8 bf8c161 383bfb8 bf8c161 383bfb8 9658342 383bfb8 9658342 383bfb8 9658342 383bfb8 9658342 383bfb8 bf8c161 383bfb8 bf8c161 383bfb8 bf8c161 383bfb8 93ac856 383bfb8 93ac856 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 |
import gradio as gr
import torch
import torch.nn as nn
from PIL import Image
import numpy as np
import sys
import os
# Add current directory to path
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from core.network import Network, MLMTransformerPretrain
from model.backbone import get_visual_backbone
from model.encoder import get_encoder
from model.decoder import get_decoder
from datasets.preprossing import SN, SrcLang, TgtLang
from datasets.utils import get_combined_text, get_var_arg, get_text_index
from datasets.operators import normalize_exp
import datasets.diagram_aug as T_diagram
# Configuration class
class Config:
def __init__(self):
# Visual backbone
self.visual_backbone = "ResNet10"
self.diagram_size = 128
self.pretrain_vis_path = '' # Added missing attribute
# Encoder
self.encoder_type = "gru"
self.encoder_layers = 2
self.encoder_embedding_size = 256
self.encoder_hidden_size = 512
self.max_input_len = 400
# Decoder
self.decoder_type = "rnn_decoder"
self.decoder_layers = 2
self.decoder_embedding_size = 512
self.decoder_hidden_size = 512
self.max_output_len = 40
# General
self.dropout_rate = 0.2
self.beam_size = 10
self.use_MLM_pretrain = False # Disabled due to dimension mismatch issues in demo
self.MLM_pretrain_path = './LM_MODEL.pth'
self.pretrain_emb_path = ''
# Dataset
self.without_stru = False
# Logger (dummy for compatibility)
self.logger = type('obj', (object,), {'info': lambda x: print(x)})
# Initialize model
def load_model():
cfg = Config()
# Load vocabularies using proper Lang classes
src_lang = SrcLang('./vocab/vocab_src.txt')
tgt_lang = TgtLang('./vocab/vocab_tgt.txt')
# Create model
model = Network(cfg, src_lang, tgt_lang)
# Load pretrained weights if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
if os.path.exists('./LM_MODEL.pth'):
try:
# Load with proper device mapping
checkpoint = torch.load('./LM_MODEL.pth', map_location=device)
if 'state_dict' in checkpoint:
state_dict = checkpoint['state_dict']
else:
state_dict = checkpoint
# Filter out incompatible keys
model_dict = model.state_dict()
filtered_dict = {k: v for k, v in state_dict.items()
if k in model_dict and v.shape == model_dict[k].shape}
model_dict.update(filtered_dict)
model.load_state_dict(model_dict, strict=False)
print(f"Loaded {len(filtered_dict)}/{len(state_dict)} parameters from checkpoint")
except Exception as e:
print(f"Warning: Could not load full model weights: {e}")
print("Continuing with randomly initialized weights")
model = model.to(device)
model.eval()
return model, src_lang, tgt_lang, cfg
# Process image and text
def process_input(image, text_input, model, src_lang, tgt_lang, cfg):
# Get device
device = next(model.parameters()).device
# Transform image
diagram_transform = T_diagram.Compose([
T_diagram.Resize(cfg.diagram_size),
T_diagram.CenterCrop(cfg.diagram_size),
T_diagram.ToTensor(),
T_diagram.Normalize()
])
diagram = diagram_transform(image).unsqueeze(0).to(device)
# Process text input
# Create a simple text structure
text_sn = SN()
text_sn.word_list = text_input.split() if text_input.strip() else ["[PAD]"]
text_sn.clause_list = [text_input] if text_input.strip() else ["[PAD]"]
text_sn.token = text_sn.word_list
text_sn.sect_tag = ["[PROB]"] * len(text_sn.word_list)
text_sn.class_tag = ["[GEN]"] * len(text_sn.word_list)
# Create empty parsing structures (will be filled with defaults)
parsing_stru = SN()
parsing_stru.word_list = []
parsing_stru.clause_list = []
parsing_stru.token = []
parsing_stru.sect_tag = []
parsing_stru.class_tag = []
parsing_sem = SN()
parsing_sem.word_list = []
parsing_sem.clause_list = []
parsing_sem.token = []
parsing_sem.sect_tag = []
parsing_sem.class_tag = []
# Combine text - but if get_combined_text fails, use fallback
combine_text = SN()
try:
get_combined_text(text_sn, parsing_stru, parsing_sem, combine_text, cfg)
except:
# Fallback if get_combined_text fails
combine_text.token = text_sn.token
combine_text.sect_tag = text_sn.sect_tag
combine_text.class_tag = text_sn.class_tag
# Get text indices - ensure we have at least one token
try:
text_token, text_sect_tag, text_class_tag = get_text_index(combine_text, src_lang)
except:
# Fallback to simple processing
text_token = [src_lang.word2index.get(w, 0) for w in combine_text.token]
text_sect_tag = [1] * len(text_token) # Default to [PROB]
text_class_tag = [1] * len(text_token) # Default to [GEN]
# Ensure minimum length
if len(text_token) == 0:
text_token = [0] # PAD token
text_sect_tag = [0]
text_class_tag = [0]
# Convert to tensors and move to device
text_dict = {
'token': torch.LongTensor([text_token]).to(device),
'sect_tag': torch.LongTensor([text_sect_tag]).to(device),
'class_tag': torch.LongTensor([text_class_tag]).to(device),
'len': torch.LongTensor([len(text_token)]).to(device)
}
# Get variables and arguments
var_arg_positions, var_values, arg_values = get_var_arg(combine_text, cfg)
var_dict = {
'pos': torch.LongTensor([var_arg_positions]).to(device),
'len': torch.LongTensor([len(var_arg_positions)]).to(device),
'var_value': var_values,
'arg_value': arg_values
}
# Create dummy expression dict for inference
exp_dict = {
'exp': torch.LongTensor([[1]]).to(device), # SOS token
'len': torch.LongTensor([1]).to(device),
'answer': 0
}
# Run inference
with torch.no_grad():
outputs = model(diagram, text_dict, var_dict, exp_dict, is_train=False)
# Decode outputs
if outputs is not None:
# Convert output indices to symbols
output_symbols = []
for idx in outputs[0]:
if idx < len(tgt_lang.index2word):
symbol = tgt_lang.index2word[idx]
if symbol == 'EOS':
break
if symbol not in ['PAD', 'SOS']:
output_symbols.append(symbol)
expression = ' '.join(output_symbols)
# Try to evaluate the expression
try:
# Simple evaluation (this would need more sophisticated handling in production)
result = eval_expression(expression, var_values, arg_values)
return f"Expression: {expression}\nResult: {result}"
except:
return f"Expression: {expression}\n(Could not evaluate)"
return "Could not generate solution"
def eval_expression(expr, var_values, arg_values):
# This is a simplified evaluator - would need proper implementation
# For now, just return the expression
return expr
# Gradio interface
def predict(image, text):
if image is None:
return "Please upload a geometry diagram image"
if not text or text.strip() == "":
text = "Find the value of x" # Default text if empty
try:
# Try the simple inference first
from simple_inference import simple_process_input
result = simple_process_input(image, text, model, src_lang, tgt_lang, cfg)
return result
except Exception as e:
# Fallback to original method
try:
result = process_input(image, text, model, src_lang, tgt_lang, cfg)
return result
except Exception as e2:
import traceback
error_details = traceback.format_exc()
return f"Error processing input: {str(e2)}\n\nDetails:\n{error_details[-500:]}" # Show last 500 chars of traceback
# Load model on startup
print("Loading PGPS model...")
model, src_lang, tgt_lang, cfg = load_model()
print("Model loaded successfully!")
# Create Gradio interface with v5+ compatible syntax
with gr.Blocks(title="PGPS: Neural Geometric Problem Solver") as demo:
gr.Markdown("# PGPS: Neural Geometric Problem Solver")
gr.Markdown("Upload a geometry diagram and provide the problem text to get a solution.")
with gr.Row():
with gr.Column():
image_input = gr.Image(
type="pil",
label="Geometry Diagram",
height=300
)
text_input = gr.Textbox(
lines=3,
placeholder="Enter the geometry problem text here...\nExample: Find the angle x if angle ABC is 60 degrees",
label="Problem Text"
)
submit_btn = gr.Button("Solve", variant="primary")
with gr.Column():
output = gr.Textbox(
label="Solution",
lines=10,
max_lines=20
)
# Examples
gr.Examples(
examples=[
[None, "Find the value of angle x if angle ABC is 60 degrees and angle BCD is 90 degrees"],
[None, "Calculate the area of triangle ABC if AB = 5, BC = 7, and angle B = 60 degrees"],
[None, "In triangle PQR, if angle P = 45 degrees and angle Q = 60 degrees, find angle R"],
[None, "Find the perimeter of a rectangle with length 8 and width 5"]
],
inputs=[image_input, text_input],
outputs=output,
fn=predict,
cache_examples=False
)
# Event handlers
submit_btn.click(
fn=predict,
inputs=[image_input, text_input],
outputs=output
)
text_input.submit(
fn=predict,
inputs=[image_input, text_input],
outputs=output
)
if __name__ == "__main__":
demo.launch() |