| import torch |
| from transformers import T5Config, T5ForConditionalGeneration |
| import os |
|
|
| |
| |
| |
|
|
| SAVE_PATH = "model.pt" |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
| if not os.path.exists(SAVE_PATH): |
| print(f"Error: File {SAVE_PATH} not found!") |
| exit() |
|
|
| torch.serialization.add_safe_globals([T5Config]) |
|
|
| checkpoint = torch.load(SAVE_PATH, map_location=DEVICE, weights_only=True) |
|
|
| char2id = checkpoint["char2id"] |
| id2char = checkpoint["id2char"] |
| PAD_ID = char2id["<pad>"] |
| BOS_ID = char2id["<bos>"] |
| EOS_ID = char2id["<eos>"] |
|
|
| config = checkpoint["config"] |
| model = T5ForConditionalGeneration(config).to(DEVICE) |
| model.load_state_dict(checkpoint["model_state_dict"]) |
| model.eval() |
|
|
| print(f"Model loaded (Accuracy: {checkpoint['accuracy']:.2f}% from epoch {checkpoint['epoch']})") |
|
|
| |
| |
| |
|
|
| def encode(text, max_len=20): |
| tokens = [] |
| for c in text: |
| tokens.append(char2id.get(c, PAD_ID)) |
| tokens.append(EOS_ID) |
| |
| tokens = tokens[:max_len] |
| tokens += [PAD_ID] * (max_len - len(tokens)) |
| return tokens |
|
|
| def decode(token_ids): |
| result = [] |
| for tid in token_ids: |
| if tid == EOS_ID: break |
| if tid in (PAD_ID, BOS_ID): continue |
| result.append(id2char.get(tid, "?")) |
| return "".join(result) |
|
|
| def solve(expression): |
| if not expression.endswith("="): |
| expression += "=" |
| |
| input_ids = torch.tensor([encode(expression)], dtype=torch.long).to(DEVICE) |
| attention_mask = (input_ids != PAD_ID).long() |
| |
| with torch.no_grad(): |
| generated = model.generate( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| max_new_tokens=12, |
| eos_token_id=EOS_ID, |
| pad_token_id=PAD_ID, |
| do_sample=False |
| ) |
| |
| return decode(generated[0].cpu().tolist()) |
|
|
| |
| |
| |
|
|
| print("\n--- Mini Math Model interactive ---") |
| print("Enter an arithmetic task (e.g. 15*15) or type 'exit' to quit this.") |
|
|
| while True: |
| user_input = input("\nTask > ").strip().replace(" ", "") |
| if user_input.lower() in ("exit", "quit", "q"): |
| break |
| |
| if not any(op in user_input for op in "+-*/"): |
| print("Input an arithmetic task!") |
| continue |
| |
| prediction = solve(user_input) |
| |
| try: |
| calc_input = user_input.replace("/", "//") |
| true_val = str(eval(calc_input)) |
| status = "✅" if prediction == true_val else "❌" |
| print(f"Model: {prediction} | Correct: {true_val} {status}") |
| except: |
| print(f"Model: {prediction}") |