Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| import threading | |
| import time | |
| import torch | |
| import gradio as gr | |
| from transformers import ( | |
| AutoModelForCausalLM, | |
| AutoTokenizer, | |
| TextIteratorStreamer, | |
| ) | |
| import spaces | |
| MODEL_ID = os.getenv("MODEL_ID", "yasserrmd/SoftwareArchitecture-Instruct-v1") | |
| # -------- Load model & tokenizer -------- | |
| print(f"Loading model: {MODEL_ID}") | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True, trust_remote_code=True) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_ID, | |
| device_map="auto", | |
| torch_dtype="auto", | |
| low_cpu_mem_usage=True, | |
| trust_remote_code=True, | |
| ) | |
| model.eval() | |
| # Ensure a pad token to avoid warnings on some bases | |
| if tokenizer.pad_token_id is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| TITLE = "SoftwareArchitecture-Instruct v1 — Chat" | |
| DESCRIPTION = ( | |
| "An instruction-tuned LLM for **software architecture**. " | |
| "Built on LiquidAI/LFM2-1.2B, fine-tuned with the Software-Architecture dataset. " | |
| "Designed for technical professionals: accurate, detailed, and on-topic answers." | |
| ) | |
| SAMPLES = [ | |
| "Explain the API Gateway pattern and when to use it.", | |
| "CQRS vs Event Sourcing — how do they relate, and when would you combine them?", | |
| "Design a resilient payment workflow with retries, idempotency keys, and DLQ.", | |
| "Rate limiting strategies for a public REST API: token bucket vs sliding window.", | |
| "Multi-tenant SaaS: compare shared DB, schema, and dedicated DB for isolation.", | |
| "Blue/green vs canary deployments — trade-offs and where each fits best.", | |
| ] | |
| def format_history_as_messages(history): | |
| """ | |
| Convert Gradio chat history into OpenAI-style messages for apply_chat_template. | |
| history: list of tuples (user, assistant) | |
| """ | |
| messages = [] | |
| for (u, a) in history: | |
| if u: | |
| messages.append({"role": "user", "content": u}) | |
| if a: | |
| messages.append({"role": "assistant", "content": a}) | |
| return messages | |
| def stream_generate(messages, max_new_tokens, temperature, top_p, repetition_penalty, seed=None): | |
| if seed is not None and seed >= 0: | |
| torch.manual_seed(seed) | |
| inputs = tokenizer.apply_chat_template( | |
| messages, | |
| add_generation_prompt=True, | |
| return_tensors="pt", | |
| tokenize=True, | |
| return_dict=True, | |
| ) | |
| # Keep only what the model expects | |
| allowed = {"input_ids", "attention_mask"} # no token_type_ids for causal LMs | |
| inputs = {k: v.to(model.device) for k, v in inputs.items() if k in allowed} | |
| streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) | |
| gen_kwargs = dict( | |
| **inputs, | |
| max_new_tokens=int(max_new_tokens), | |
| temperature=float(temperature), | |
| top_p=float(top_p), | |
| repetition_penalty=float(repetition_penalty), | |
| do_sample=temperature > 0, | |
| use_cache=True, | |
| streamer=streamer, | |
| ) | |
| thread = threading.Thread(target=model.generate, kwargs=gen_kwargs) | |
| thread.start() | |
| partial = "" | |
| for chunk in streamer: | |
| partial += chunk | |
| yield partial | |
| # -------- Gradio callbacks -------- | |
| def chat_respond(user_msg, chat_history, max_new_tokens, temperature, top_p, repetition_penalty, seed): | |
| if not user_msg or not user_msg.strip(): | |
| return gr.update(), chat_history | |
| # Add user turn | |
| chat_history = chat_history + [(user_msg, None)] | |
| # Build messages from full history | |
| messages = format_history_as_messages(chat_history) | |
| # Stream assistant output | |
| stream = stream_generate( | |
| messages=messages, | |
| max_new_tokens=int(max_new_tokens), | |
| temperature=float(temperature), | |
| top_p=float(top_p), | |
| repetition_penalty=float(repetition_penalty), | |
| seed=int(seed) if seed is not None else None, | |
| ) | |
| # Yield progressive updates for the last assistant turn | |
| final_assistant_text = "" | |
| for chunk in stream: | |
| final_assistant_text = chunk | |
| yield gr.update(value=chat_history[:-1] + [(user_msg, final_assistant_text)]), "" | |
| # Ensure final state returned | |
| chat_history[-1] = (user_msg, final_assistant_text) | |
| yield gr.update(value=chat_history), "" | |
| def use_sample(sample, chat_history): | |
| return sample, chat_history | |
| def clear_chat(): | |
| return [] | |
| # -------- UI -------- | |
| CUSTOM_CSS = """ | |
| :root { | |
| --brand: #0ea5e9; /* cyan-500 */ | |
| --ink: #0b1220; | |
| } | |
| .gradio-container { | |
| font-family: Inter, ui-sans-serif, system-ui, -apple-system, Segoe UI, Roboto, Helvetica, Arial, "Apple Color Emoji","Segoe UI Emoji"; | |
| } | |
| #title h1 { | |
| font-weight: 700; | |
| letter-spacing: -0.02em; | |
| } | |
| #desc { | |
| opacity: 0.9; | |
| } | |
| footer {visibility: hidden} | |
| """ | |
| with gr.Blocks(css=CUSTOM_CSS, theme=gr.themes.Soft(primary_hue="cyan")) as demo: | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.HTML(f"<div id='title'><h1>{TITLE}</h1></div>") | |
| gr.Markdown(f"<div id='desc'>{DESCRIPTION}</div>", elem_id="desc") | |
| with gr.Row(): | |
| with gr.Column(scale=4): | |
| chat = gr.Chatbot( | |
| label="SoftwareArchitecture-Instruct v1", | |
| avatar_images=(None, None), | |
| height=480, | |
| bubble_full_width=False, | |
| sanitize_html=False, | |
| ) | |
| with gr.Row(): | |
| user_box = gr.Textbox( | |
| placeholder="Ask about software architecture…", | |
| show_label=False, | |
| lines=3, | |
| autofocus=True, | |
| scale=4, | |
| ) | |
| send_btn = gr.Button("Send", variant="primary", scale=1) | |
| with gr.Accordion("Generation Settings", open=False): | |
| max_new_tokens = gr.Slider(64, 1024, value=256, step=16, label="Max new tokens") | |
| temperature = gr.Slider(0.0, 1.5, value=0.3, step=0.05, label="Temperature") | |
| top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p") | |
| repetition_penalty = gr.Slider(1.0, 1.5, value=1.05, step=0.01, label="Repetition penalty") | |
| seed = gr.Number(value=-1, precision=0, label="Seed (-1 for random)") | |
| with gr.Row(): | |
| clear_btn = gr.Button("Clear", variant="secondary") | |
| # sample buttons | |
| sample_dropdown = gr.Dropdown(choices=SAMPLES, label="Samples", value=None) | |
| use_sample_btn = gr.Button("Use Sample") | |
| with gr.Column(scale=2): | |
| gr.Markdown("### Samples") | |
| gr.Markdown("\n".join([f"• {s}" for s in SAMPLES])) | |
| gr.Markdown("—\n**Tip:** Increase *Max new tokens* for longer, more complete answers.") | |
| # Events | |
| send_btn.click( | |
| chat_respond, | |
| inputs=[user_box, chat, max_new_tokens, temperature, top_p, repetition_penalty, seed], | |
| outputs=[chat, user_box], | |
| queue=True, | |
| show_progress=True, | |
| ) | |
| user_box.submit( | |
| chat_respond, | |
| inputs=[user_box, chat, max_new_tokens, temperature, top_p, repetition_penalty, seed], | |
| outputs=[chat, user_box], | |
| queue=True, | |
| show_progress=True, | |
| ) | |
| clear_btn.click(fn=clear_chat, outputs=chat) | |
| use_sample_btn.click(use_sample, inputs=[sample_dropdown, chat], outputs=[user_box, chat]) | |
| gr.Markdown( | |
| "—\nBuilt for engineers and architects. Base model: **LiquidAI/LFM2-1.2B** · Fine-tuned: **Software-Architecture** dataset." | |
| ) | |
| if __name__ == "__main__": | |
| demo.queue().launch() | |