manthilaffs commited on
Commit
15d1bb2
·
verified ·
1 Parent(s): 9c5c378

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -197
app.py CHANGED
@@ -1,8 +1,7 @@
1
  import gradio as gr
2
  import torch
3
  import spaces
4
- from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
5
- from threading import Thread
6
 
7
  model = None
8
  tokenizer = None
@@ -15,8 +14,10 @@ alpaca_prompt = """පහත දැක්වෙන්නේ යම් කාර
15
  ### ප්‍රතිචාරය:
16
  {}"""
17
 
18
- def load_model():
 
19
  global model, tokenizer
 
20
  if model is None:
21
  tokenizer = AutoTokenizer.from_pretrained("manthilaffs/Gamunu-4B-Instruct-Alpha")
22
  model = AutoModelForCausalLM.from_pretrained(
@@ -25,18 +26,11 @@ def load_model():
25
  device_map="auto",
26
  )
27
  model.eval()
28
-
29
- @spaces.GPU
30
- def generate_response(message, history, enable_history=False, max_new_tokens=1024):
31
- global model, tokenizer
32
-
33
- load_model()
34
 
35
  # Add history only if enabled
36
  if enable_history and history:
37
  prev = "\n".join(
38
- [f"User: {h['content']}\nGamunu: {h.get('content', '')}"
39
- for h in history if h.get('role') == 'assistant']
40
  )
41
  context = f"{prev}\n\n{message}"
42
  else:
@@ -51,62 +45,18 @@ def generate_response(message, history, enable_history=False, max_new_tokens=102
51
 
52
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
53
 
54
- # Initialize the streamer
55
- streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
56
 
57
- # Generation parameters
58
- generation_kwargs = dict(
59
- **inputs,
60
- max_new_tokens=max_new_tokens,
61
- streamer=streamer,
62
- do_sample=True,
63
- )
64
-
65
- # Start generation in a separate thread
66
- thread = Thread(target=model.generate, kwargs=generation_kwargs)
67
- thread.start()
68
-
69
- # Stream the response
70
- full_text = ""
71
 
72
- for new_text in streamer:
73
- full_text += new_text
74
-
75
- # Check if we've reached the response section and extract it
76
- if "### ප්‍රතිචාරය:" in full_text:
77
- response_text = full_text.split("### ප්‍රතිචාරය:")[-1].strip()
78
- yield response_text
79
- else:
80
- # Still building up to the response marker, yield what we have
81
- yield full_text.strip()
82
-
83
- # Make sure thread completes
84
- thread.join()
85
-
86
- # Final yield with cleaned response
87
- if "### ප්‍රතිචාරය:" in full_text:
88
- final_response = full_text.split("### ප්‍රතිචාරය:")[-1].strip()
89
- else:
90
- final_response = full_text.strip()
91
 
92
- # Ensure we yield at least once
93
- if final_response:
94
- yield final_response
95
 
96
- # Custom CSS for styling with copy button
97
  custom_css = """
98
- /* Container width constraints for PC screens */
99
- .gradio-container {
100
- max-width: 1200px !important;
101
- margin: 0 auto !important;
102
- }
103
-
104
- /* Chat interface max width */
105
- .contain {
106
- max-width: 900px !important;
107
- margin: 0 auto !important;
108
- }
109
-
110
  #splash-screen {
111
  position: fixed;
112
  top: 0;
@@ -143,42 +93,6 @@ custom_css = """
143
  75% { transform: rotateY(540deg) scale(1.2); opacity: 0.8; }
144
  }
145
 
146
- /* Copy button styling */
147
- .message-wrap.bot {
148
- position: relative;
149
- }
150
-
151
- .message-wrap.bot:hover .copy-button {
152
- opacity: 1;
153
- }
154
-
155
- .copy-button {
156
- position: absolute;
157
- top: 8px;
158
- right: 8px;
159
- opacity: 0;
160
- transition: opacity 0.2s;
161
- background: rgba(102, 126, 234, 0.1);
162
- border: 1px solid rgba(102, 126, 234, 0.3);
163
- border-radius: 6px;
164
- padding: 6px 10px;
165
- cursor: pointer;
166
- font-size: 12px;
167
- color: #667eea;
168
- z-index: 10;
169
- }
170
-
171
- .copy-button:hover {
172
- background: rgba(102, 126, 234, 0.2);
173
- border-color: rgba(102, 126, 234, 0.5);
174
- }
175
-
176
- .copy-button.copied {
177
- background: rgba(34, 197, 94, 0.2);
178
- border-color: rgba(34, 197, 94, 0.5);
179
- color: #22c55e;
180
- }
181
-
182
  /* Smaller font sizes for chat */
