Spaces:
Sleeping
Sleeping
| import os | |
| import time | |
| import gc | |
| import traceback | |
| from queue import Queue | |
| from threading import Thread, Event | |
| from itertools import islice | |
| from datetime import datetime | |
| import re # for parsing <think> blocks | |
| import gradio as gr | |
| import torch | |
| from transformers import pipeline, TextIteratorStreamer | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| from duckduckgo_search import DDGS | |
| from transformers import modeling_utils | |
| if not hasattr(modeling_utils, "ALL_PARALLEL_STYLES") or modeling_utils.ALL_PARALLEL_STYLES is None: | |
| modeling_utils.ALL_PARALLEL_STYLES = ["tp", "none","colwise",'rowwise'] | |
| # import spaces # Import spaces early to enable ZeroGPU support | |
| # Optional: Disable GPU visibility if you wish to force CPU usage | |
| os.environ["CUDA_VISIBLE_DEVICES"] = "" | |
| if torch.cuda.is_available(): | |
| device = "auto" | |
| else: | |
| device = "cpu" | |
| # ------------------------------ | |
| # Global Cancellation Event | |
| # ------------------------------ | |
| cancel_event = Event() | |
| # ------------------------------ | |
| # Torch-Compatible Model Definitions with Adjusted Descriptions | |
| # ------------------------------ | |
| MODELS = { | |
| "Yee-R1-mini": {"repo_id":"sds-ai/Yee-R1-mini","description":"小熠(Yee)AI 数据安全专家"}, | |
| "secgpt-mini": {"repo_id":"clouditera/secgpt-mini","description":"SecGPT 是由 云起无垠 于 2023 年正式推出的开源大模型,专为网络安全场景打造,旨在以人工智能技术全面提升安全防护效率与效果。"}, | |
| "Qwen3-0.6B": {"repo_id":"Qwen/Qwen3-0.6B","description":"Dense causal language model with 0.6 B total parameters (0.44 B non-embedding), 28 transformer layers, 16 query heads & 8 KV heads, native 32 768-token context window, dual-mode generation, full multilingual & agentic capabilities."}, | |
| "Qwen3-1.7B": {"repo_id":"Qwen/Qwen3-1.7B","description":"Dense causal language model with 1.7 B total parameters (1.4 B non-embedding), 28 layers, 16 query heads & 8 KV heads, 32 768-token context, stronger reasoning vs. 0.6 B variant, dual-mode inference, instruction following across 100+ languages."}, | |
| } | |
| # Global cache for pipelines to avoid re-loading. | |
| PIPELINES = {} | |
| def load_pipeline(model_name): | |
| """ | |
| Load and cache a transformers pipeline for text generation. | |
| Tries bfloat16, falls back to float16 or float32 if unsupported. | |
| """ | |
| global PIPELINES | |
| if model_name in PIPELINES: | |
| return PIPELINES[model_name] | |
| repo = MODELS[model_name]["repo_id"] | |
| if model_name == "secgpt-mini": | |
| tokenizer = AutoTokenizer.from_pretrained(repo, trust_remote_code=True, device_map=device, subfolder="models") | |
| model = AutoModelForCausalLM.from_pretrained(repo, trust_remote_code=True, device_map=device, subfolder="models") | |
| else: | |
| tokenizer = AutoTokenizer.from_pretrained(repo, trust_remote_code=True, device_map=device) | |
| model = AutoModelForCausalLM.from_pretrained(repo, trust_remote_code=True, device_map=device) | |
| for dtype in (torch.bfloat16, torch.float16, torch.float32): | |
| try: | |
| pipe = pipeline( | |
| task="text-generation", | |
| model=model, | |
| tokenizer=tokenizer, | |
| trust_remote_code=True, | |
| torch_dtype=dtype, | |
| device_map=device, | |
| ) | |
| PIPELINES[model_name] = pipe | |
| return pipe | |
| except Exception: | |
| continue | |
| # Final fallback | |
| pipe = pipeline( | |
| task="text-generation", | |
| model=model, | |
| tokenizer=tokenizer, | |
| trust_remote_code=True, | |
| torch_dtype=dtype, | |
| device_map=device, | |
| ) | |
| PIPELINES[model_name] = pipe | |
| return pipe | |
| def retrieve_context(query, max_results=6, max_chars=600): | |
| """ | |
| Retrieve search snippets from DuckDuckGo (runs in background). | |
| Returns a list of result strings. | |
| """ | |
| try: | |
| with DDGS() as ddgs: | |
| return [f"{i+1}. {r.get('title','No Title')} - {r.get('body','')[:max_chars]}" | |
| for i, r in enumerate(islice(ddgs.text(query, region="wt-wt", safesearch="off", timelimit="y"), max_results))] | |
| except Exception: | |
| return [] | |
| def format_conversation(history, system_prompt, tokenizer): | |
| if history is None: | |
| history = [] | |
| if hasattr(tokenizer, "chat_template") and tokenizer.chat_template: | |
| messages = [{"role": "system", "content": system_prompt.strip()}] + history | |
| return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True, enable_thinking=True) | |
| else: | |
| # Fallback for base LMs without chat template | |
| prompt = system_prompt.strip() + "\n" | |
| for msg in history: | |
| if msg['role'] == 'user': | |
| prompt += "User: " + msg['content'].strip() + "\n" | |
| elif msg['role'] == 'assistant': | |
| prompt += "Assistant: " + msg['content'].strip() + "\n" | |
| if not prompt.strip().endswith("Assistant:"): | |
| prompt += "Assistant: " | |
| return prompt | |
| def chat_response(user_msg, chat_history, system_prompt, | |
| enable_search, max_results, max_chars, | |
| model_name, max_tokens, temperature, | |
| top_k, top_p, repeat_penalty, search_timeout): | |
| """ | |
| Generates streaming chat responses, optionally with background web search. | |
| """ | |
| cancel_event.clear() | |
| history = list(chat_history) if chat_history else [] | |
| history.append({'role': 'user', 'content': user_msg}) | |
| # Launch web search if enabled | |
| debug = '' | |
| search_results = [] | |
| if enable_search: | |
| debug = 'Search task started.' | |
| thread_search = Thread( | |
| target=lambda: search_results.extend( | |
| retrieve_context(user_msg, int(max_results), int(max_chars)) | |
| ) | |
| ) | |
| thread_search.daemon = True | |
| thread_search.start() | |
| else: | |
| debug = 'Web search disabled.' | |
| enriched = system_prompt | |
| try: | |
| # wait up to 1s for snippets, then replace debug with them | |
| if enable_search: | |
| thread_search.join(timeout=float(search_timeout)) | |
| if len(search_results) > 0: | |
| debug = "### Search results merged into prompt\n\n" + "\n".join( | |
| f"- {r}" for r in search_results | |
| ) | |
| system_prompt.strip() + "\n\nRelevant context:\n" + "\n".join(search_results) | |
| else: | |
| debug = "*No web search results found.*" | |
| enriched = system_prompt | |
| pipe = load_pipeline(model_name) | |
| prompt = format_conversation(history, enriched, pipe.tokenizer) | |
| prompt_debug = f"\n\n--- Prompt Preview ---\n```\n{prompt}\n```" | |
| streamer = TextIteratorStreamer(pipe.tokenizer, | |
| skip_prompt=True, | |
| skip_special_tokens=True) | |
| gen_thread = Thread( | |
| target=pipe, | |
| args=(prompt,), | |
| kwargs={ | |
| 'max_new_tokens': max_tokens, | |
| 'temperature': temperature, | |
| 'top_k': top_k, | |
| 'top_p': top_p, | |
| 'repetition_penalty': repeat_penalty, | |
| 'streamer': streamer, | |
| 'return_full_text': False, | |
| } | |
| ) | |
| gen_thread.start() | |
| # Buffers for thought vs answer | |
| thought_buf = '' | |
| answer_buf = '' | |
| in_thought = False | |
| # Stream tokens | |
| for chunk in streamer: | |
| if cancel_event.is_set(): | |
| break | |
| text = chunk | |
| # Detect start of thinking | |
| if not in_thought and '<think>' in text: | |
| in_thought = True | |
| # Insert thought placeholder | |
| history.append({ | |
| 'role': 'assistant', | |
| 'content': '', | |
| 'metadata': {'title': '💭 Thought'} | |
| }) | |
| # Capture after opening tag | |
| after = text.split('<think>', 1)[1] | |
| thought_buf += after | |
| # If closing tag in same chunk | |
| if '</think>' in thought_buf: | |
| before, after2 = thought_buf.split('</think>', 1) | |
| history[-1]['content'] = before.strip() | |
| in_thought = False | |
| # Start answer buffer | |
| answer_buf = after2 | |
| history.append({'role': 'assistant', 'content': answer_buf}) | |
| else: | |
| history[-1]['content'] = thought_buf | |
| yield history, debug | |
| continue | |
| # Continue thought streaming | |
| if in_thought: | |
| thought_buf += text | |
| if '</think>' in thought_buf: | |
| before, after2 = thought_buf.split('</think>', 1) | |
| history[-1]['content'] = before.strip() | |
| in_thought = False | |
| # Start answer buffer | |
| answer_buf = after2 | |
| history.append({'role': 'assistant', 'content': answer_buf}) | |
| else: | |
| history[-1]['content'] = thought_buf | |
| yield history, debug | |
| continue | |
| # Stream answer | |
| if not answer_buf: | |
| history.append({'role': 'assistant', 'content': ''}) | |
| answer_buf += text | |
| history[-1]['content'] = answer_buf | |
| yield history, debug | |
| gen_thread.join() | |
| yield history, debug + prompt_debug | |
| except Exception as e: | |
| history.append({'role': 'assistant', 'content': f"Error: {traceback.format_exc()}"}) | |
| yield history, debug | |
| finally: | |
| gc.collect() | |
| def cancel_generation(): | |
| cancel_event.set() | |
| return 'Generation cancelled.' | |
| def update_default_prompt(enable_search): | |
| today = datetime.now().strftime('%Y-%m-%d') | |
| return f"You are a helpful assistant. Today is {today}." | |
| # ------------------------------ | |
| # Gradio UI | |
| # ------------------------------ | |
| with gr.Blocks(title="Yee R1 Demo") as demo: | |
| gr.Markdown("## Yee-R1 Demo") | |
| gr.Markdown("小熠(Yee)AI 数据安全专家") | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| model_dd = gr.Dropdown(label="Select Model", choices=list(MODELS.keys()), value=list(MODELS.keys())[0]) | |
| search_chk = gr.Checkbox(label="Enable Web Search", value=False) | |
| sys_prompt = gr.Textbox(label="System Prompt", lines=3, value=update_default_prompt(search_chk.value)) | |
| gr.Markdown("### Generation Parameters") | |
| max_tok = gr.Slider(64, 16384, value=4096, step=32, label="Max Tokens") | |
| temp = gr.Slider(0.1, 2.0, value=0.6, step=0.1, label="Temperature") | |
| k = gr.Slider(1, 100, value=20, step=1, label="Top-K") | |
| p = gr.Slider(0.1, 1.0, value=0.95, step=0.05, label="Top-P") | |
| rp = gr.Slider(1.0, 2.0, value=1.0, step=0.1, label="Repetition Penalty") | |
| gr.Markdown("### Web Search Settings") | |
| mr = gr.Number(value=6, precision=0, label="Max Results") | |
| mc = gr.Number(value=600, precision=0, label="Max Chars/Result") | |
| st = gr.Slider(minimum=0.0, maximum=30.0, step=0.5, value=5.0, label="Search Timeout (s)") | |
| clr = gr.Button("Clear Chat") | |
| cnl = gr.Button("Cancel Generation") | |
| with gr.Column(scale=7): | |
| chat = gr.Chatbot(type="messages") | |
| txt = gr.Textbox(placeholder="Type your message and press Enter...") | |
| dbg = gr.Markdown() | |
| search_chk.change(fn=update_default_prompt, inputs=search_chk, outputs=sys_prompt) | |
| clr.click(fn=lambda: ([], "", ""), outputs=[chat, txt, dbg]) | |
| cnl.click(fn=cancel_generation, outputs=dbg) | |
| txt.submit(fn=chat_response, | |
| inputs=[txt, chat, sys_prompt, search_chk, mr, mc, | |
| model_dd, max_tok, temp, k, p, rp, st], | |
| outputs=[chat, dbg]) | |
| demo.launch() |