Spaces:
Running
Running
| import os | |
| import time | |
| import gc | |
| import sys | |
| import threading | |
| from itertools import islice | |
| from datetime import datetime | |
| import re | |
| import gradio as gr | |
| import torch | |
| from transformers import pipeline, TextIteratorStreamer | |
| from transformers import AutoTokenizer | |
| from ddgs import DDGS | |
| from config import MODELS # Import from config file | |
| # Global event to signal cancellation from the UI thread to the generation thread | |
| cancel_event = threading.Event() | |
| access_token = os.environ.get('HF_TOKEN', '') | |
| # Global cache for pipelines to avoid re-loading. | |
| PIPELINES = {} | |
| def load_pipeline(model_name): | |
| """ | |
| Load and cache a transformers pipeline for text generation. | |
| Tries bfloat16, falls back to float16 or float32 if unsupported. | |
| """ | |
| global PIPELINES | |
| if model_name in PIPELINES: | |
| return PIPELINES[model_name] | |
| repo = MODELS[model_name]["repo_id"] | |
| tokenizer = AutoTokenizer.from_pretrained(repo, token=access_token) | |
| for dtype in (torch.bfloat16, torch.float16, torch.float32): | |
| try: | |
| pipe = pipeline( | |
| task="text-generation", | |
| model=repo, | |
| tokenizer=tokenizer, | |
| trust_remote_code=True, | |
| dtype=dtype, | |
| device_map="auto", | |
| use_cache=True, | |
| token=access_token) | |
| PIPELINES[model_name] = pipe | |
| return pipe | |
| except Exception: | |
| continue | |
| # Final fallback | |
| pipe = pipeline( | |
| task="text-generation", | |
| model=repo, | |
| tokenizer=tokenizer, | |
| trust_remote_code=True, | |
| device_map="auto", | |
| use_cache=True | |
| ) | |
| PIPELINES[model_name] = pipe | |
| return pipe | |
| def retrieve_context(query, max_results=6, max_chars=50): | |
| """ | |
| Retrieve search snippets from DuckDuckGo (runs in background). | |
| Returns a list of result strings. | |
| """ | |
| try: | |
| with DDGS() as ddgs: | |
| return [f"{i+1}. {r.get('title','No Title')} - {r.get('body','')[:max_chars]}" | |
| for i, r in enumerate(islice(ddgs.text(query, region="wt-wt", safesearch="off", timelimit="y"), max_results))] | |
| except Exception: | |
| return [] | |
| def format_conversation(history, system_prompt, tokenizer): | |
| if hasattr(tokenizer, "chat_template") and tokenizer.chat_template: | |
| messages = [{"role": "system", "content": system_prompt.strip()}] + history | |
| return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True, enable_thinking=True) | |
| else: | |
| # Fallback for base LMs without chat template | |
| prompt = system_prompt.strip() + "\n" | |
| for msg in history: | |
| if msg['role'] == 'user': | |
| prompt += "User: " + msg['content'].strip() + "\n" | |
| elif msg['role'] == 'assistant': | |
| prompt += "Assistant: " + msg['content'].strip() + "\n" | |
| if not prompt.strip().endswith("Assistant:"): | |
| prompt += "Assistant: " | |
| return prompt | |
| def get_duration(user_msg, chat_history, system_prompt, enable_search, max_results, max_chars, model_name, max_tokens, temperature, top_k, top_p, repeat_penalty, search_timeout): | |
| # Get model size from the MODELS dict | |
| model_size = MODELS[model_name].get("params_b", 4.0) | |
| # Only use AOT for models >= 2B parameters | |
| use_aot = model_size >= 2 | |
| # Adjusted for H200 performance | |
| base_duration = 20 if not use_aot else 40 | |
| token_duration = max_tokens * 0.005 | |
| search_duration = 10 if enable_search else 0 | |
| aot_compilation_buffer = 20 if use_aot else 0 | |
| return base_duration + token_duration + search_duration + aot_compilation_buffer | |
| def get_model_size(model_name): | |
| """Get model size from the MODELS dict.""" | |
| return MODELS.get(model_name, {}).get("params_b", 4.0) | |
| def chat_response(user_msg, chat_history, system_prompt, | |
| enable_search, max_results, max_chars, | |
| model_name, max_tokens, temperature, | |
| top_k, top_p, repeat_penalty, search_timeout): | |
| """ | |
| Generates streaming chat responses, optionally with background web search. | |
| This version includes cancellation support. | |
| """ | |
| # Clear the cancellation event at the start of a new generation | |
| cancel_event.clear() | |
| history = list(chat_history or []) | |
| history.append({'role': 'user', 'content': user_msg}) | |
| # Launch web search if enabled | |
| debug = '' | |
| search_results = [] | |
| if enable_search: | |
| debug = 'Search task started.' | |
| thread_search = threading.Thread( | |
| target=lambda: search_results.extend( | |
| retrieve_context(user_msg, int(max_results), int(max_chars)) | |
| ) | |
| ) | |
| thread_search.daemon = True | |
| thread_search.start() | |
| else: | |
| debug = 'Web search disabled.' | |
| # Wait for search results if enabled | |
| if enable_search: | |
| thread_search.join(timeout=float(search_timeout)) | |
| if search_results: | |
| debug = "### Search results merged into prompt\n\n" + "\n".join( | |
| f"- {r}" for r in search_results | |
| ) | |
| else: | |
| debug = "*No web search results found.*" | |
| try: | |
| cur_date = datetime.now().strftime('%Y-%m-%d') | |
| # Prepare enriched system prompt | |
| if search_results: | |
| enriched = system_prompt.strip() + f""" | |
| # SEARCH CONTEXT (TRUSTED SOURCES ONLY) | |
| Below are web search results. Treat them as the ONLY source of truth for answering. | |
| {search_results} | |
| RULES (VERY IMPORTANT): | |
| - Do NOT use outside knowledge. Do NOT guess or fill missing information. | |
| - If the answer is not clearly supported by the search results, say: "Not enough information in the provided sources." | |
| - Every factual statement must be directly supported by at least one citation [citation:X]. | |
| - Do NOT add explanations, examples, or background that are not explicitly present in the sources. | |
| - Do NOT paraphrase beyond what is necessary for clarity. | |
| - If sources conflict, mention the conflict and cite both. | |
| - If multiple sources are used, distribute citations per sentence, not only at the end. | |
| CITATION RULES: | |
| - Use inline citations like this: [citation:1] | |
| - If multiple sources support a sentence: [citation:1][citation:3] | |
| - Never place all citations only at the end. | |
| ANSWER POLICY: | |
| - Be concise and strictly grounded. | |
| - No speculation, no assumptions, no "likely", no "probably". | |
| - If the user requests a list, only include items explicitly found in sources. | |
| - If sources are insufficient, stop and ask for more data instead of guessing. | |
| DATE CONTEXT: | |
| - Today is {cur_date} (use only for time reference, not for assumptions). | |
| USER QUESTION: | |
| """ | |
| else: | |
| enriched = system_prompt.strip() | |
| pipe = load_pipeline(model_name) | |
| prompt = format_conversation(history, enriched, pipe.tokenizer) | |
| prompt_debug = f"\n\n--- Prompt Preview ---\n```\n{prompt}\n```" | |
| streamer = TextIteratorStreamer(pipe.tokenizer, | |
| skip_prompt=True, | |
| skip_special_tokens=True) | |
| gen_thread = threading.Thread( | |
| target=pipe, | |
| args=(prompt,), | |
| kwargs={ | |
| 'max_new_tokens': max_tokens, | |
| 'temperature': temperature, | |
| 'top_k': top_k, | |
| 'top_p': top_p, | |
| 'repetition_penalty': repeat_penalty, | |
| 'streamer': streamer, | |
| 'return_full_text': False, | |
| } | |
| ) | |
| gen_thread.start() | |
| # Buffers for thought vs answer | |
| thought_buf = '' | |
| answer_buf = '' | |
| in_thought = False | |
| assistant_message_started = False | |
| # First yield contains the user message | |
| yield history, debug | |
| # Stream tokens | |
| for chunk in streamer: | |
| # Check for cancellation signal | |
| if cancel_event.is_set(): | |
| if assistant_message_started and history and history[-1]['role'] == 'assistant': | |
| history[-1]['content'] += " [Generation Canceled]" | |
| yield history, debug | |
| break | |
| text = chunk | |
| # Detect start of thinking | |
| if not in_thought and '<think>' in text: | |
| in_thought = True | |
| history.append({'role': 'assistant', 'content': '', 'metadata': {'title': '💭 Thought'}}) | |
| assistant_message_started = True | |
| after = text.split('<think>', 1)[1] | |
| thought_buf += after | |
| if '</think>' in thought_buf: | |
| before, after2 = thought_buf.split('</think>', 1) | |
| history[-1]['content'] = before.strip() | |
| in_thought = False | |
| answer_buf = after2 | |
| history.append({'role': 'assistant', 'content': answer_buf}) | |
| else: | |
| history[-1]['content'] = thought_buf | |
| yield history, debug | |
| continue | |
| if in_thought: | |
| thought_buf += text | |
| if '</think>' in thought_buf: | |
| before, after2 = thought_buf.split('</think>', 1) | |
| history[-1]['content'] = before.strip() | |
| in_thought = False | |
| answer_buf = after2 | |
| history.append({'role': 'assistant', 'content': answer_buf}) | |
| else: | |
| history[-1]['content'] = thought_buf | |
| yield history, debug | |
| continue | |
| # Stream answer | |
| if not assistant_message_started: | |
| history.append({'role': 'assistant', 'content': ''}) | |
| assistant_message_started = True | |
| answer_buf += text | |
| history[-1]['content'] = answer_buf.strip() | |
| yield history, debug | |
| gen_thread.join() | |
| yield history, debug + prompt_debug | |
| except GeneratorExit: | |
| # Handle cancellation gracefully | |
| print("Chat response cancelled.") | |
| return | |
| except Exception as e: | |
| history.append({'role': 'assistant', 'content': f"Error: {e}"}) | |
| yield history, debug | |
| finally: | |
| gc.collect() | |
| def update_default_prompt(enable_search): | |
| return f"You are a helpful assistant." | |
| def update_duration_estimate(model_name, enable_search, max_results, max_chars, max_tokens, search_timeout): | |
| """Calculate and format the estimated GPU duration for current settings.""" | |
| try: | |
| dummy_msg, dummy_history, dummy_system_prompt = "", [], "" | |
| duration = get_duration(dummy_msg, dummy_history, dummy_system_prompt, | |
| enable_search, max_results, max_chars, model_name, | |
| max_tokens, 0.7, 40, 0.9, 1.2, search_timeout) | |
| model_size = get_model_size(model_name) | |
| return (f"⏱️ **Estimated GPU Time: {duration:.1f} seconds**\n\n" | |
| f"📊 **Model Size:** {model_size:.1f}B parameters\n" | |
| f"🔍 **Web Search:** {'Enabled' if enable_search else 'Disabled'}") | |
| except Exception as e: | |
| return f"⚠️ Error calculating estimate: {e}" | |
| # ------------------------------ | |
| # Gradio UI | |
| # ------------------------------ | |
| with gr.Blocks( | |
| title="LLM Inference", | |
| theme=gr.themes.Soft( | |
| primary_hue="blue", | |
| secondary_hue="blue", | |
| neutral_hue="slate", | |
| radius_size="lg", | |
| font=[gr.themes.GoogleFont("Syne"), "Arial", "sans-serif"] | |
| ), | |
| css=""" | |
| .duration-estimate { background: linear-gradient(135deg, #667eea15 0%, #764ba215 100%); border-left: 4px solid #667eea; padding: 12px; border-radius: 8px; margin: 16px 0; } | |
| .chatbot { border-radius: 12px; box-shadow: 0 4px 6px -1px rgba(0, 0, 0, 0.1); } | |
| button.primary { font-weight: 600; } | |
| .gradio-accordion { margin-bottom: 12px; } | |
| """ | |
| ) as demo: | |
| # Header | |
| gr.Markdown(""" | |
| # 🧠 CPU LLM Inference | |
| """) | |
| with gr.Row(): | |
| # Left Panel - Configuration | |
| with gr.Column(scale=3): | |
| # Core Settings (Always Visible) | |
| with gr.Group(): | |
| gr.Markdown("### ⚙️ Core Settings") | |
| model_dd = gr.Dropdown( | |
| label="🤖 Model", | |
| choices=list(MODELS.keys()), | |
| value="Qwen3-1.7B", | |
| info="Select the language model to use" | |
| ) | |
| search_chk = gr.Checkbox( | |
| label="🔍 Enable Web Search", | |
| value=False, | |
| info="Augment responses with real-time web data" | |
| ) | |
| sys_prompt = gr.Textbox(label="📝 System Prompt", lines=3, value=update_default_prompt(False), placeholder="Define the assistant's behavior and personality...") | |
| # Duration Estimate | |
| duration_display = gr.Markdown( | |
| value=update_duration_estimate("Qwen3-1.7B", False, 4, 50, 1024, 5.0), | |
| elem_classes="duration-estimate" | |
| ) | |
| # Advanced Settings (Collapsible) | |
| with gr.Accordion("🎛️ Advanced Generation Parameters", open=False): | |
| max_tok = gr.Slider( | |
| 64, 16384, value=1024, step=32, | |
| label="Max Tokens", | |
| info="Maximum length of generated response" | |
| ) | |
| temp = gr.Slider( | |
| 0.1, 2.0, value=0.7, step=0.1, | |
| label="Temperature", | |
| info="Higher = more creative, Lower = more focused" | |
| ) | |
| with gr.Row(): | |
| k = gr.Slider( | |
| 1, 100, value=40, step=1, | |
| label="Top-K", | |
| info="Number of top tokens to consider" | |
| ) | |
| p = gr.Slider( | |
| 0.1, 1.0, value=0.9, step=0.05, | |
| label="Top-P", | |
| info="Nucleus sampling threshold" | |
| ) | |
| rp = gr.Slider( | |
| 1.0, 2.0, value=1.2, step=0.1, | |
| label="Repetition Penalty", | |
| info="Penalize repeated tokens" | |
| ) | |
| # Web Search Settings (Collapsible) | |
| with gr.Accordion("🌐 Web Search Settings", open=False, visible=False) as search_settings: | |
| mr = gr.Number( | |
| value=4, precision=0, | |
| label="Max Results", | |
| info="Number of search results to retrieve" | |
| ) | |
| mc = gr.Number( | |
| value=50, precision=0, | |
| label="Max Chars/Result", | |
| info="Character limit per search result" | |
| ) | |
| st = gr.Slider( | |
| minimum=0.0, maximum=30.0, step=0.5, value=5.0, | |
| label="Search Timeout (s)", | |
| info="Maximum time to wait for search results" | |
| ) | |
| # Actions | |
| with gr.Row(): | |
| clr = gr.Button("🗑️ Clear Chat", variant="secondary", scale=1) | |
| # Right Panel - Chat Interface | |
| with gr.Column(scale=7): | |
| chat = gr.Chatbot( | |
| type="messages", | |
| height=600, | |
| label="💬 Conversation", | |
| show_copy_button=True, | |
| avatar_images=(None, "🤖"), | |
| bubble_full_width=False | |
| ) | |
| # Input Area | |
| with gr.Row(): | |
| txt = gr.Textbox( | |
| placeholder="💭 Type your message here... (Press Enter to send)", | |
| scale=9, | |
| container=False, | |
| show_label=False, | |
| lines=1, | |
| max_lines=5 | |
| ) | |
| with gr.Column(scale=1, min_width=120): | |
| submit_btn = gr.Button("📤 Send", variant="primary", size="lg") | |
| cancel_btn = gr.Button("⏹️ Stop", variant="stop", visible=False, size="lg") | |
| # Example Prompts | |
| gr.Examples( | |
| examples=[ | |
| ["Explain quantum computing in simple terms"], | |
| ["Write a Python function to calculate fibonacci numbers"], | |
| ["What are the latest developments in AI? (Enable web search)"], | |
| ["Tell me a creative story about a time traveler"], | |
| ["Help me debug this code: def add(a,b): return a+b+1"] | |
| ], | |
| inputs=txt, | |
| label="💡 Example Prompts" | |
| ) | |
| # Debug/Status Info (Collapsible) | |
| with gr.Accordion("🔍 Debug Info", open=False): | |
| dbg = gr.Markdown() | |
| # Footer | |
| gr.Markdown(""" | |
| --- | |
| 💡 **Tips:** | |
| - Use **Advanced Parameters** to fine-tune creativity and response length | |
| - Enable **Web Search** for real-time, up-to-date information | |
| - Try different **models** for various tasks (reasoning, coding, general chat) | |
| - Click the **Copy** button on responses to save them to your clipboard | |
| """, elem_classes="footer") | |
| # --- Event Listeners --- | |
| # Group all inputs for cleaner event handling | |
| chat_inputs = [txt, chat, sys_prompt, search_chk, mr, mc, model_dd, max_tok, temp, k, p, rp, st] | |
| # Group all UI components that can be updated. | |
| ui_components = [chat, dbg, txt, submit_btn, cancel_btn] | |
| def submit_and_manage_ui(user_msg, chat_history, *args): | |
| """ | |
| Orchestrator function that manages UI state and calls the backend chat function. | |
| It uses a try...finally block to ensure the UI is always reset. | |
| """ | |
| if not user_msg.strip(): | |
| yield {} | |
| return | |
| # Update UI to "generating" state | |
| yield { | |
| txt: gr.update(value="", interactive=False), | |
| submit_btn: gr.update(interactive=False), | |
| cancel_btn: gr.update(visible=True), | |
| } | |
| cancelled = False | |
| try: | |
| backend_args = [user_msg, chat_history] + list(args) | |
| for response_chunk in chat_response(*backend_args): | |
| yield { | |
| chat: response_chunk[0], | |
| dbg: response_chunk[1], | |
| } | |
| except GeneratorExit: | |
| cancelled = True | |
| print("Generation cancelled by user.") | |
| raise | |
| except Exception as e: | |
| print(f"An error occurred during generation: {e}") | |
| error_history = (chat_history or []) + [ | |
| {'role': 'user', 'content': user_msg}, | |
| {'role': 'assistant', 'content': f"**An error occurred:** {str(e)}"} | |
| ] | |
| yield {chat: error_history} | |
| finally: | |
| if not cancelled: | |
| print("Resetting UI state.") | |
| yield { | |
| txt: gr.update(interactive=True), | |
| submit_btn: gr.update(interactive=True), | |
| cancel_btn: gr.update(visible=False), | |
| } | |
| def set_cancel_flag(): | |
| """Called by the cancel button, sets the global event.""" | |
| cancel_event.set() | |
| print("Cancellation signal sent.") | |
| def reset_ui_after_cancel(): | |
| """Reset UI components after cancellation.""" | |
| cancel_event.clear() | |
| print("UI reset after cancellation.") | |
| return { | |
| txt: gr.update(interactive=True), | |
| submit_btn: gr.update(interactive=True), | |
| cancel_btn: gr.update(visible=False), | |
| } | |
| # Event for submitting text via Enter key or Submit button | |
| submit_event = txt.submit( | |
| fn=submit_and_manage_ui, | |
| inputs=chat_inputs, | |
| outputs=ui_components, | |
| ) | |
| submit_btn.click( | |
| fn=submit_and_manage_ui, | |
| inputs=chat_inputs, | |
| outputs=ui_components, | |
| ) | |
| # Event for the "Cancel" button. | |
| cancel_btn.click( | |
| fn=set_cancel_flag, | |
| cancels=[submit_event] | |
| ).then( | |
| fn=reset_ui_after_cancel, | |
| outputs=ui_components | |
| ) | |
| # Listeners for updating the duration estimate | |
| duration_inputs = [model_dd, search_chk, mr, mc, max_tok, st] | |
| for component in duration_inputs: | |
| component.change(fn=update_duration_estimate, inputs=duration_inputs, outputs=duration_display) | |
| # Toggle web search settings visibility | |
| def toggle_search_settings(enabled): | |
| return gr.update(visible=enabled) | |
| search_chk.change( | |
| fn=lambda enabled: (update_default_prompt(enabled), gr.update(visible=enabled)), | |
| inputs=search_chk, | |
| outputs=[sys_prompt, search_settings] | |
| ) | |
| # Clear chat action | |
| clr.click(fn=lambda: ([], "", ""), outputs=[chat, txt, dbg]) | |
| demo.launch() |