183
  .message-wrap .message {
184
  font-size: 0.9rem !important;
@@ -188,15 +102,6 @@ custom_css = """
188
  font-size: 0.9rem !important;
189
  }
190
 
191
- /* Compact examples grid */
192
- .examples {
193
- max-width: 100% !important;
194
- }
195
-
196
- .examples .wrap {
197
- gap: 0.5rem !important;
198
- }
199
-
200
  /* Avatar styling */
201
  .message-wrap.user .avatar-container::before {
202
  content: "👤";
@@ -232,67 +137,6 @@ custom_css = """
232
  height: 40px !important;
233
  min-width: 40px !important;
234
  }
235
-
236
- /* Compact padding */
237
- .main {
238
- padding: 1rem !important;
239
- }
240
-
241
- /* Title styling - more compact */
242
- h1 {
243
- font-size: 1.5rem !important;
244
- margin-bottom: 0.5rem !important;
245
- }
246
-
247
- /* Streaming cursor effect */
248
- .message-wrap.bot.streaming .message::after {
249
- content: '▊';
250
- animation: blink 1s infinite;
251
- margin-left: 2px;
252
- }
253
-
254
- @keyframes blink {
255
- 0%, 50% { opacity: 1; }
256
- 51%, 100% { opacity: 0; }
257
- }
258
- """
259
-
260
- # JavaScript for copy functionality
261
- copy_js = """
262
- <script>
263
- function addCopyButtons() {
264
- // Remove existing copy buttons first
265
- document.querySelectorAll('.copy-button').forEach(btn => btn.remove());
266
-
267
- // Add copy buttons to bot messages
268
- const botMessages = document.querySelectorAll('.message-wrap.bot .message');
269
- botMessages.forEach((message, index) => {
270
- if (!message.querySelector('.copy-button')) {
271
- const copyBtn = document.createElement('button');
272
- copyBtn.className = 'copy-button';
273
- copyBtn.innerHTML = '📋 Copy';
274
- copyBtn.onclick = function(e) {
275
- e.stopPropagation();
276
- const text = message.innerText;
277
- navigator.clipboard.writeText(text).then(() => {
278
- copyBtn.innerHTML = '✅ Copied!';
279
- copyBtn.classList.add('copied');
280
- setTimeout(() => {
281
- copyBtn.innerHTML = '📋 Copy';
282
- copyBtn.classList.remove('copied');
283
- }, 2000);
284
- });
285
- };
286
- message.parentElement.style.position = 'relative';
287
- message.parentElement.appendChild(copyBtn);
288
- }
289
- });
290
- }
291
-
292
- // Run on load and periodically to catch new messages
293
- setInterval(addCopyButtons, 1000);
294
- window.addEventListener('load', addCopyButtons);
295
- </script>
296
  """
297
 
298
  # Splash screen HTML
@@ -308,39 +152,28 @@ splash_html = """
308
  """
309
 
310
  # ---------------- UI ----------------
311
- with gr.Blocks(css=custom_css, head=copy_js) as demo:
312
  gr.HTML(splash_html)
313
 
314
- enable_history = gr.State(value=False)
315
- max_new_tokens = gr.State(value=512)
316
-
317
- with gr.Row():
318
- with gr.Column():
319
- chat = gr.ChatInterface(
320
- fn=generate_response,
321
- title="🧠 Gamunu 4B Instruct - Demo",
322
- theme=gr.themes.Default(text_size="sm"),
323
- type="messages", # Use new messages format
324
- examples=[
325
- ["හෙලෝ ගැමුණු! මම සමන්, ඔයාට කොහොමද?"],
326
- ["ෆොටෝසින්තසිස් ක්‍රියාවලිය පැහැදිලි කරන්න."],
327
- ["මෙම වාක්‍යය සිංහලයට පරිවර්තනය කරන්න: 'The sun rises in the east.'"],
328
- ["'completed' තත්ත්වයේ ඇති වාර්තා ගණන ගණනය කිරීමට දත්ත සමුදා විමසුමක් (database query) ගොඩනඟන්න."],
329
- ["ඔබ ගුරුවරයෙකු ලෙස ක්‍රියාකරන්න. ශිෂ්‍යයාට ඉතිහාසය උගන්වන්න."],
330
- ["පහත ප්‍රකාශය ප්‍රංශ භාෂාවට පරිවර්තනය කරන්න. Laughter is the best medicine."],
331
- ["ඝන වස්තුවක හා ද්‍රවයක පරිමාවන්හි වෙනස පැහැදිලි කරන්න."],
332
- ["වෙස් මුහුණු කලාවේ ප්‍රධාන අංග මොනවාද? වර්තමානයේ මෙම කලාව ප්‍රචලිතව පවතින ප්‍රදේශ මොනවාද?"]
333
- ],
334
- additional_inputs=[enable_history, max_new_tokens]
335
- )
336
 
337
  with gr.Accordion("⚙️ Advanced Settings", open=False):
338
- history_checkbox = gr.Checkbox(label="Enable chat history", value=False)
339
- tokens_slider = gr.Slider(64, 1024, value=512, step=32, label="🔢 Max New Tokens")
340
-
341
- # Update state when controls change
342
- history_checkbox.change(fn=lambda x: x, inputs=history_checkbox, outputs=enable_history)
343
- tokens_slider.change(fn=lambda x: x, inputs=tokens_slider, outputs=max_new_tokens)
344
 
345
  gr.Markdown("""
