from __future__ import annotations import json import re try: import spaces _gpu_decorator = spaces.GPU except ImportError: spaces = None # type: ignore[assignment] _gpu_decorator = lambda f: f # noqa: E731 import torch from transformers import AutoModelForCausalLM, AutoTokenizer from src.tools import TOOL_SCHEMAS, execute_tool MODEL_ID = "openbmb/MiniCPM5-1B" _PERSONA = ( "You are Pip, a cheerful RPG companion sprite who helps adventurers stay productive. " "You speak in a warm, whimsical tone — encouraging, occasionally using light fantasy metaphors " "(quests, XP, loot) — but never annoying. Keep replies concise (2-4 sentences) unless the user " "asks for more. You help with focus, tasks, and motivation.\n\n" "CRITICAL TOOL USE RULES — follow these exactly:\n" "- When the user wants to START, BEGIN, or RESUME a timer or focus session → call start_timer.\n" "- When the user wants to STOP, PAUSE, or END the timer → call stop_timer.\n" "- When the user wants to RESET or RESTART the timer → call reset_timer.\n" "- When the user wants to ADD a task/todo/quest → call add_todo.\n" "- When the user wants to mark a task DONE/COMPLETE/finished → call complete_todo " "(identify it by its id from the quest list).\n" "- When the user wants to DELETE/REMOVE a task → call remove_todo.\n" "- NEVER describe an action in words without also calling the appropriate tool.\n" "- ALWAYS call the tool first, then reply. Do not skip the tool call." ) _TOOL_CALL_RE = re.compile(r'\s*(\{.*\})\s*', re.DOTALL) _FUNCTION_TAG_RE = re.compile(r'\s*(.*?)\s*', re.DOTALL) _PARAM_TAG_RE = re.compile(r'\s*(.*?)\s*', re.DOTALL) # Maps tool name → the single string param to use when content isn't JSON _TOOL_STRING_PARAM = { "add_todo": "task", "complete_todo": "task", "remove_todo": "task", } _model = None _tokenizer = None def _load(): global _model, _tokenizer if _model is not None: return _tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True) _model = AutoModelForCausalLM.from_pretrained( MODEL_ID, torch_dtype="auto", device_map="auto", trust_remote_code=True, ) def _build_system(timer_state: dict | None, todos: list[dict] | None = None) -> str: parts = [_PERSONA] if timer_state: status = "running" if timer_state.get("running") else "stopped" mode = timer_state.get("mode", "pomodoro") dur = timer_state.get("duration_minutes", 25) phase = timer_state.get("phase", "work") parts.append( f"\n\nCurrent timer status: {status}, mode={mode}, " f"duration={dur}min, phase={phase}." ) if todos is not None: if todos: lines = "\n".join( f" - [{'x' if t['done'] else ' '}] (id={t['id']}) {t['task']}" for t in todos ) parts.append(f"\n\nCurrent quest list:\n{lines}") else: parts.append("\n\nCurrent quest list: (empty)") return "".join(parts) _COMPLETE_RE = re.compile( r'(?:complete|finish|mark(?:\s+as)?\s+done|done|completed?)\s+(?:task|quest|todo)?[:\s]+(.+)', re.IGNORECASE, ) _ADD_RE = re.compile( r'(?:add|create|new)\s+(?:task|quest|todo)[:\s]+(.+)', re.IGNORECASE, ) _REMOVE_RE = re.compile( r'(?:remove|delete)\s+(?:task|quest|todo)[:\s]+(.+)', re.IGNORECASE, ) def _infer_tool_from_user_message(text: str) -> tuple[str, dict] | None: """Keyword fallback: if the LLM skipped the tool call, infer it from the user's words.""" t = text.strip() # Todo operations take priority over timer inference m = _COMPLETE_RE.search(t) if m: return "complete_todo", {"task": m.group(1).strip()} m = _ADD_RE.search(t) if m: return "add_todo", {"task": m.group(1).strip()} m = _REMOVE_RE.search(t) if m: return "remove_todo", {"task": m.group(1).strip()} lower = t.lower() # Don't misfire timer when the message is clearly about todos/quests. if any(w in lower for w in ("todo", "to-do", "task", "quest")): return None if any(w in lower for w in ("reset", "restart", "start over")): return "reset_timer", {} if any(w in lower for w in ("stop", "pause", "end", "cancel", "halt")): return "stop_timer", {} if any(w in lower for w in ("start", "begin", "go", "launch", "kick off", "let's focus", "pomodoro", "freeflow", "free flow")): return "start_timer", {} return None def _parse_tool_call(text: str) -> tuple[str, dict] | None: # Format 1: {"name": "...", "arguments": {...}} m = _TOOL_CALL_RE.search(text) if m: try: payload = json.loads(m.group(1)) return payload["name"], payload.get("arguments", {}) except (json.JSONDecodeError, KeyError): pass # Format 2: args_or_json (MiniCPM5) m = _FUNCTION_TAG_RE.search(text) if m: name = m.group(1) content = m.group(2).strip() if content: # Try value tags first (MiniCPM5 format) params = _PARAM_TAG_RE.findall(content) if params: return name, {k: v for k, v in params} try: args = json.loads(content) if isinstance(args, dict): return name, args except json.JSONDecodeError: pass # Plain-text content: map to the tool's primary string parameter param = _TOOL_STRING_PARAM.get(name) if param: return name, {param: content} return name, {} return None @_gpu_decorator def _generate(messages: list[dict]) -> str: _load() text = _tokenizer.apply_chat_template( messages, tools=TOOL_SCHEMAS, tokenize=False, add_generation_prompt=True, enable_thinking=False, ) inputs = _tokenizer(text, return_tensors="pt").to(_model.device) with torch.inference_mode(): output_ids = _model.generate( **inputs, max_new_tokens=256, do_sample=True, temperature=0.7, top_p=0.9, pad_token_id=_tokenizer.eos_token_id, ) new_tokens = output_ids[0][inputs["input_ids"].shape[1]:] # skip_special_tokens=False preserves tags raw = _tokenizer.decode(new_tokens, skip_special_tokens=False) for tok in ("<|im_end|>", "<|endoftext|>"): raw = raw.replace(tok, "") # strip any blocks that slip through despite enable_thinking=False raw = re.sub(r".*?", "", raw, flags=re.DOTALL) return raw.strip() def chat( messages: list[dict], timer_state: dict | None = None, todos: list[dict] | None = None, ) -> tuple[str, str | None]: """Return (reply_text, js_command_or_None).""" system = _build_system(timer_state, todos) full_messages = [{"role": "system", "content": system}] + messages raw = _generate(full_messages) print(f"[llm] raw output: {raw!r}") tool_call = _parse_tool_call(raw) print(f"[llm] tool_call parsed: {tool_call}") if not tool_call: tool_call = _infer_tool_from_user_message(messages[-1]["content"] if messages else "") if tool_call: tool_name, tool_args = tool_call tool_result, js_cmd = execute_tool(tool_name, tool_args) follow_up = full_messages + [ {"role": "assistant", "content": raw}, {"role": "tool", "content": tool_result, "name": tool_name}, ] reply = _generate(follow_up) return reply, js_cmd return raw, None