boyuia commited on
Commit
c19d729
·
verified ·
1 Parent(s): 04c654d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -15
app.py CHANGED
@@ -2,7 +2,8 @@ import gradio as gr
2
  import torch
3
  import torch.nn as nn
4
  from torch.nn import functional as F
5
- import json # <-- Added this import
 
6
 
7
  # --- Model Definition (same as before) ---
8
  # NOTE: The model class MUST be defined in your app.py file
@@ -149,23 +150,35 @@ model.to(device)
149
 
150
 
151
  # --- Gradio UI & Inference function ---
152
- def generate_text(prompt, max_new_tokens):
 
 
 
 
 
153
  # Encode the prompt text into tokens.
154
- context = torch.tensor(encode(prompt), dtype=torch.long, device=device).unsqueeze(0)
 
 
 
 
155
  # Generate new tokens.
156
  generated_text_indices = model.generate(context, max_new_tokens=max_new_tokens)
157
  # Decode the tokens back into text.
158
- return decode(generated_text_indices[0].tolist())
159
-
160
- demo = gr.Interface(
161
- fn=generate_text,
162
- inputs=[
163
- gr.Textbox(label="Prompt", placeholder="Enter your text prompt here..."),
164
- gr.Slider(1, 100, value=20, step=1, label="Number of new tokens to generate"),
165
- ],
166
- outputs="text",
167
- title="Tiny Language Model",
168
- description="A simple character-level language model trained in PyTorch."
 
 
 
169
  )
170
 
171
- demo.launch()
 
2
  import torch
3
  import torch.nn as nn
4
  from torch.nn import functional as F
5
+ import json
6
+ import os # <-- Added for file path checks
7
 
8
  # --- Model Definition (same as before) ---
9
  # NOTE: The model class MUST be defined in your app.py file
 
150
 
151
 
152
  # --- Gradio UI & Inference function ---
153
+ def generate_text_chat(message, history):
154
+ # We'll just use the most recent message as the prompt.
155
+ prompt = message
156
+ # You can adjust this to a different number of tokens if you like.
157
+ max_new_tokens = 50
158
+
159
  # Encode the prompt text into tokens.
160
+ encoded_prompt = [stoi.get(c, 0) for c in prompt]
161
+ if not encoded_prompt:
162
+ return "Prompt is empty or contains unknown characters."
163
+
164
+ context = torch.tensor(encoded_prompt, dtype=torch.long, device=device).unsqueeze(0)
165
  # Generate new tokens.
166
  generated_text_indices = model.generate(context, max_new_tokens=max_new_tokens)
167
  # Decode the tokens back into text.
168
+ generated_text = decode(generated_text_indices[0].tolist())
169
+
170
+ # Return only the newly generated part of the text, removing the original prompt
171
+ return generated_text[len(prompt):]
172
+
173
+ # Using gr.ChatInterface for a conversational experience
174
+ demo = gr.ChatInterface(
175
+ fn=generate_text_chat,
176
+ title="Tiny Language Model Chat",
177
+ description="A simple character-level language model trained in PyTorch, now with a chat interface.",
178
+ # You can customize these components further if you like
179
+ chatbot=gr.Chatbot(height="500px"),
180
+ textbox=gr.Textbox(placeholder="Ask me anything...", container=False, scale=7),
181
+ theme="soft",
182
  )
183
 
184
+ demo.launch()