Update app.py
Browse files
app.py
CHANGED
|
@@ -2,7 +2,8 @@ import gradio as gr
|
|
| 2 |
import torch
|
| 3 |
import torch.nn as nn
|
| 4 |
from torch.nn import functional as F
|
| 5 |
-
import json
|
|
|
|
| 6 |
|
| 7 |
# --- Model Definition (same as before) ---
|
| 8 |
# NOTE: The model class MUST be defined in your app.py file
|
|
@@ -149,23 +150,35 @@ model.to(device)
|
|
| 149 |
|
| 150 |
|
| 151 |
# --- Gradio UI & Inference function ---
|
| 152 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 153 |
# Encode the prompt text into tokens.
|
| 154 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 155 |
# Generate new tokens.
|
| 156 |
generated_text_indices = model.generate(context, max_new_tokens=max_new_tokens)
|
| 157 |
# Decode the tokens back into text.
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
|
|
|
|
|
|
|
|
|
| 169 |
)
|
| 170 |
|
| 171 |
-
demo.launch()
|
|
|
|
| 2 |
import torch
|
| 3 |
import torch.nn as nn
|
| 4 |
from torch.nn import functional as F
|
| 5 |
+
import json
|
| 6 |
+
import os # <-- Added for file path checks
|
| 7 |
|
| 8 |
# --- Model Definition (same as before) ---
|
| 9 |
# NOTE: The model class MUST be defined in your app.py file
|
|
|
|
| 150 |
|
| 151 |
|
| 152 |
# --- Gradio UI & Inference function ---
|
| 153 |
+
def generate_text_chat(message, history):
|
| 154 |
+
# We'll just use the most recent message as the prompt.
|
| 155 |
+
prompt = message
|
| 156 |
+
# You can adjust this to a different number of tokens if you like.
|
| 157 |
+
max_new_tokens = 50
|
| 158 |
+
|
| 159 |
# Encode the prompt text into tokens.
|
| 160 |
+
encoded_prompt = [stoi.get(c, 0) for c in prompt]
|
| 161 |
+
if not encoded_prompt:
|
| 162 |
+
return "Prompt is empty or contains unknown characters."
|
| 163 |
+
|
| 164 |
+
context = torch.tensor(encoded_prompt, dtype=torch.long, device=device).unsqueeze(0)
|
| 165 |
# Generate new tokens.
|
| 166 |
generated_text_indices = model.generate(context, max_new_tokens=max_new_tokens)
|
| 167 |
# Decode the tokens back into text.
|
| 168 |
+
generated_text = decode(generated_text_indices[0].tolist())
|
| 169 |
+
|
| 170 |
+
# Return only the newly generated part of the text, removing the original prompt
|
| 171 |
+
return generated_text[len(prompt):]
|
| 172 |
+
|
| 173 |
+
# Using gr.ChatInterface for a conversational experience
|
| 174 |
+
demo = gr.ChatInterface(
|
| 175 |
+
fn=generate_text_chat,
|
| 176 |
+
title="Tiny Language Model Chat",
|
| 177 |
+
description="A simple character-level language model trained in PyTorch, now with a chat interface.",
|
| 178 |
+
# You can customize these components further if you like
|
| 179 |
+
chatbot=gr.Chatbot(height="500px"),
|
| 180 |
+
textbox=gr.Textbox(placeholder="Ask me anything...", container=False, scale=7),
|
| 181 |
+
theme="soft",
|
| 182 |
)
|
| 183 |
|
| 184 |
+
demo.launch()
|