MATH-GPT / app.py
rverma0631's picture
Update app.py
4887fbf verified
import torch
import torch.nn as nn
from torch.nn import functional as F
import tiktoken
import gradio as gr
import os
device = 'cuda' if torch.cuda.is_available() else 'cpu'
class Head(nn.Module):
def __init__(self, head_size, n_embd, block_size, dropout):
super().__init__()
self.key = nn.Linear(n_embd, head_size, bias=False)
self.query = nn.Linear(n_embd, head_size, bias=False)
self.value = nn.Linear(n_embd, head_size, bias=False)
self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))
self.dropout = nn.Dropout(dropout)
def forward(self, x):
B,T,C = x.shape
k = self.key(x)
q = self.query(x)
wei = q @ k.transpose(-2,-1) * C**-0.5
wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
wei = F.softmax(wei, dim=-1)
wei = self.dropout(wei)
v = self.value(x)
out = wei @ v
return out
class MultiHeadAttention(nn.Module):
def __init__(self, num_heads, head_size, n_embd, block_size, dropout):
super().__init__()
self.heads = nn.ModuleList([Head(head_size, n_embd, block_size, dropout) for _ in range(num_heads)])
self.proj = nn.Linear(n_embd, n_embd)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
out = torch.cat([h(x) for h in self.heads], dim=-1)
out = self.dropout(self.proj(out))
return out
class FeedFoward(nn.Module):
def __init__(self, n_embd, dropout):
super().__init__()
self.net = nn.Sequential(
nn.Linear(n_embd, 4 * n_embd),
nn.GELU(),
nn.Linear(4 * n_embd, n_embd),
nn.Dropout(dropout),
)
def forward(self, x):
return self.net(x)
class Block(nn.Module):
def __init__(self, n_embd, n_head, block_size, dropout):
super().__init__()
head_size = n_embd // n_head
self.sa = MultiHeadAttention(n_head, head_size, n_embd, block_size, dropout)
self.ffwd = FeedFoward(n_embd, dropout)
self.ln1 = nn.LayerNorm(n_embd)
self.ln2 = nn.LayerNorm(n_embd)
def forward(self, x):
x = x + self.sa(self.ln1(x))
x = x + self.ffwd(self.ln2(x))
return x
class GPTLanguageModel(nn.Module):
def __init__(self, vocab_size, n_embd, block_size, n_layer, n_head, dropout):
super().__init__()
self.block_size = block_size
self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
self.position_embedding_table = nn.Embedding(block_size, n_embd)
self.blocks = nn.Sequential(*[Block(n_embd, n_head, block_size, dropout) for _ in range(n_layer)])
self.ln_f = nn.LayerNorm(n_embd)
self.lm_head = nn.Linear(n_embd, vocab_size)
def forward(self, idx, targets=None):
B, T = idx.shape
tok_emb = self.token_embedding_table(idx)
pos_emb = self.position_embedding_table(torch.arange(T, device=device))
x = tok_emb + pos_emb
x = self.blocks(x)
x = self.ln_f(x)
logits = self.lm_head(x)
if targets is None:
loss = None
else:
B, T, C = logits.shape
logits = logits.view(B*T, C)
targets = targets.view(B*T)
loss = F.cross_entropy(logits, targets)
return logits, loss
def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
for _ in range(max_new_tokens):
idx_cond = idx[:, -self.block_size:]
logits, _ = self(idx_cond)
logits = logits[:, -1, :] / temperature
if top_k is not None:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[:, [-1]]] = -float('Inf')
probs = F.softmax(logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1)
idx = torch.cat((idx, idx_next), dim=1)
return idx
def load_model(checkpoint_path):
checkpoint = torch.load(checkpoint_path, map_location=device)
config = checkpoint['config']
model = GPTLanguageModel(
vocab_size=config['vocab_size'],
n_embd=config['n_embd'],
block_size=config['block_size'],
n_layer=config['n_layer'],
n_head=config['n_head'],
dropout=0.0
)
model.load_state_dict(checkpoint['model_state_dict'])
model = model.to(device)
model.eval()
return model, config
model = None
enc = None
def initialize_model():
global model, enc
checkpoint_path = "model.pt"
if not os.path.exists(checkpoint_path):
return f"❌ model.pt not found in current directory!"
try:
model, config = load_model(checkpoint_path)
enc = tiktoken.get_encoding("gpt2")
param_count = sum(p.numel() for p in model.parameters())/1e6
return f"✅ Model loaded successfully!\nParameters: {param_count:.1f}M\nDevice: {device}"
except Exception as e:
return f"❌ Error loading model: {str(e)}"
def generate_response(message, history, temperature, top_k, max_tokens):
global model, enc
if model is None or enc is None:
return history + [("Please load a model first!", "")]
if not message.strip():
return history + [("", "Please enter a message!")]
try:
tokens = enc.encode(message, disallowed_special=())
context = torch.tensor([tokens], dtype=torch.long, device=device)
with torch.no_grad():
generated = model.generate(
context,
max_new_tokens=max_tokens,
temperature=temperature,
top_k=top_k if top_k > 0 else None
)
full_response = enc.decode(generated[0].tolist())
response = full_response[len(message):].strip()
if not response:
response = "I couldn't generate a meaningful response. Try adjusting the parameters or rephrasing your question."
history.append((message, response))
return history
except Exception as e:
error_msg = f"Error generating response: {str(e)}"
history.append((message, error_msg))
return history
def clear_conversation():
return []
css = """
#chatbot {
height: 500px;
}
.gradio-container {
max-width: 900px;
margin: auto;
}
"""
with gr.Blocks(css=css, title="Math Model Chat") as demo:
gr.Markdown("# 🧮 Math Model Chat Interface")
gr.Markdown("Place your trained model as `model.pt` in the same directory and click Load!")
with gr.Row():
with gr.Column(scale=2):
chatbot = gr.Chatbot(
elem_id="chatbot",
show_copy_button=True,
bubble_full_width=False
)
with gr.Row():
msg = gr.Textbox(
placeholder="Ask me anything about math...",
show_label=False,
scale=4
)
submit_btn = gr.Button("Send", variant="primary", scale=1)
with gr.Row():
clear_btn = gr.Button("Clear Chat", variant="secondary")
with gr.Column(scale=1):
gr.Markdown("### Model Settings")
load_btn = gr.Button("Load model.pt", variant="primary", size="lg")
status = gr.Textbox(label="Status", interactive=False, lines=3)
gr.Markdown("### Generation Parameters")
temperature = gr.Slider(
minimum=0.1,
maximum=2.0,
value=0.8,
step=0.1,
label="Temperature",
info="Higher = more creative"
)
top_k = gr.Slider(
minimum=0,
maximum=500,
value=200,
step=10,
label="Top-k",
info="0 = disabled, lower = more focused"
)
max_tokens = gr.Slider(
minimum=50,
maximum=500,
value=200,
step=10,
label="Max Tokens",
info="Maximum response length"
)
gr.Markdown("### Example Questions")
examples = gr.Examples(
examples=[
"What is the derivative of x²?",
"Solve the integral ∫x dx",
"Explain the Pythagorean theorem",
"What is the quadratic formula?",
"How do you find the area of a circle?",
],
inputs=msg
)
load_btn.click(
fn=initialize_model,
outputs=[status]
)
submit_btn.click(
fn=generate_response,
inputs=[msg, chatbot, temperature, top_k, max_tokens],
outputs=[chatbot]
).then(
fn=lambda: "",
outputs=[msg]
)
msg.submit(
fn=generate_response,
inputs=[msg, chatbot, temperature, top_k, max_tokens],
outputs=[chatbot]
).then(
fn=lambda: "",
outputs=[msg]
)
clear_btn.click(
fn=clear_conversation,
outputs=[chatbot]
)
if __name__ == "__main__":
demo.launch(
share=True,
server_name="127.0.0.1",
server_port=7860,
show_error=True
)