Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| from torch.nn import functional as F | |
| import tiktoken | |
| import gradio as gr | |
| import os | |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| class Head(nn.Module): | |
| def __init__(self, head_size, n_embd, block_size, dropout): | |
| super().__init__() | |
| self.key = nn.Linear(n_embd, head_size, bias=False) | |
| self.query = nn.Linear(n_embd, head_size, bias=False) | |
| self.value = nn.Linear(n_embd, head_size, bias=False) | |
| self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size))) | |
| self.dropout = nn.Dropout(dropout) | |
| def forward(self, x): | |
| B,T,C = x.shape | |
| k = self.key(x) | |
| q = self.query(x) | |
| wei = q @ k.transpose(-2,-1) * C**-0.5 | |
| wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf')) | |
| wei = F.softmax(wei, dim=-1) | |
| wei = self.dropout(wei) | |
| v = self.value(x) | |
| out = wei @ v | |
| return out | |
| class MultiHeadAttention(nn.Module): | |
| def __init__(self, num_heads, head_size, n_embd, block_size, dropout): | |
| super().__init__() | |
| self.heads = nn.ModuleList([Head(head_size, n_embd, block_size, dropout) for _ in range(num_heads)]) | |
| self.proj = nn.Linear(n_embd, n_embd) | |
| self.dropout = nn.Dropout(dropout) | |
| def forward(self, x): | |
| out = torch.cat([h(x) for h in self.heads], dim=-1) | |
| out = self.dropout(self.proj(out)) | |
| return out | |
| class FeedFoward(nn.Module): | |
| def __init__(self, n_embd, dropout): | |
| super().__init__() | |
| self.net = nn.Sequential( | |
| nn.Linear(n_embd, 4 * n_embd), | |
| nn.GELU(), | |
| nn.Linear(4 * n_embd, n_embd), | |
| nn.Dropout(dropout), | |
| ) | |
| def forward(self, x): | |
| return self.net(x) | |
| class Block(nn.Module): | |
| def __init__(self, n_embd, n_head, block_size, dropout): | |
| super().__init__() | |
| head_size = n_embd // n_head | |
| self.sa = MultiHeadAttention(n_head, head_size, n_embd, block_size, dropout) | |
| self.ffwd = FeedFoward(n_embd, dropout) | |
| self.ln1 = nn.LayerNorm(n_embd) | |
| self.ln2 = nn.LayerNorm(n_embd) | |
| def forward(self, x): | |
| x = x + self.sa(self.ln1(x)) | |
| x = x + self.ffwd(self.ln2(x)) | |
| return x | |
| class GPTLanguageModel(nn.Module): | |
| def __init__(self, vocab_size, n_embd, block_size, n_layer, n_head, dropout): | |
| super().__init__() | |
| self.block_size = block_size | |
| self.token_embedding_table = nn.Embedding(vocab_size, n_embd) | |
| self.position_embedding_table = nn.Embedding(block_size, n_embd) | |
| self.blocks = nn.Sequential(*[Block(n_embd, n_head, block_size, dropout) for _ in range(n_layer)]) | |
| self.ln_f = nn.LayerNorm(n_embd) | |
| self.lm_head = nn.Linear(n_embd, vocab_size) | |
| def forward(self, idx, targets=None): | |
| B, T = idx.shape | |
| tok_emb = self.token_embedding_table(idx) | |
| pos_emb = self.position_embedding_table(torch.arange(T, device=device)) | |
| x = tok_emb + pos_emb | |
| x = self.blocks(x) | |
| x = self.ln_f(x) | |
| logits = self.lm_head(x) | |
| if targets is None: | |
| loss = None | |
| else: | |
| B, T, C = logits.shape | |
| logits = logits.view(B*T, C) | |
| targets = targets.view(B*T) | |
| loss = F.cross_entropy(logits, targets) | |
| return logits, loss | |
| def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None): | |
| for _ in range(max_new_tokens): | |
| idx_cond = idx[:, -self.block_size:] | |
| logits, _ = self(idx_cond) | |
| logits = logits[:, -1, :] / temperature | |
| if top_k is not None: | |
| v, _ = torch.topk(logits, min(top_k, logits.size(-1))) | |
| logits[logits < v[:, [-1]]] = -float('Inf') | |
| probs = F.softmax(logits, dim=-1) | |
| idx_next = torch.multinomial(probs, num_samples=1) | |
| idx = torch.cat((idx, idx_next), dim=1) | |
| return idx | |
| def load_model(checkpoint_path): | |
| checkpoint = torch.load(checkpoint_path, map_location=device) | |
| config = checkpoint['config'] | |
| model = GPTLanguageModel( | |
| vocab_size=config['vocab_size'], | |
| n_embd=config['n_embd'], | |
| block_size=config['block_size'], | |
| n_layer=config['n_layer'], | |
| n_head=config['n_head'], | |
| dropout=0.0 | |
| ) | |
| model.load_state_dict(checkpoint['model_state_dict']) | |
| model = model.to(device) | |
| model.eval() | |
| return model, config | |
| model = None | |
| enc = None | |
| def initialize_model(): | |
| global model, enc | |
| checkpoint_path = "model.pt" | |
| if not os.path.exists(checkpoint_path): | |
| return f"❌ model.pt not found in current directory!" | |
| try: | |
| model, config = load_model(checkpoint_path) | |
| enc = tiktoken.get_encoding("gpt2") | |
| param_count = sum(p.numel() for p in model.parameters())/1e6 | |
| return f"✅ Model loaded successfully!\nParameters: {param_count:.1f}M\nDevice: {device}" | |
| except Exception as e: | |
| return f"❌ Error loading model: {str(e)}" | |
| def generate_response(message, history, temperature, top_k, max_tokens): | |
| global model, enc | |
| if model is None or enc is None: | |
| return history + [("Please load a model first!", "")] | |
| if not message.strip(): | |
| return history + [("", "Please enter a message!")] | |
| try: | |
| tokens = enc.encode(message, disallowed_special=()) | |
| context = torch.tensor([tokens], dtype=torch.long, device=device) | |
| with torch.no_grad(): | |
| generated = model.generate( | |
| context, | |
| max_new_tokens=max_tokens, | |
| temperature=temperature, | |
| top_k=top_k if top_k > 0 else None | |
| ) | |
| full_response = enc.decode(generated[0].tolist()) | |
| response = full_response[len(message):].strip() | |
| if not response: | |
| response = "I couldn't generate a meaningful response. Try adjusting the parameters or rephrasing your question." | |
| history.append((message, response)) | |
| return history | |
| except Exception as e: | |
| error_msg = f"Error generating response: {str(e)}" | |
| history.append((message, error_msg)) | |
| return history | |
| def clear_conversation(): | |
| return [] | |
| css = """ | |
| #chatbot { | |
| height: 500px; | |
| } | |
| .gradio-container { | |
| max-width: 900px; | |
| margin: auto; | |
| } | |
| """ | |
| with gr.Blocks(css=css, title="Math Model Chat") as demo: | |
| gr.Markdown("# 🧮 Math Model Chat Interface") | |
| gr.Markdown("Place your trained model as `model.pt` in the same directory and click Load!") | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| chatbot = gr.Chatbot( | |
| elem_id="chatbot", | |
| show_copy_button=True, | |
| bubble_full_width=False | |
| ) | |
| with gr.Row(): | |
| msg = gr.Textbox( | |
| placeholder="Ask me anything about math...", | |
| show_label=False, | |
| scale=4 | |
| ) | |
| submit_btn = gr.Button("Send", variant="primary", scale=1) | |
| with gr.Row(): | |
| clear_btn = gr.Button("Clear Chat", variant="secondary") | |
| with gr.Column(scale=1): | |
| gr.Markdown("### Model Settings") | |
| load_btn = gr.Button("Load model.pt", variant="primary", size="lg") | |
| status = gr.Textbox(label="Status", interactive=False, lines=3) | |
| gr.Markdown("### Generation Parameters") | |
| temperature = gr.Slider( | |
| minimum=0.1, | |
| maximum=2.0, | |
| value=0.8, | |
| step=0.1, | |
| label="Temperature", | |
| info="Higher = more creative" | |
| ) | |
| top_k = gr.Slider( | |
| minimum=0, | |
| maximum=500, | |
| value=200, | |
| step=10, | |
| label="Top-k", | |
| info="0 = disabled, lower = more focused" | |
| ) | |
| max_tokens = gr.Slider( | |
| minimum=50, | |
| maximum=500, | |
| value=200, | |
| step=10, | |
| label="Max Tokens", | |
| info="Maximum response length" | |
| ) | |
| gr.Markdown("### Example Questions") | |
| examples = gr.Examples( | |
| examples=[ | |
| "What is the derivative of x²?", | |
| "Solve the integral ∫x dx", | |
| "Explain the Pythagorean theorem", | |
| "What is the quadratic formula?", | |
| "How do you find the area of a circle?", | |
| ], | |
| inputs=msg | |
| ) | |
| load_btn.click( | |
| fn=initialize_model, | |
| outputs=[status] | |
| ) | |
| submit_btn.click( | |
| fn=generate_response, | |
| inputs=[msg, chatbot, temperature, top_k, max_tokens], | |
| outputs=[chatbot] | |
| ).then( | |
| fn=lambda: "", | |
| outputs=[msg] | |
| ) | |
| msg.submit( | |
| fn=generate_response, | |
| inputs=[msg, chatbot, temperature, top_k, max_tokens], | |
| outputs=[chatbot] | |
| ).then( | |
| fn=lambda: "", | |
| outputs=[msg] | |
| ) | |
| clear_btn.click( | |
| fn=clear_conversation, | |
| outputs=[chatbot] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch( | |
| share=True, | |
| server_name="127.0.0.1", | |
| server_port=7860, | |
| show_error=True | |
| ) |