pgps-demo / simple_inference.py
asdfasdfdsafdsa's picture
Fix tensor dimension mismatch by disabling MLM pretrain for demo
96836c8 verified
import torch
import torch.nn as nn
from PIL import Image
import numpy as np
def simple_process_input(image, text_input, model, src_lang, tgt_lang, cfg):
"""Simplified inference that bypasses complex text processing"""
device = next(model.parameters()).device
# Transform image
import datasets.diagram_aug as T_diagram
diagram_transform = T_diagram.Compose([
T_diagram.Resize(cfg.diagram_size),
T_diagram.CenterCrop(cfg.diagram_size),
T_diagram.ToTensor(),
T_diagram.Normalize()
])
diagram = diagram_transform(image).unsqueeze(0).to(device)
# Simple text processing - just tokenize the words
words = text_input.split() if text_input.strip() else ["problem"]
# Map words to indices (use PAD for unknown)
text_indices = []
for word in words:
if word in src_lang.word2index:
text_indices.append(src_lang.word2index[word])
else:
text_indices.append(0) # PAD token
# Ensure minimum length
if len(text_indices) == 0:
text_indices = [0]
# Create text tensors with proper shape
batch_size = 1
text_len = len(text_indices)
# For MLM pretrain, tokens need to be 3D: [batch, seq_len, vocab_size]
# But here we use 2D: [batch, seq_len] and let the embedding layer handle it
token_tensor = torch.LongTensor([text_indices]).to(device)
# Ensure sect_tag and class_tag match token length
sect_tag_indices = [1] * text_len # Default to [PROB]
class_tag_indices = [1] * text_len # Default to [GEN]
# The model expects token to be [batch, num_subwords_per_token, seq_len]
# For simple case, we have 1 subword per token, so shape is [batch, 1, seq_len]
# This gets embedded and summed over dim=1 to get [batch, seq_len, embed_dim]
# Create 3D tensor: [batch_size, 1, text_len]
# Each token is a single subword, so middle dimension is 1
token_tensor_3d = token_tensor.unsqueeze(1) # [batch, 1, seq_len]
text_dict = {
'token': token_tensor_3d,
'sect_tag': torch.LongTensor([sect_tag_indices]).to(device),
'class_tag': torch.LongTensor([class_tag_indices]).to(device),
'len': torch.LongTensor([text_len]).to(device)
}
# Simple var dict (no variables detected)
# Note: var positions need to account for the diagram token that will be added
var_dict = {
'pos': torch.zeros(batch_size, 1, dtype=torch.long).to(device),
'len': torch.zeros(batch_size, dtype=torch.long).to(device),
'var_value': [],
'arg_value': []
}
# Expression dict for inference
exp_dict = {
'exp': torch.LongTensor([[1]]).to(device), # SOS token
'len': torch.ones(batch_size, dtype=torch.long).to(device),
'answer': 0
}
# Run inference with no_grad
with torch.no_grad():
try:
# Create a copy of text_dict to avoid in-place modification
text_dict_copy = {
'token': text_dict['token'].clone(),
'sect_tag': text_dict['sect_tag'].clone(),
'class_tag': text_dict['class_tag'].clone(),
'len': text_dict['len'].clone()
}
var_dict_copy = {
'pos': var_dict['pos'].clone(),
'len': var_dict['len'].clone(),
'var_value': var_dict['var_value'],
'arg_value': var_dict['arg_value']
}
outputs = model(diagram, text_dict_copy, var_dict_copy, exp_dict, is_train=False)
except Exception as e:
return f"Model inference error: {str(e)}"
# Decode outputs
if outputs is not None:
try:
# Handle different output types
if isinstance(outputs, tuple):
outputs = outputs[0]
if isinstance(outputs, torch.Tensor):
if outputs.dim() > 1:
output_indices = outputs[0].cpu().numpy()
else:
output_indices = outputs.cpu().numpy()
else:
output_indices = outputs
# Convert indices to symbols
output_symbols = []
for idx in output_indices:
if idx < len(tgt_lang.index2word):
symbol = tgt_lang.index2word[idx]
if symbol in ['[EOS]', '[PAD]']:
break
if symbol not in ['[SOS]']:
output_symbols.append(symbol)
if output_symbols:
return f"Generated expression: {' '.join(output_symbols)}"
else:
return "No solution generated (empty output)"
except Exception as e:
return f"Output decoding error: {str(e)}"
return "No output from model"