NanoCalc-1M / use.py
LH-Tech-AI's picture
Create use.py
8ab44df verified
import torch
from transformers import T5Config, T5ForConditionalGeneration
import os
# ============================================================
# 1. SETUP & LOADING
# ============================================================
SAVE_PATH = "model.pt"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
if not os.path.exists(SAVE_PATH):
print(f"Error: File {SAVE_PATH} not found!")
exit()
torch.serialization.add_safe_globals([T5Config])
checkpoint = torch.load(SAVE_PATH, map_location=DEVICE, weights_only=True)
char2id = checkpoint["char2id"]
id2char = checkpoint["id2char"]
PAD_ID = char2id["<pad>"]
BOS_ID = char2id["<bos>"]
EOS_ID = char2id["<eos>"]
config = checkpoint["config"]
model = T5ForConditionalGeneration(config).to(DEVICE)
model.load_state_dict(checkpoint["model_state_dict"])
model.eval()
print(f"Model loaded (Accuracy: {checkpoint['accuracy']:.2f}% from epoch {checkpoint['epoch']})")
# ============================================================
# 2. HELPER FUNCTIONS
# ============================================================
def encode(text, max_len=20):
tokens = []
for c in text:
tokens.append(char2id.get(c, PAD_ID))
tokens.append(EOS_ID)
# Padding
tokens = tokens[:max_len]
tokens += [PAD_ID] * (max_len - len(tokens))
return tokens
def decode(token_ids):
result = []
for tid in token_ids:
if tid == EOS_ID: break
if tid in (PAD_ID, BOS_ID): continue
result.append(id2char.get(tid, "?"))
return "".join(result)
def solve(expression):
if not expression.endswith("="):
expression += "="
input_ids = torch.tensor([encode(expression)], dtype=torch.long).to(DEVICE)
attention_mask = (input_ids != PAD_ID).long()
with torch.no_grad():
generated = model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
max_new_tokens=12,
eos_token_id=EOS_ID,
pad_token_id=PAD_ID,
do_sample=False
)
return decode(generated[0].cpu().tolist())
# ============================================================
# 3. INTERACTIVE MODE
# ============================================================
print("\n--- Mini Math Model interactive ---")
print("Enter an arithmetic task (e.g. 15*15) or type 'exit' to quit this.")
while True:
user_input = input("\nTask > ").strip().replace(" ", "")
if user_input.lower() in ("exit", "quit", "q"):
break
if not any(op in user_input for op in "+-*/"):
print("Input an arithmetic task!")
continue
prediction = solve(user_input)
try:
calc_input = user_input.replace("/", "//")
true_val = str(eval(calc_input))
status = "✅" if prediction == true_val else "❌"
print(f"Model: {prediction} | Correct: {true_val} {status}")
except:
print(f"Model: {prediction}")