Spaces:
Sleeping
Sleeping
| import inspect | |
| import os | |
| import threading | |
| import gradio as gr | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer | |
| os.environ.setdefault("TOKENIZERS_PARALLELISM", "false") | |
| MODEL_ID = os.getenv("MODEL_ID", "Qwen/Qwen3-0.6B") | |
| MAX_NEW_TOKENS = int(os.getenv("MAX_NEW_TOKENS", "4096")) | |
| MAX_INPUT_TOKENS = int(os.getenv("MAX_INPUT_TOKENS", "1536")) | |
| MAX_HISTORY_TURNS = int(os.getenv("MAX_HISTORY_TURNS", "3")) | |
| N_THREADS = int(os.getenv("N_THREADS", str(max(1, os.cpu_count() or 1)))) | |
| DEFAULT_SYSTEM_PROMPT = os.getenv( | |
| "SYSTEM_PROMPT", | |
| "You are a helpful assistant. Keep answers clear and concise.", | |
| ) | |
| PRESETS = { | |
| "Math": { | |
| "system": "You are a careful math tutor. Think through the problem, then give a short final answer.", | |
| "prompt": "Solve: If 2x^2 - 7x + 3 = 0, what are the real solutions?", | |
| "thinking": True, | |
| "sample_reasoning": "The discriminant is 49 - 24 = 25, so the roots are easy to compute with the quadratic formula.", | |
| "sample_answer": "The real solutions are x = 3 and x = 1/2.", | |
| }, | |
| "Coding": { | |
| "system": "You are a Python assistant. Prefer short, readable code.", | |
| "prompt": "Write a Python function that merges two sorted lists into one sorted list.", | |
| "thinking": True, | |
| "sample_reasoning": "Use two pointers. Compare the current elements, append the smaller one, then append the leftovers.", | |
| "sample_answer": "Here is a compact merge function plus a tiny example.", | |
| }, | |
| "Structured output": { | |
| "system": "Return compact JSON and avoid extra commentary.", | |
| "prompt": "Extract JSON from: Call Mina by Friday, priority high, budget about $2400, topic is launch video edits.", | |
| "thinking": False, | |
| "sample_reasoning": "Reasoning is disabled here so the output stays short and machine-friendly.", | |
| "sample_answer": '{"person":"Mina","deadline":"Friday","priority":"high","budget_usd":2400,"topic":"launch video edits"}', | |
| }, | |
| "Function calling style": { | |
| "system": "You are an assistant that plans tool use when it helps. If a tool would help, say what tool you would call and with which arguments.", | |
| "prompt": "Pretend you have tools. For 18.75 * 42 - 199 and converting 12 km to miles, explain which tool calls you would make, then give the result.", | |
| "thinking": True, | |
| "sample_reasoning": "I would use a calculator tool for the arithmetic and a unit-conversion tool for the distance conversion.", | |
| "sample_answer": "Calculator(18.75 * 42 - 199) -> 588.5\nConvert(12 km -> miles) -> about 7.46 miles", | |
| }, | |
| "Creative writing": { | |
| "system": "Write vivid, tight prose.", | |
| "prompt": "Write a two-sentence opening for a sci-fi heist story set on a drifting museum ship.", | |
| "thinking": False, | |
| "sample_reasoning": "Reasoning is disabled for a faster clean draft.", | |
| "sample_answer": "By the time the museum ship crossed into the dead zone, every priceless relic aboard had started broadcasting a heartbeat. Nia took that as her cue to cut the lights and steal the one artifact already trying to escape.", | |
| }, | |
| } | |
| torch.set_num_threads(N_THREADS) | |
| try: | |
| torch.set_num_interop_threads(max(1, min(2, N_THREADS))) | |
| except RuntimeError: | |
| pass | |
| _tokenizer = None | |
| _model = None | |
| _load_lock = threading.Lock() | |
| _generate_lock = threading.Lock() | |
| def make_chatbot(label, height=520): | |
| kwargs = {"label": label, "height": height} | |
| if "type" in inspect.signature(gr.Chatbot.__init__).parameters: | |
| kwargs["type"] = "messages" | |
| return gr.Chatbot(**kwargs) | |
| def get_model(): | |
| global _tokenizer, _model | |
| if _model is None or _tokenizer is None: | |
| with _load_lock: | |
| if _model is None or _tokenizer is None: | |
| _tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True) | |
| _model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_ID, | |
| torch_dtype=torch.float32, | |
| ) | |
| _model.eval() | |
| return _tokenizer, _model | |
| def clone_messages(messages): | |
| return [dict(item) for item in (messages or [])] | |
| def load_preset(name): | |
| preset = PRESETS[name] | |
| return ( | |
| preset["system"], | |
| preset["prompt"], | |
| preset["thinking"], | |
| preset["sample_reasoning"], | |
| preset["sample_answer"], | |
| ) | |
| def clear_all(): | |
| return [], [], [], "" | |
| def strip_non_think_specials(text): | |
| text = text or "" | |
| for token in ["<|im_end|>", "<|endoftext|>", "<|end▁of▁sentence|>"]: | |
| text = text.replace(token, "") | |
| return text | |
| def final_cleanup(text): | |
| text = strip_non_think_specials(text) | |
| text = text.replace("<think>", "").replace("</think>", "") | |
| return text.strip() | |
| def split_stream_text(raw_text, thinking): | |
| raw_text = strip_non_think_specials(raw_text) | |
| if not thinking: | |
| return "", final_cleanup(raw_text), False | |
| raw_text = raw_text.replace("<think>", "") | |
| if "</think>" in raw_text: | |
| reasoning, answer = raw_text.split("</think>", 1) | |
| return reasoning.strip(), answer.strip(), True | |
| return raw_text.strip(), "", False | |
| def respond_stream( | |
| message, | |
| system_prompt, | |
| thinking, | |
| model_history, | |
| reasoning_chat, | |
| answer_chat, | |
| ): | |
| message = (message or "").strip() | |
| if not message: | |
| yield clone_messages(reasoning_chat), clone_messages(answer_chat), list(model_history or []), "" | |
| return | |
| model_history = list(model_history or []) | |
| reasoning_chat = clone_messages(reasoning_chat) | |
| answer_chat = clone_messages(answer_chat) | |
| reasoning_chat.append({"role": "user", "content": message}) | |
| reasoning_chat.append( | |
| { | |
| "role": "assistant", | |
| "content": "(thinking...)" if thinking else "(reasoning disabled)", | |
| } | |
| ) | |
| answer_chat.append({"role": "user", "content": message}) | |
| answer_chat.append({"role": "assistant", "content": ""}) | |
| yield clone_messages(reasoning_chat), clone_messages(answer_chat), list(model_history), "" | |
| try: | |
| tokenizer, model = get_model() | |
| short_history = model_history[-2 * MAX_HISTORY_TURNS :] | |
| messages = [ | |
| {"role": "system", "content": (system_prompt or "").strip() or DEFAULT_SYSTEM_PROMPT}, | |
| *short_history, | |
| {"role": "user", "content": message}, | |
| ] | |
| prompt = tokenizer.apply_chat_template( | |
| messages, | |
| tokenize=False, | |
| add_generation_prompt=True, | |
| enable_thinking=thinking, | |
| ) | |
| inputs = tokenizer(prompt, return_tensors="pt") | |
| input_ids = inputs["input_ids"][:, -MAX_INPUT_TOKENS:] | |
| attention_mask = inputs["attention_mask"][:, -MAX_INPUT_TOKENS:] | |
| streamer = TextIteratorStreamer( | |
| tokenizer, | |
| skip_prompt=True, | |
| skip_special_tokens=False, | |
| clean_up_tokenization_spaces=False, | |
| timeout=None, | |
| ) | |
| generation_kwargs = { | |
| "input_ids": input_ids, | |
| "attention_mask": attention_mask, | |
| "max_new_tokens": MAX_NEW_TOKENS, | |
| "do_sample": True, | |
| "temperature": 0.6 if thinking else 0.7, | |
| "top_p": 0.95 if thinking else 0.8, | |
| "top_k": 20, | |
| "pad_token_id": tokenizer.eos_token_id, | |
| "streamer": streamer, | |
| } | |
| generation_error = {} | |
| def run_generation(): | |
| try: | |
| with _generate_lock: | |
| model.generate(**generation_kwargs) | |
| except Exception as exc: | |
| generation_error["message"] = str(exc) | |
| streamer.on_finalized_text("", stream_end=True) | |
| thread = threading.Thread(target=run_generation, daemon=True) | |
| thread.start() | |
| raw_text = "" | |
| saw_end_think = False | |
| for chunk in streamer: | |
| raw_text += chunk | |
| reasoning_text, answer_text, saw_end_now = split_stream_text(raw_text, thinking) | |
| saw_end_think = saw_end_think or saw_end_now | |
| if thinking: | |
| if saw_end_think: | |
| reasoning_chat[-1]["content"] = reasoning_text or "(no reasoning text returned)" | |
| else: | |
| reasoning_chat[-1]["content"] = reasoning_text or "(thinking...)" | |
| else: | |
| reasoning_chat[-1]["content"] = "(reasoning disabled)" | |
| answer_chat[-1]["content"] = answer_text | |
| yield clone_messages(reasoning_chat), clone_messages(answer_chat), list(model_history), "" | |
| thread.join() | |
| if generation_error: | |
| reasoning_chat[-1]["content"] = "" | |
| answer_chat[-1]["content"] = f"Error while running the local CPU model: {generation_error['message']}" | |
| yield clone_messages(reasoning_chat), clone_messages(answer_chat), list(model_history), "" | |
| return | |
| reasoning_text, answer_text, saw_end_think = split_stream_text(raw_text, thinking) | |
| if thinking and not saw_end_think: | |
| reasoning_text = "" | |
| answer_text = final_cleanup(raw_text) | |
| if thinking: | |
| reasoning_chat[-1]["content"] = reasoning_text or "(no reasoning text returned)" | |
| else: | |
| reasoning_chat[-1]["content"] = "(reasoning disabled)" | |
| answer_chat[-1]["content"] = answer_text or "(empty response)" | |
| model_history = short_history + [ | |
| {"role": "user", "content": message}, | |
| {"role": "assistant", "content": answer_chat[-1]["content"]}, | |
| ] | |
| yield clone_messages(reasoning_chat), clone_messages(answer_chat), list(model_history), "" | |
| except Exception as exc: | |
| reasoning_chat[-1]["content"] = "" | |
| answer_chat[-1]["content"] = f"Error while preparing the local CPU model: {exc}" | |
| yield clone_messages(reasoning_chat), clone_messages(answer_chat), list(model_history), "" | |
| with gr.Blocks(title="Local CPU split-reasoning chat") as demo: | |
| gr.Markdown( | |
| "# Local CPU split-reasoning chat\n" | |
| f"Running a local safetensors model on CPU from `{MODEL_ID}`. No GGUF and no external inference provider.\n\n" | |
| "The first request downloads the model, so the cold start is slower." | |
| ) | |
| with gr.Row(): | |
| preset = gr.Dropdown( | |
| choices=list(PRESETS.keys()), | |
| value="Math", | |
| label="Preset prompt", | |
| ) | |
| thinking = gr.Checkbox(label="Enable thinking", value=True) | |
| system_prompt = gr.Textbox( | |
| label="System prompt", | |
| value=PRESETS["Math"]["system"], | |
| lines=3, | |
| ) | |
| user_input = gr.Textbox( | |
| label="Your message", | |
| value=PRESETS["Math"]["prompt"], | |
| lines=4, | |
| ) | |
| with gr.Row(): | |
| sample_reasoning = gr.Textbox( | |
| label="Sample reasoning", | |
| value=PRESETS["Math"]["sample_reasoning"], | |
| lines=5, | |
| interactive=False, | |
| ) | |
| sample_answer = gr.Textbox( | |
| label="Sample answer", | |
| value=PRESETS["Math"]["sample_answer"], | |
| lines=5, | |
| interactive=False, | |
| ) | |
| with gr.Row(): | |
| send_btn = gr.Button("Send", variant="primary") | |
| clear_btn = gr.Button("Clear") | |
| with gr.Row(): | |
| reasoning_bot = make_chatbot("Reasoning", height=520) | |
| answer_bot = make_chatbot("Assistant", height=520) | |
| model_history_state = gr.State([]) | |
| preset.change( | |
| fn=load_preset, | |
| inputs=preset, | |
| outputs=[system_prompt, user_input, thinking, sample_reasoning, sample_answer], | |
| ) | |
| send_btn.click( | |
| fn=respond_stream, | |
| inputs=[user_input, system_prompt, thinking, model_history_state, reasoning_bot, answer_bot], | |
| outputs=[reasoning_bot, answer_bot, model_history_state, user_input], | |
| ) | |
| user_input.submit( | |
| fn=respond_stream, | |
| inputs=[user_input, system_prompt, thinking, model_history_state, reasoning_bot, answer_bot], | |
| outputs=[reasoning_bot, answer_bot, model_history_state, user_input], | |
| ) | |
| clear_btn.click( | |
| fn=clear_all, | |
| inputs=None, | |
| outputs=[reasoning_bot, answer_bot, model_history_state, user_input], | |
| ) | |
| demo.queue() | |
| demo.launch() | |