| import os |
| import time |
| import gc |
| import threading |
| from itertools import islice |
| from datetime import datetime |
| import gradio as gr |
| import torch |
| from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer |
| from duckduckgo_search import DDGS |
| import spaces |
|
|
| |
| |
|
|
| |
| |
| |
| cancel_event = threading.Event() |
|
|
| |
| |
| |
| MODELS = { |
| "Gemma-3-4B-IT": { |
| "repo_id": "unsloth/gemma-3-4b-it", |
| "description": "Gemma-3-4B-IT" |
| }, |
| "Llama-3.2-Taiwan-3B-Instruct": { |
| "repo_id": "lianghsun/Llama-3.2-Taiwan-3B-Instruct", |
| "description": "Llama-3.2-Taiwan-3B-Instruct" |
| }, |
| "MiniCPM3-4B": { |
| "repo_id": "openbmb/MiniCPM3-4B", |
| "description": "MiniCPM3-4B" |
| }, |
| "Qwen2.5-3B-Instruct": { |
| "repo_id": "Qwen/Qwen2.5-3B-Instruct", |
| "description": "Qwen2.5-3B-Instruct" |
| }, |
| "Qwen2.5-7B-Instruct": { |
| "repo_id": "Qwen/Qwen2.5-7B-Instruct", |
| "description": "Qwen2.5-7B-Instruct" |
| }, |
|
|
| "Phi-4-mini-Instruct": { |
| "repo_id": "unsloth/Phi-4-mini-instruct", |
| "description": "Phi-4-mini-Instruct" |
| }, |
| "Meta-Llama-3.1-8B-Instruct": { |
| "repo_id": "MaziyarPanahi/Meta-Llama-3.1-8B-Instruct", |
| "description": "Meta-Llama-3.1-8B-Instruct" |
| }, |
| "DeepSeek-R1-Distill-Llama-8B": { |
| "repo_id": "unsloth/DeepSeek-R1-Distill-Llama-8B", |
| "description": "DeepSeek-R1-Distill-Llama-8B" |
| }, |
| "Mistral-7B-Instruct-v0.3": { |
| "repo_id": "MaziyarPanahi/Mistral-7B-Instruct-v0.3", |
| "description": "Mistral-7B-Instruct-v0.3" |
| }, |
| "Qwen2.5-Coder-7B-Instruct": { |
| "repo_id": "Qwen/Qwen2.5-Coder-7B-Instruct", |
| "description": "Qwen2.5-Coder-7B-Instruct" |
| }, |
| } |
|
|
| |
| PIPELINES = {} |
|
|
| def load_pipeline(model_name): |
| """ |
| Load and cache a transformers pipeline for chat/text-generation. |
| Uses the model's repo_id from MODELS and caches the pipeline for future use. |
| """ |
| global PIPELINES |
| if model_name in PIPELINES: |
| return PIPELINES[model_name] |
| selected_model = MODELS[model_name] |
| |
| pipe = pipeline( |
| task="text-generation", |
| model=selected_model["repo_id"], |
| tokenizer=selected_model["repo_id"], |
| trust_remote_code=True, |
| torch_dtype=torch.bfloat16, |
| device_map="auto" |
| ) |
| PIPELINES[model_name] = pipe |
| return pipe |
|
|
| def retrieve_context(query, max_results=6, max_chars_per_result=600): |
| """ |
| Retrieve recent web search context for the given query using DuckDuckGo. |
| Returns a formatted string with search results. |
| """ |
| try: |
| with DDGS() as ddgs: |
| results = list(islice(ddgs.text(query, region="wt-wt", safesearch="off", timelimit="y"), max_results)) |
| context = "" |
| for i, result in enumerate(results, start=1): |
| title = result.get("title", "No Title") |
| snippet = result.get("body", "")[:max_chars_per_result] |
| context += f"Result {i}:\nTitle: {title}\nSnippet: {snippet}\n\n" |
| return context.strip() |
| except Exception: |
| return "" |
|
|
| |
| |
| |
| def format_conversation(conversation, system_prompt): |
| """ |
| Converts a list of conversation messages (each a dict with 'role' and 'content') |
| and a system prompt into a single plain text string. |
| This prevents raw role labels from being passed to the model. |
| """ |
| |
| prompt = system_prompt.strip() + "\n" |
| |
| for msg in conversation: |
| if msg["role"] == "user": |
| prompt += "User: " + msg["content"].strip() + "\n" |
| elif msg["role"] == "assistant": |
| prompt += "Assistant: " + msg["content"].strip() + "\n" |
| elif msg["role"] == "system": |
| prompt += msg["content"].strip() + "\n" |
| |
| if not prompt.strip().endswith("Assistant:"): |
| prompt += "Assistant: " |
| return prompt |
|
|
| |
| |
| |
| @spaces.GPU(duration=60) |
| def chat_response(user_message, chat_history, system_prompt, enable_search, |
| max_results, max_chars, model_name, max_tokens, temperature, top_k, top_p, repeat_penalty): |
| """ |
| Generate a chat response by utilizing a transformers pipeline with streaming. |
| |
| - Appends the user's message to the conversation history. |
| - Optionally retrieves web search context and inserts it as an additional system message. |
| - Converts the conversation into a formatted prompt to avoid leaking role labels. |
| - Uses the cached pipeline’s underlying model and tokenizer with a streamer to yield tokens as they are generated. |
| - Yields updated conversation history token by token. |
| """ |
| cancel_event.clear() |
| |
| |
| conversation = list(chat_history) if chat_history else [] |
| conversation.append({"role": "user", "content": user_message}) |
| |
| |
| debug_message = "" |
| if enable_search: |
| debug_message = "Initiating web search..." |
| yield conversation, debug_message |
| search_result = [""] |
| def do_search(): |
| search_result[0] = retrieve_context(user_message, max_results, max_chars) |
| search_thread = threading.Thread(target=do_search) |
| search_thread.start() |
| search_thread.join(timeout=2) |
| retrieved_context = search_result[0] |
| if retrieved_context: |
| debug_message = f"Web search results:\n\n{retrieved_context}" |
| |
| conversation.insert(1, {"role": "system", "content": f"Web search context:\n{retrieved_context}"}) |
| else: |
| debug_message = "Web search returned no results or timed out." |
| else: |
| debug_message = "Web search disabled." |
| |
| |
| conversation.append({"role": "assistant", "content": ""}) |
| |
| try: |
| |
| prompt_text = format_conversation(conversation, system_prompt) |
| |
| |
| pipe = load_pipeline(model_name) |
| |
| tokenizer = pipe.tokenizer |
| model = pipe.model |
| |
| |
| model_inputs = tokenizer(prompt_text, return_tensors="pt").to(model.device) |
| |
| |
| streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) |
| |
| |
| gen_kwargs = { |
| "input_ids": model_inputs.input_ids, |
| "attention_mask": model_inputs.attention_mask, |
| "max_new_tokens": max_tokens, |
| "temperature": temperature, |
| "top_k": top_k, |
| "top_p": top_p, |
| "repetition_penalty": repeat_penalty, |
| "streamer": streamer |
| } |
| thread = threading.Thread(target=model.generate, kwargs=gen_kwargs) |
| thread.start() |
| |
| |
| assistant_text = "" |
| for new_text in streamer: |
| assistant_text += new_text |
| conversation[-1]["content"] = assistant_text |
| yield conversation, debug_message |
| |
| thread.join() |
| except Exception as e: |
| conversation[-1]["content"] = f"Error: {e}" |
| yield conversation, debug_message |
| finally: |
| gc.collect() |
|
|
| |
| |
| |
| def cancel_generation(): |
| cancel_event.set() |
| return "Cancellation requested." |
|
|
| |
| |
| |
| def update_default_prompt(enable_search): |
| today = datetime.now().strftime('%Y-%m-%d') |
| if enable_search: |
| return f"You are a helpful assistant. Today is {today}. Please leverage the latest web data when responding to queries." |
| else: |
| return f"You are a helpful assistant. Today is {today}." |
|
|
| |
| |
| |
| with gr.Blocks(title="LLM Inference with ZeroGPU") as demo: |
| gr.Markdown("## 🧠 ZeroGPU LLM Inference with Web Search") |
| gr.Markdown("Interact with the model. Select your model, set your system prompt, and adjust parameters on the left.") |
| |
| with gr.Row(): |
| with gr.Column(scale=3): |
| default_model = list(MODELS.keys())[0] if MODELS else "No models available" |
| model_dropdown = gr.Dropdown( |
| label="Select Model", |
| choices=list(MODELS.keys()) if MODELS else [], |
| value=default_model, |
| info="Choose from available models." |
| ) |
| |
| enable_search_checkbox = gr.Checkbox(label="Enable Web Search", value=True, |
| info="Include recent search context to improve answers.") |
| |
| system_prompt_text = gr.Textbox(label="System Prompt", |
| value=update_default_prompt(enable_search_checkbox.value), |
| lines=3, |
| info="Define the base context for the AI's responses.") |
| gr.Markdown("### Generation Parameters") |
| max_tokens_slider = gr.Slider(label="Max Tokens", minimum=64, maximum=1024, value=1024, step=32, |
| info="Maximum tokens for the response.") |
| temperature_slider = gr.Slider(label="Temperature", minimum=0.1, maximum=2.0, value=0.7, step=0.1, |
| info="Controls the randomness of the output.") |
| top_k_slider = gr.Slider(label="Top-K", minimum=1, maximum=100, value=40, step=1, |
| info="Limits token candidates to the top-k tokens.") |
| top_p_slider = gr.Slider(label="Top-P (Nucleus Sampling)", minimum=0.1, maximum=1.0, value=0.95, step=0.05, |
| info="Limits token candidates to a cumulative probability threshold.") |
| repeat_penalty_slider = gr.Slider(label="Repetition Penalty", minimum=1.0, maximum=2.0, value=1.1, step=0.1, |
| info="Penalizes token repetition to improve diversity.") |
| gr.Markdown("### Web Search Settings") |
| max_results_number = gr.Number(label="Max Search Results", value=10, precision=0, |
| info="Maximum number of search results to retrieve.") |
| max_chars_number = gr.Number(label="Max Chars per Result", value=2000, precision=0, |
| info="Maximum characters to retrieve per search result.") |
| clear_button = gr.Button("Clear Chat") |
| cancel_button = gr.Button("Cancel Generation") |
| with gr.Column(scale=7): |
| chatbot = gr.Chatbot(label="Chat", type="messages") |
| msg_input = gr.Textbox(label="Your Message", placeholder="Enter your message and press Enter") |
| search_debug = gr.Markdown(label="Web Search Debug") |
| |
| |
| enable_search_checkbox.change( |
| fn=update_default_prompt, |
| inputs=[enable_search_checkbox], |
| outputs=[system_prompt_text] |
| ) |
| |
| def clear_chat(): |
| return [], "", "" |
| |
| clear_button.click(fn=clear_chat, outputs=[chatbot, msg_input, search_debug]) |
| cancel_button.click(fn=cancel_generation, outputs=search_debug) |
| |
| |
| msg_input.submit( |
| fn=chat_response, |
| inputs=[msg_input, chatbot, system_prompt_text, enable_search_checkbox, |
| max_results_number, max_chars_number, model_dropdown, |
| max_tokens_slider, temperature_slider, top_k_slider, top_p_slider, repeat_penalty_slider], |
| outputs=[chatbot, search_debug], |
| ) |
| |
| demo.launch() |
|
|