File size: 2,950 Bytes
8ab44df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import torch
from transformers import T5Config, T5ForConditionalGeneration
import os

# ============================================================
# 1. SETUP & LOADING
# ============================================================

SAVE_PATH = "model.pt"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

if not os.path.exists(SAVE_PATH):
    print(f"Error: File {SAVE_PATH} not found!")
    exit()

torch.serialization.add_safe_globals([T5Config])

checkpoint = torch.load(SAVE_PATH, map_location=DEVICE, weights_only=True)

char2id = checkpoint["char2id"]
id2char = checkpoint["id2char"]
PAD_ID = char2id["<pad>"]
BOS_ID = char2id["<bos>"]
EOS_ID = char2id["<eos>"]

config = checkpoint["config"]
model = T5ForConditionalGeneration(config).to(DEVICE)
model.load_state_dict(checkpoint["model_state_dict"])
model.eval()

print(f"Model loaded (Accuracy: {checkpoint['accuracy']:.2f}% from epoch {checkpoint['epoch']})")

# ============================================================
# 2. HELPER FUNCTIONS
# ============================================================

def encode(text, max_len=20):
    tokens = []
    for c in text:
        tokens.append(char2id.get(c, PAD_ID))
    tokens.append(EOS_ID)
    # Padding
    tokens = tokens[:max_len]
    tokens += [PAD_ID] * (max_len - len(tokens))
    return tokens

def decode(token_ids):
    result = []
    for tid in token_ids:
        if tid == EOS_ID: break
        if tid in (PAD_ID, BOS_ID): continue
        result.append(id2char.get(tid, "?"))
    return "".join(result)

def solve(expression):
    if not expression.endswith("="):
        expression += "="
    
    input_ids = torch.tensor([encode(expression)], dtype=torch.long).to(DEVICE)
    attention_mask = (input_ids != PAD_ID).long()
    
    with torch.no_grad():
        generated = model.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
            max_new_tokens=12,
            eos_token_id=EOS_ID,
            pad_token_id=PAD_ID,
            do_sample=False
        )
    
    return decode(generated[0].cpu().tolist())

# ============================================================
# 3. INTERACTIVE MODE
# ============================================================

print("\n--- Mini Math Model interactive ---")
print("Enter an arithmetic task (e.g. 15*15) or type 'exit' to quit this.")

while True:
    user_input = input("\nTask > ").strip().replace(" ", "")
    if user_input.lower() in ("exit", "quit", "q"):
        break
    
    if not any(op in user_input for op in "+-*/"):
        print("Input an arithmetic task!")
        continue
        
    prediction = solve(user_input)
    
    try:
        calc_input = user_input.replace("/", "//")
        true_val = str(eval(calc_input))
        status = "✅" if prediction == true_val else "❌"
        print(f"Model: {prediction} | Correct: {true_val} {status}")
    except:
        print(f"Model: {prediction}")