Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -4,6 +4,7 @@
|
|
| 4 |
# - Loads final.pt from repo OpenTransformer/AGILLM2-fast-training
|
| 5 |
# - Qwen tokenizer + chat template
|
| 6 |
# - Optional local CLI REPL when run in a terminal
|
|
|
|
| 7 |
|
| 8 |
from __future__ import annotations
|
| 9 |
import os, sys, time, math, argparse
|
|
@@ -235,6 +236,15 @@ def render_chat(messages: List[Dict[str, str]], add_generation_prompt: bool = Tr
|
|
| 235 |
out.append("Assistant:")
|
| 236 |
return "\n".join(out)
|
| 237 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 238 |
def _apply_no_repeat_ngram(logits: torch.Tensor, ids: torch.Tensor, n: int):
|
| 239 |
if n <= 0 or ids.size(1) < n - 1: return logits
|
| 240 |
prefix = ids[0, -(n - 1):].tolist()
|
|
@@ -319,30 +329,49 @@ def launch_gradio(core, ar_h):
|
|
| 319 |
import gradio as gr
|
| 320 |
with gr.Blocks() as demo:
|
| 321 |
gr.Markdown("### OpenTransformer / AGILLM2 — Chat")
|
|
|
|
| 322 |
with gr.Row():
|
| 323 |
temp = gr.Slider(0.1, 1.5, value=0.9, step=0.05, label="Temperature")
|
| 324 |
topp = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p")
|
| 325 |
topk = gr.Slider(0, 200, value=50, step=1, label="Top-k")
|
| 326 |
mxnt = gr.Slider(16, 1024, value=200, step=8, label="Max new tokens")
|
| 327 |
-
|
| 328 |
-
|
| 329 |
-
|
| 330 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 331 |
|
| 332 |
def _chat(history, user_msg, t, p, k, mxt, sys_p):
|
|
|
|
| 333 |
messages = [{"role":"system","content":sys_p}]
|
| 334 |
for u,a in history or []:
|
| 335 |
messages.append({"role":"user","content":u})
|
| 336 |
messages.append({"role":"assistant","content":a})
|
| 337 |
messages.append({"role":"user","content":user_msg})
|
| 338 |
-
|
| 339 |
-
|
| 340 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 341 |
history = (history or []) + [(user_msg, reply)]
|
| 342 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 343 |
|
| 344 |
-
|
| 345 |
-
|
|
|
|
| 346 |
|
| 347 |
demo.launch(server_name="0.0.0.0", server_port=int(os.getenv("PORT", "7860")))
|
| 348 |
|
|
@@ -363,6 +392,9 @@ def run_cli(core, ar_h):
|
|
| 363 |
dt = time.time()-t0
|
| 364 |
print(f"Bot: {reply}\n[{len(tok.encode(reply))} tok in {dt:.2f}s]")
|
| 365 |
history.append((user, reply))
|
|
|
|
|
|
|
|
|
|
| 366 |
except KeyboardInterrupt:
|
| 367 |
print("\nbye."); break
|
| 368 |
|
|
|
|
| 4 |
# - Loads final.pt from repo OpenTransformer/AGILLM2-fast-training
|
| 5 |
# - Qwen tokenizer + chat template
|
| 6 |
# - Optional local CLI REPL when run in a terminal
|
| 7 |
+
# - Adds a "Raw transcript" tab with "User:" / "Assistant:" lines
|
| 8 |
|
| 9 |
from __future__ import annotations
|
| 10 |
import os, sys, time, math, argparse
|
|
|
|
| 236 |
out.append("Assistant:")
|
| 237 |
return "\n".join(out)
|
| 238 |
|
| 239 |
+
def render_raw(history: List[Tuple[str, str]] | None, sys_prompt: str) -> str:
|
| 240 |
+
lines = []
|
| 241 |
+
if sys_prompt:
|
| 242 |
+
lines.append(f"System: {sys_prompt}")
|
| 243 |
+
for u, a in (history or []):
|
| 244 |
+
lines.append(f"User: {u}")
|
| 245 |
+
lines.append(f"Assistant: {a}")
|
| 246 |
+
return "\n".join(lines)
|
| 247 |
+
|
| 248 |
def _apply_no_repeat_ngram(logits: torch.Tensor, ids: torch.Tensor, n: int):
|
| 249 |
if n <= 0 or ids.size(1) < n - 1: return logits
|
| 250 |
prefix = ids[0, -(n - 1):].tolist()
|
|
|
|
| 329 |
import gradio as gr
|
| 330 |
with gr.Blocks() as demo:
|
| 331 |
gr.Markdown("### OpenTransformer / AGILLM2 — Chat")
|
| 332 |
+
|
| 333 |
with gr.Row():
|
| 334 |
temp = gr.Slider(0.1, 1.5, value=0.9, step=0.05, label="Temperature")
|
| 335 |
topp = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p")
|
| 336 |
topk = gr.Slider(0, 200, value=50, step=1, label="Top-k")
|
| 337 |
mxnt = gr.Slider(16, 1024, value=200, step=8, label="Max new tokens")
|
| 338 |
+
sys_prompt = gr.Textbox(value="You are a helpful, concise assistant.", label="System prompt")
|
| 339 |
+
|
| 340 |
+
with gr.Tabs():
|
| 341 |
+
with gr.TabItem("Chat"):
|
| 342 |
+
chatbot = gr.Chatbot(height=520, label="Conversation")
|
| 343 |
+
msg = gr.Textbox(placeholder="Say something useful…", label="Message")
|
| 344 |
+
submit = gr.Button("Send", variant="primary")
|
| 345 |
+
with gr.TabItem("Raw transcript"):
|
| 346 |
+
raw = gr.Textbox(lines=24, label="Raw transcript (User:/Assistant:)", interactive=False)
|
| 347 |
+
|
| 348 |
+
clear = gr.Button("Clear", variant="secondary")
|
| 349 |
|
| 350 |
def _chat(history, user_msg, t, p, k, mxt, sys_p):
|
| 351 |
+
# Build messages from history + new user message
|
| 352 |
messages = [{"role":"system","content":sys_p}]
|
| 353 |
for u,a in history or []:
|
| 354 |
messages.append({"role":"user","content":u})
|
| 355 |
messages.append({"role":"assistant","content":a})
|
| 356 |
messages.append({"role":"user","content":user_msg})
|
| 357 |
+
|
| 358 |
+
reply = chat_decode(
|
| 359 |
+
core, ar_h, messages,
|
| 360 |
+
max_new=int(mxt), T=float(t),
|
| 361 |
+
greedy=False, top_k=int(k), top_p=float(p),
|
| 362 |
+
use_fp8=False, fp8_fallback=True
|
| 363 |
+
)
|
| 364 |
history = (history or []) + [(user_msg, reply)]
|
| 365 |
+
transcript = render_raw(history, sys_p)
|
| 366 |
+
return history, "", transcript
|
| 367 |
+
|
| 368 |
+
# Wire up events: submit via button or enter
|
| 369 |
+
msg.submit(_chat, [chatbot, msg, temp, topp, topk, mxnt, sys_prompt], [chatbot, msg, raw], queue=False)
|
| 370 |
+
submit.click(_chat, [chatbot, msg, temp, topp, topk, mxnt, sys_prompt], [chatbot, msg, raw], queue=False)
|
| 371 |
|
| 372 |
+
def _clear():
|
| 373 |
+
return [], "", ""
|
| 374 |
+
clear.click(_clear, inputs=None, outputs=[chatbot, msg, raw], queue=False)
|
| 375 |
|
| 376 |
demo.launch(server_name="0.0.0.0", server_port=int(os.getenv("PORT", "7860")))
|
| 377 |
|
|
|
|
| 392 |
dt = time.time()-t0
|
| 393 |
print(f"Bot: {reply}\n[{len(tok.encode(reply))} tok in {dt:.2f}s]")
|
| 394 |
history.append((user, reply))
|
| 395 |
+
# Also show raw transcript line by line in CLI
|
| 396 |
+
print("\n--- RAW ---")
|
| 397 |
+
print(render_raw(history, "You are a helpful, concise assistant."))
|
| 398 |
except KeyboardInterrupt:
|
| 399 |
print("\nbye."); break
|
| 400 |
|