focus-buddy / src /llm.py
pocanman's picture
add debug time controls and MiniCPM5 LLM support
5578b51
Raw
History Blame Contribute Delete
7.95 kB
from __future__ import annotations
import json
import re
try:
import spaces
_gpu_decorator = spaces.GPU
except ImportError:
spaces = None # type: ignore[assignment]
_gpu_decorator = lambda f: f # noqa: E731
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from src.tools import TOOL_SCHEMAS, execute_tool
MODEL_ID = "openbmb/MiniCPM5-1B"
_PERSONA = (
"You are Pip, a cheerful RPG companion sprite who helps adventurers stay productive. "
"You speak in a warm, whimsical tone — encouraging, occasionally using light fantasy metaphors "
"(quests, XP, loot) — but never annoying. Keep replies concise (2-4 sentences) unless the user "
"asks for more. You help with focus, tasks, and motivation.\n\n"
"CRITICAL TOOL USE RULES — follow these exactly:\n"
"- When the user wants to START, BEGIN, or RESUME a timer or focus session → call start_timer.\n"
"- When the user wants to STOP, PAUSE, or END the timer → call stop_timer.\n"
"- When the user wants to RESET or RESTART the timer → call reset_timer.\n"
"- When the user wants to ADD a task/todo/quest → call add_todo.\n"
"- When the user wants to mark a task DONE/COMPLETE/finished → call complete_todo "
"(identify it by its id from the quest list).\n"
"- When the user wants to DELETE/REMOVE a task → call remove_todo.\n"
"- NEVER describe an action in words without also calling the appropriate tool.\n"
"- ALWAYS call the tool first, then reply. Do not skip the tool call."
)
_TOOL_CALL_RE = re.compile(r'<tool_call>\s*(\{.*\})\s*</tool_call>', re.DOTALL)
_FUNCTION_TAG_RE = re.compile(r'<function\s+name=["\'](\w+)["\']>\s*(.*?)\s*</function>', re.DOTALL)
_PARAM_TAG_RE = re.compile(r'<param\s+name=["\'](\w+)["\']>\s*(.*?)\s*</param>', re.DOTALL)
# Maps tool name → the single string param to use when content isn't JSON
_TOOL_STRING_PARAM = {
"add_todo": "task",
"complete_todo": "task",
"remove_todo": "task",
}
_model = None
_tokenizer = None
def _load():
global _model, _tokenizer
if _model is not None:
return
_tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
_model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype="auto",
device_map="auto",
trust_remote_code=True,
)
def _build_system(timer_state: dict | None, todos: list[dict] | None = None) -> str:
parts = [_PERSONA]
if timer_state:
status = "running" if timer_state.get("running") else "stopped"
mode = timer_state.get("mode", "pomodoro")
dur = timer_state.get("duration_minutes", 25)
phase = timer_state.get("phase", "work")
parts.append(
f"\n\nCurrent timer status: {status}, mode={mode}, "
f"duration={dur}min, phase={phase}."
)
if todos is not None:
if todos:
lines = "\n".join(
f" - [{'x' if t['done'] else ' '}] (id={t['id']}) {t['task']}"
for t in todos
)
parts.append(f"\n\nCurrent quest list:\n{lines}")
else:
parts.append("\n\nCurrent quest list: (empty)")
return "".join(parts)
_COMPLETE_RE = re.compile(
r'(?:complete|finish|mark(?:\s+as)?\s+done|done|completed?)\s+(?:task|quest|todo)?[:\s]+(.+)',
re.IGNORECASE,
)
_ADD_RE = re.compile(
r'(?:add|create|new)\s+(?:task|quest|todo)[:\s]+(.+)',
re.IGNORECASE,
)
_REMOVE_RE = re.compile(
r'(?:remove|delete)\s+(?:task|quest|todo)[:\s]+(.+)',
re.IGNORECASE,
)
def _infer_tool_from_user_message(text: str) -> tuple[str, dict] | None:
"""Keyword fallback: if the LLM skipped the tool call, infer it from the user's words."""
t = text.strip()
# Todo operations take priority over timer inference
m = _COMPLETE_RE.search(t)
if m:
return "complete_todo", {"task": m.group(1).strip()}
m = _ADD_RE.search(t)
if m:
return "add_todo", {"task": m.group(1).strip()}
m = _REMOVE_RE.search(t)
if m:
return "remove_todo", {"task": m.group(1).strip()}
lower = t.lower()
# Don't misfire timer when the message is clearly about todos/quests.
if any(w in lower for w in ("todo", "to-do", "task", "quest")):
return None
if any(w in lower for w in ("reset", "restart", "start over")):
return "reset_timer", {}
if any(w in lower for w in ("stop", "pause", "end", "cancel", "halt")):
return "stop_timer", {}
if any(w in lower for w in ("start", "begin", "go", "launch", "kick off", "let's focus", "pomodoro", "freeflow", "free flow")):
return "start_timer", {}
return None
def _parse_tool_call(text: str) -> tuple[str, dict] | None:
# Format 1: <tool_call>{"name": "...", "arguments": {...}}</tool_call>
m = _TOOL_CALL_RE.search(text)
if m:
try:
payload = json.loads(m.group(1))
return payload["name"], payload.get("arguments", {})
except (json.JSONDecodeError, KeyError):
pass
# Format 2: <function name="tool_name">args_or_json</function> (MiniCPM5)
m = _FUNCTION_TAG_RE.search(text)
if m:
name = m.group(1)
content = m.group(2).strip()
if content:
# Try <param name="...">value</param> tags first (MiniCPM5 format)
params = _PARAM_TAG_RE.findall(content)
if params:
return name, {k: v for k, v in params}
try:
args = json.loads(content)
if isinstance(args, dict):
return name, args
except json.JSONDecodeError:
pass
# Plain-text content: map to the tool's primary string parameter
param = _TOOL_STRING_PARAM.get(name)
if param:
return name, {param: content}
return name, {}
return None
@_gpu_decorator
def _generate(messages: list[dict]) -> str:
_load()
text = _tokenizer.apply_chat_template(
messages,
tools=TOOL_SCHEMAS,
tokenize=False,
add_generation_prompt=True,
enable_thinking=False,
)
inputs = _tokenizer(text, return_tensors="pt").to(_model.device)
with torch.inference_mode():
output_ids = _model.generate(
**inputs,
max_new_tokens=256,
do_sample=True,
temperature=0.7,
top_p=0.9,
pad_token_id=_tokenizer.eos_token_id,
)
new_tokens = output_ids[0][inputs["input_ids"].shape[1]:]
# skip_special_tokens=False preserves <tool_call> tags
raw = _tokenizer.decode(new_tokens, skip_special_tokens=False)
for tok in ("<|im_end|>", "<|endoftext|>"):
raw = raw.replace(tok, "")
# strip any <think> blocks that slip through despite enable_thinking=False
raw = re.sub(r"<think>.*?</think>", "", raw, flags=re.DOTALL)
return raw.strip()
def chat(
messages: list[dict],
timer_state: dict | None = None,
todos: list[dict] | None = None,
) -> tuple[str, str | None]:
"""Return (reply_text, js_command_or_None)."""
system = _build_system(timer_state, todos)
full_messages = [{"role": "system", "content": system}] + messages
raw = _generate(full_messages)
print(f"[llm] raw output: {raw!r}")
tool_call = _parse_tool_call(raw)
print(f"[llm] tool_call parsed: {tool_call}")
if not tool_call:
tool_call = _infer_tool_from_user_message(messages[-1]["content"] if messages else "")
if tool_call:
tool_name, tool_args = tool_call
tool_result, js_cmd = execute_tool(tool_name, tool_args)
follow_up = full_messages + [
{"role": "assistant", "content": raw},
{"role": "tool", "content": tool_result, "name": tool_name},
]
reply = _generate(follow_up)
return reply, js_cmd
return raw, None