LH-Tech-AI commited on
Commit
8ab44df
·
verified ·
1 Parent(s): 76a1be0

Create use.py

Browse files
Files changed (1) hide show
  1. use.py +98 -0
use.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import T5Config, T5ForConditionalGeneration
3
+ import os
4
+
5
+ # ============================================================
6
+ # 1. SETUP & LOADING
7
+ # ============================================================
8
+
9
+ SAVE_PATH = "model.pt"
10
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
11
+
12
+ if not os.path.exists(SAVE_PATH):
13
+ print(f"Error: File {SAVE_PATH} not found!")
14
+ exit()
15
+
16
+ torch.serialization.add_safe_globals([T5Config])
17
+
18
+ checkpoint = torch.load(SAVE_PATH, map_location=DEVICE, weights_only=True)
19
+
20
+ char2id = checkpoint["char2id"]
21
+ id2char = checkpoint["id2char"]
22
+ PAD_ID = char2id["<pad>"]
23
+ BOS_ID = char2id["<bos>"]
24
+ EOS_ID = char2id["<eos>"]
25
+
26
+ config = checkpoint["config"]
27
+ model = T5ForConditionalGeneration(config).to(DEVICE)
28
+ model.load_state_dict(checkpoint["model_state_dict"])
29
+ model.eval()
30
+
31
+ print(f"Model loaded (Accuracy: {checkpoint['accuracy']:.2f}% from epoch {checkpoint['epoch']})")
32
+
33
+ # ============================================================
34
+ # 2. HELPER FUNCTIONS
35
+ # ============================================================
36
+
37
+ def encode(text, max_len=20):
38
+ tokens = []
39
+ for c in text:
40
+ tokens.append(char2id.get(c, PAD_ID))
41
+ tokens.append(EOS_ID)
42
+ # Padding
43
+ tokens = tokens[:max_len]
44
+ tokens += [PAD_ID] * (max_len - len(tokens))
45
+ return tokens
46
+
47
+ def decode(token_ids):
48
+ result = []
49
+ for tid in token_ids:
50
+ if tid == EOS_ID: break
51
+ if tid in (PAD_ID, BOS_ID): continue
52
+ result.append(id2char.get(tid, "?"))
53
+ return "".join(result)
54
+
55
+ def solve(expression):
56
+ if not expression.endswith("="):
57
+ expression += "="
58
+
59
+ input_ids = torch.tensor([encode(expression)], dtype=torch.long).to(DEVICE)
60
+ attention_mask = (input_ids != PAD_ID).long()
61
+
62
+ with torch.no_grad():
63
+ generated = model.generate(
64
+ input_ids=input_ids,
65
+ attention_mask=attention_mask,
66
+ max_new_tokens=12,
67
+ eos_token_id=EOS_ID,
68
+ pad_token_id=PAD_ID,
69
+ do_sample=False
70
+ )
71
+
72
+ return decode(generated[0].cpu().tolist())
73
+
74
+ # ============================================================
75
+ # 3. INTERACTIVE MODE
76
+ # ============================================================
77
+
78
+ print("\n--- Mini Math Model interactive ---")
79
+ print("Enter an arithmetic task (e.g. 15*15) or type 'exit' to quit this.")
80
+
81
+ while True:
82
+ user_input = input("\nTask > ").strip().replace(" ", "")
83
+ if user_input.lower() in ("exit", "quit", "q"):
84
+ break
85
+
86
+ if not any(op in user_input for op in "+-*/"):
87
+ print("Input an arithmetic task!")
88
+ continue
89
+
90
+ prediction = solve(user_input)
91
+
92
+ try:
93
+ calc_input = user_input.replace("/", "//")
94
+ true_val = str(eval(calc_input))
95
+ status = "✅" if prediction == true_val else "❌"
96
+ print(f"Model: {prediction} | Correct: {true_val} {status}")
97
+ except:
98
+ print(f"Model: {prediction}")