Spaces:
Running on Zero
Running on Zero
| from __future__ import annotations | |
| import json | |
| import re | |
| try: | |
| import spaces | |
| _gpu_decorator = spaces.GPU | |
| except ImportError: | |
| spaces = None # type: ignore[assignment] | |
| _gpu_decorator = lambda f: f # noqa: E731 | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from src.tools import TOOL_SCHEMAS, execute_tool | |
| MODEL_ID = "openbmb/MiniCPM5-1B" | |
| _PERSONA = ( | |
| "You are Pip, a cheerful RPG companion sprite who helps adventurers stay productive. " | |
| "You speak in a warm, whimsical tone — encouraging, occasionally using light fantasy metaphors " | |
| "(quests, XP, loot) — but never annoying. Keep replies concise (2-4 sentences) unless the user " | |
| "asks for more. You help with focus, tasks, and motivation.\n\n" | |
| "CRITICAL TOOL USE RULES — follow these exactly:\n" | |
| "- When the user wants to START, BEGIN, or RESUME a timer or focus session → call start_timer.\n" | |
| "- When the user wants to STOP, PAUSE, or END the timer → call stop_timer.\n" | |
| "- When the user wants to RESET or RESTART the timer → call reset_timer.\n" | |
| "- When the user wants to ADD a task/todo/quest → call add_todo.\n" | |
| "- When the user wants to mark a task DONE/COMPLETE/finished → call complete_todo " | |
| "(identify it by its id from the quest list).\n" | |
| "- When the user wants to DELETE/REMOVE a task → call remove_todo.\n" | |
| "- NEVER describe an action in words without also calling the appropriate tool.\n" | |
| "- ALWAYS call the tool first, then reply. Do not skip the tool call." | |
| ) | |
| _TOOL_CALL_RE = re.compile(r'<tool_call>\s*(\{.*\})\s*</tool_call>', re.DOTALL) | |
| _FUNCTION_TAG_RE = re.compile(r'<function\s+name=["\'](\w+)["\']>\s*(.*?)\s*</function>', re.DOTALL) | |
| _PARAM_TAG_RE = re.compile(r'<param\s+name=["\'](\w+)["\']>\s*(.*?)\s*</param>', re.DOTALL) | |
| # Maps tool name → the single string param to use when content isn't JSON | |
| _TOOL_STRING_PARAM = { | |
| "add_todo": "task", | |
| "complete_todo": "task", | |
| "remove_todo": "task", | |
| } | |
| _model = None | |
| _tokenizer = None | |
| def _load(): | |
| global _model, _tokenizer | |
| if _model is not None: | |
| return | |
| _tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True) | |
| _model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_ID, | |
| torch_dtype="auto", | |
| device_map="auto", | |
| trust_remote_code=True, | |
| ) | |
| def _build_system(timer_state: dict | None, todos: list[dict] | None = None) -> str: | |
| parts = [_PERSONA] | |
| if timer_state: | |
| status = "running" if timer_state.get("running") else "stopped" | |
| mode = timer_state.get("mode", "pomodoro") | |
| dur = timer_state.get("duration_minutes", 25) | |
| phase = timer_state.get("phase", "work") | |
| parts.append( | |
| f"\n\nCurrent timer status: {status}, mode={mode}, " | |
| f"duration={dur}min, phase={phase}." | |
| ) | |
| if todos is not None: | |
| if todos: | |
| lines = "\n".join( | |
| f" - [{'x' if t['done'] else ' '}] (id={t['id']}) {t['task']}" | |
| for t in todos | |
| ) | |
| parts.append(f"\n\nCurrent quest list:\n{lines}") | |
| else: | |
| parts.append("\n\nCurrent quest list: (empty)") | |
| return "".join(parts) | |
| _COMPLETE_RE = re.compile( | |
| r'(?:complete|finish|mark(?:\s+as)?\s+done|done|completed?)\s+(?:task|quest|todo)?[:\s]+(.+)', | |
| re.IGNORECASE, | |
| ) | |
| _ADD_RE = re.compile( | |
| r'(?:add|create|new)\s+(?:task|quest|todo)[:\s]+(.+)', | |
| re.IGNORECASE, | |
| ) | |
| _REMOVE_RE = re.compile( | |
| r'(?:remove|delete)\s+(?:task|quest|todo)[:\s]+(.+)', | |
| re.IGNORECASE, | |
| ) | |
| def _infer_tool_from_user_message(text: str) -> tuple[str, dict] | None: | |
| """Keyword fallback: if the LLM skipped the tool call, infer it from the user's words.""" | |
| t = text.strip() | |
| # Todo operations take priority over timer inference | |
| m = _COMPLETE_RE.search(t) | |
| if m: | |
| return "complete_todo", {"task": m.group(1).strip()} | |
| m = _ADD_RE.search(t) | |
| if m: | |
| return "add_todo", {"task": m.group(1).strip()} | |
| m = _REMOVE_RE.search(t) | |
| if m: | |
| return "remove_todo", {"task": m.group(1).strip()} | |
| lower = t.lower() | |
| # Don't misfire timer when the message is clearly about todos/quests. | |
| if any(w in lower for w in ("todo", "to-do", "task", "quest")): | |
| return None | |
| if any(w in lower for w in ("reset", "restart", "start over")): | |
| return "reset_timer", {} | |
| if any(w in lower for w in ("stop", "pause", "end", "cancel", "halt")): | |
| return "stop_timer", {} | |
| if any(w in lower for w in ("start", "begin", "go", "launch", "kick off", "let's focus", "pomodoro", "freeflow", "free flow")): | |
| return "start_timer", {} | |
| return None | |
| def _parse_tool_call(text: str) -> tuple[str, dict] | None: | |
| # Format 1: <tool_call>{"name": "...", "arguments": {...}}</tool_call> | |
| m = _TOOL_CALL_RE.search(text) | |
| if m: | |
| try: | |
| payload = json.loads(m.group(1)) | |
| return payload["name"], payload.get("arguments", {}) | |
| except (json.JSONDecodeError, KeyError): | |
| pass | |
| # Format 2: <function name="tool_name">args_or_json</function> (MiniCPM5) | |
| m = _FUNCTION_TAG_RE.search(text) | |
| if m: | |
| name = m.group(1) | |
| content = m.group(2).strip() | |
| if content: | |
| # Try <param name="...">value</param> tags first (MiniCPM5 format) | |
| params = _PARAM_TAG_RE.findall(content) | |
| if params: | |
| return name, {k: v for k, v in params} | |
| try: | |
| args = json.loads(content) | |
| if isinstance(args, dict): | |
| return name, args | |
| except json.JSONDecodeError: | |
| pass | |
| # Plain-text content: map to the tool's primary string parameter | |
| param = _TOOL_STRING_PARAM.get(name) | |
| if param: | |
| return name, {param: content} | |
| return name, {} | |
| return None | |
| def _generate(messages: list[dict]) -> str: | |
| _load() | |
| text = _tokenizer.apply_chat_template( | |
| messages, | |
| tools=TOOL_SCHEMAS, | |
| tokenize=False, | |
| add_generation_prompt=True, | |
| enable_thinking=False, | |
| ) | |
| inputs = _tokenizer(text, return_tensors="pt").to(_model.device) | |
| with torch.inference_mode(): | |
| output_ids = _model.generate( | |
| **inputs, | |
| max_new_tokens=256, | |
| do_sample=True, | |
| temperature=0.7, | |
| top_p=0.9, | |
| pad_token_id=_tokenizer.eos_token_id, | |
| ) | |
| new_tokens = output_ids[0][inputs["input_ids"].shape[1]:] | |
| # skip_special_tokens=False preserves <tool_call> tags | |
| raw = _tokenizer.decode(new_tokens, skip_special_tokens=False) | |
| for tok in ("<|im_end|>", "<|endoftext|>"): | |
| raw = raw.replace(tok, "") | |
| # strip any <think> blocks that slip through despite enable_thinking=False | |
| raw = re.sub(r"<think>.*?</think>", "", raw, flags=re.DOTALL) | |
| return raw.strip() | |
| def chat( | |
| messages: list[dict], | |
| timer_state: dict | None = None, | |
| todos: list[dict] | None = None, | |
| ) -> tuple[str, str | None]: | |
| """Return (reply_text, js_command_or_None).""" | |
| system = _build_system(timer_state, todos) | |
| full_messages = [{"role": "system", "content": system}] + messages | |
| raw = _generate(full_messages) | |
| print(f"[llm] raw output: {raw!r}") | |
| tool_call = _parse_tool_call(raw) | |
| print(f"[llm] tool_call parsed: {tool_call}") | |
| if not tool_call: | |
| tool_call = _infer_tool_from_user_message(messages[-1]["content"] if messages else "") | |
| if tool_call: | |
| tool_name, tool_args = tool_call | |
| tool_result, js_cmd = execute_tool(tool_name, tool_args) | |
| follow_up = full_messages + [ | |
| {"role": "assistant", "content": raw}, | |
| {"role": "tool", "content": tool_result, "name": tool_name}, | |
| ] | |
| reply = _generate(follow_up) | |
| return reply, js_cmd | |
| return raw, None | |