346
  ---
 
1
  import gradio as gr
2
  import torch
3
  import spaces
4
+ from transformers import AutoModelForCausalLM, AutoTokenizer
 
5
 
6
  model = None
7
  tokenizer = None
 
14
  ### ප්‍රතිචාරය:
15
  {}"""
16
 
17
+ @spaces.GPU
18
+ def infer(message, history, enable_history=False, max_new_tokens=1024):
19
  global model, tokenizer
20
+
21
  if model is None:
22
  tokenizer = AutoTokenizer.from_pretrained("manthilaffs/Gamunu-4B-Instruct-Alpha")
23
  model = AutoModelForCausalLM.from_pretrained(
 
26
  device_map="auto",
27
  )
28
  model.eval()
 
 
 
 
 
 
29
 
30
  # Add history only if enabled
31
  if enable_history and history:
32
  prev = "\n".join(
33
+ [f"User: {h[0]}\nGamunu: {h[1]}" for h in history if h[1] is not None]
 
34
  )
35
  context = f"{prev}\n\n{message}"
36
  else:
 
45
 
46
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
47
 
48
+ with torch.inference_mode():
49
+ outputs = model.generate(**inputs, max_new_tokens=max_new_tokens)
50
 
51
+ text = tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
52
 
53
+ if "### ප්‍රතිචාරය:" in text:
54
+ text = text.split("### ප්‍රතිචාරය:")[-1].strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
+ return text
 
 
57
 
58
+ # Custom CSS for styling
59
  custom_css = """
 
 
 
 
 
 
 
 
 
 
 
 
60
  #splash-screen {
61
  position: fixed;
62
  top: 0;
 
93
  75% { transform: rotateY(540deg) scale(1.2); opacity: 0.8; }
94
  }
95
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
  /* Smaller font sizes for chat */
97
  .message-wrap .message {
98
  font-size: 0.9rem !important;
 
102
  font-size: 0.9rem !important;
103
  }
104
 
 
 
 
 
 
 
 
 
 
105
  /* Avatar styling */
106
  .message-wrap.user .avatar-container::before {
107
  content: "👤";
 
137
  height: 40px !important;
138
  min-width: 40px !important;
139
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
  """
141
 
142
  # Splash screen HTML
 
152
  """
153
 
154
  # ---------------- UI ----------------
155
+ with gr.Blocks(css=custom_css) as demo:
156
  gr.HTML(splash_html)
157
 
158
+ chat = gr.ChatInterface(
159
+ fn=lambda message, history: infer(message, history, enable_history.value, max_new_tokens.value),
160
+ title="🧠 Gamunu 4B Instruct - Demo",
161
+ theme=gr.themes.Default(text_size="sm"),
162
+ examples=[
163
+ ["හෙලෝ ගැමුණු! මම සමන්, ඔයාට කොහොමද?"],
164
+ ["ෆොටෝසින්තසිස් ක්‍රියාවලිය පැහැදිලි කරන්න."],
165
+ ["මෙම වාක්‍යය සිංහලයට පරිවර්තනය කරන්න: 'The sun rises in the east.'"],
166
+ ["'completed' තත්ත්වයේ ඇති වාර්තා ගණන ගණනය කිරීමට දත්ත සමුදා විමසුමක් (database query) ගොඩනඟන්න."],
167
+ ["ඔබ ගුරුවරයෙකු ලෙස ක්‍රියාකරන්න. ශිෂ්‍යයාට ඉතිහාසය උගන්වන්න."],
168
+ ["පහත ප්‍රකාශය ප්‍රංශ භාෂාවට පරිවර්���නය කරන්න. Laughter is the best medicine."],
169
+ ["ඝන වස්තුවක හා ද්‍රවයක පරිමාවන්හි වෙනස පැහැදිලි කරන්න."],
170
+ ["වෙස් මුහුණු කලාවේ ප්‍රධාන අංග මොනවාද? වර්තමානයේ මෙම කලාව ප්‍රචලිතව පවතින ප්‍රදේශ මොනවාද?"]
171
+ ]
172
+ )
 
 
 
 
 
 
 
173
 
174
  with gr.Accordion("⚙️ Advanced Settings", open=False):
175
+ enable_history = gr.Checkbox(label="Enable chat history", value=False)
176
+ max_new_tokens = gr.Slider(64, 1024, value=512, step=32, label="🔢 Max New Tokens")
 
 
 
 
177
 
178
  gr.Markdown("""
179
  ---