import torch from transformers import T5Config, T5ForConditionalGeneration import os # ============================================================ # 1. SETUP & LOADING # ============================================================ SAVE_PATH = "model.pt" DEVICE = "cuda" if torch.cuda.is_available() else "cpu" if not os.path.exists(SAVE_PATH): print(f"Error: File {SAVE_PATH} not found!") exit() torch.serialization.add_safe_globals([T5Config]) checkpoint = torch.load(SAVE_PATH, map_location=DEVICE, weights_only=True) char2id = checkpoint["char2id"] id2char = checkpoint["id2char"] PAD_ID = char2id[""] BOS_ID = char2id[""] EOS_ID = char2id[""] config = checkpoint["config"] model = T5ForConditionalGeneration(config).to(DEVICE) model.load_state_dict(checkpoint["model_state_dict"]) model.eval() print(f"Model loaded (Accuracy: {checkpoint['accuracy']:.2f}% from epoch {checkpoint['epoch']})") # ============================================================ # 2. HELPER FUNCTIONS # ============================================================ def encode(text, max_len=20): tokens = [] for c in text: tokens.append(char2id.get(c, PAD_ID)) tokens.append(EOS_ID) # Padding tokens = tokens[:max_len] tokens += [PAD_ID] * (max_len - len(tokens)) return tokens def decode(token_ids): result = [] for tid in token_ids: if tid == EOS_ID: break if tid in (PAD_ID, BOS_ID): continue result.append(id2char.get(tid, "?")) return "".join(result) def solve(expression): if not expression.endswith("="): expression += "=" input_ids = torch.tensor([encode(expression)], dtype=torch.long).to(DEVICE) attention_mask = (input_ids != PAD_ID).long() with torch.no_grad(): generated = model.generate( input_ids=input_ids, attention_mask=attention_mask, max_new_tokens=12, eos_token_id=EOS_ID, pad_token_id=PAD_ID, do_sample=False ) return decode(generated[0].cpu().tolist()) # ============================================================ # 3. INTERACTIVE MODE # ============================================================ print("\n--- Mini Math Model interactive ---") print("Enter an arithmetic task (e.g. 15*15) or type 'exit' to quit this.") while True: user_input = input("\nTask > ").strip().replace(" ", "") if user_input.lower() in ("exit", "quit", "q"): break if not any(op in user_input for op in "+-*/"): print("Input an arithmetic task!") continue prediction = solve(user_input) try: calc_input = user_input.replace("/", "//") true_val = str(eval(calc_input)) status = "✅" if prediction == true_val else "❌" print(f"Model: {prediction} | Correct: {true_val} {status}") except: print(f"Model: {prediction}")