Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| from PIL import Image | |
| import numpy as np | |
| def simple_process_input(image, text_input, model, src_lang, tgt_lang, cfg): | |
| """Simplified inference that bypasses complex text processing""" | |
| device = next(model.parameters()).device | |
| # Transform image | |
| import datasets.diagram_aug as T_diagram | |
| diagram_transform = T_diagram.Compose([ | |
| T_diagram.Resize(cfg.diagram_size), | |
| T_diagram.CenterCrop(cfg.diagram_size), | |
| T_diagram.ToTensor(), | |
| T_diagram.Normalize() | |
| ]) | |
| diagram = diagram_transform(image).unsqueeze(0).to(device) | |
| # Simple text processing - just tokenize the words | |
| words = text_input.split() if text_input.strip() else ["problem"] | |
| # Map words to indices (use PAD for unknown) | |
| text_indices = [] | |
| for word in words: | |
| if word in src_lang.word2index: | |
| text_indices.append(src_lang.word2index[word]) | |
| else: | |
| text_indices.append(0) # PAD token | |
| # Ensure minimum length | |
| if len(text_indices) == 0: | |
| text_indices = [0] | |
| # Create text tensors with proper shape | |
| batch_size = 1 | |
| text_len = len(text_indices) | |
| # For MLM pretrain, tokens need to be 3D: [batch, seq_len, vocab_size] | |
| # But here we use 2D: [batch, seq_len] and let the embedding layer handle it | |
| token_tensor = torch.LongTensor([text_indices]).to(device) | |
| # Ensure sect_tag and class_tag match token length | |
| sect_tag_indices = [1] * text_len # Default to [PROB] | |
| class_tag_indices = [1] * text_len # Default to [GEN] | |
| # The model expects token to be [batch, num_subwords_per_token, seq_len] | |
| # For simple case, we have 1 subword per token, so shape is [batch, 1, seq_len] | |
| # This gets embedded and summed over dim=1 to get [batch, seq_len, embed_dim] | |
| # Create 3D tensor: [batch_size, 1, text_len] | |
| # Each token is a single subword, so middle dimension is 1 | |
| token_tensor_3d = token_tensor.unsqueeze(1) # [batch, 1, seq_len] | |
| text_dict = { | |
| 'token': token_tensor_3d, | |
| 'sect_tag': torch.LongTensor([sect_tag_indices]).to(device), | |
| 'class_tag': torch.LongTensor([class_tag_indices]).to(device), | |
| 'len': torch.LongTensor([text_len]).to(device) | |
| } | |
| # Simple var dict (no variables detected) | |
| # Note: var positions need to account for the diagram token that will be added | |
| var_dict = { | |
| 'pos': torch.zeros(batch_size, 1, dtype=torch.long).to(device), | |
| 'len': torch.zeros(batch_size, dtype=torch.long).to(device), | |
| 'var_value': [], | |
| 'arg_value': [] | |
| } | |
| # Expression dict for inference | |
| exp_dict = { | |
| 'exp': torch.LongTensor([[1]]).to(device), # SOS token | |
| 'len': torch.ones(batch_size, dtype=torch.long).to(device), | |
| 'answer': 0 | |
| } | |
| # Run inference with no_grad | |
| with torch.no_grad(): | |
| try: | |
| # Create a copy of text_dict to avoid in-place modification | |
| text_dict_copy = { | |
| 'token': text_dict['token'].clone(), | |
| 'sect_tag': text_dict['sect_tag'].clone(), | |
| 'class_tag': text_dict['class_tag'].clone(), | |
| 'len': text_dict['len'].clone() | |
| } | |
| var_dict_copy = { | |
| 'pos': var_dict['pos'].clone(), | |
| 'len': var_dict['len'].clone(), | |
| 'var_value': var_dict['var_value'], | |
| 'arg_value': var_dict['arg_value'] | |
| } | |
| outputs = model(diagram, text_dict_copy, var_dict_copy, exp_dict, is_train=False) | |
| except Exception as e: | |
| return f"Model inference error: {str(e)}" | |
| # Decode outputs | |
| if outputs is not None: | |
| try: | |
| # Handle different output types | |
| if isinstance(outputs, tuple): | |
| outputs = outputs[0] | |
| if isinstance(outputs, torch.Tensor): | |
| if outputs.dim() > 1: | |
| output_indices = outputs[0].cpu().numpy() | |
| else: | |
| output_indices = outputs.cpu().numpy() | |
| else: | |
| output_indices = outputs | |
| # Convert indices to symbols | |
| output_symbols = [] | |
| for idx in output_indices: | |
| if idx < len(tgt_lang.index2word): | |
| symbol = tgt_lang.index2word[idx] | |
| if symbol in ['[EOS]', '[PAD]']: | |
| break | |
| if symbol not in ['[SOS]']: | |
| output_symbols.append(symbol) | |
| if output_symbols: | |
| return f"Generated expression: {' '.join(output_symbols)}" | |
| else: | |
| return "No solution generated (empty output)" | |
| except Exception as e: | |
| return f"Output decoding error: {str(e)}" | |
| return "No output from model" |