from __future__ import annotations import json import time import uuid from typing import Any, Callable from .cli import compose_generation_context def build_chat_completion_response(model: Any, request: dict[str, Any]) -> dict[str, Any]: """Run a Reframr model behind an OpenAI-style chat-completions shape.""" model_name = str(request.get("model", "reframr")) context = compose_generation_context( str(request.get("prompt", "")), system=str(request.get("system", "")), messages=request.get("messages"), tool_results=request.get("tool_results", request.get("toolResults")), ) generated_text = str( model.generate_text( context, max_tokens=int(request.get("max_tokens", request.get("max_completion_tokens", 120))), reasoning_mode=request.get("reasoning_mode", request.get("reasoningMode")), temperature=float(request.get("temperature", 0.58)), top_k=int(request.get("top_k", request.get("decode_top_k", 64))), top_p=float(request.get("top_p", request.get("decode_top_p", 0.92))), repetition_penalty=float(request.get("repetition_penalty", 1.25)), ) ).strip() tool_call = parse_tool_call(generated_text) if tool_call is None: message = {"role": "assistant", "content": generated_text} finish_reason = "stop" else: message = {"role": "assistant", "content": "", "tool_calls": [tool_call]} finish_reason = "tool_calls" prompt_tokens = _approx_token_count(context) completion_tokens = _approx_token_count(generated_text) return { "id": f"chatcmpl-{uuid.uuid4().hex}", "object": "chat.completion", "created": int(time.time()), "model": model_name, "choices": [ { "index": 0, "message": message, "finish_reason": finish_reason, } ], "usage": { "prompt_tokens": prompt_tokens, "completion_tokens": completion_tokens, "total_tokens": prompt_tokens + completion_tokens, }, } def iter_chat_completion_chunks( model: Any, request: dict[str, Any], *, chunk_size: int = 12, ) -> Any: """Yield OpenAI-style streaming chunk dictionaries for a Reframr response.""" full_response = build_chat_completion_response(model, request) model_name = str(full_response["model"]) response_id = str(full_response["id"]) created = int(full_response["created"]) choice = full_response["choices"][0] message = choice["message"] yield _stream_chunk( response_id, model_name, created, {"role": "assistant"}, finish_reason=None, ) tool_calls = message.get("tool_calls") if isinstance(message, dict) else None if isinstance(tool_calls, list) and tool_calls: yield _stream_chunk( response_id, model_name, created, {"tool_calls": tool_calls}, finish_reason=None, ) else: content = str(message.get("content", "")) if isinstance(message, dict) else "" for part in _split_stream_content(content, chunk_size=max(1, int(chunk_size))): yield _stream_chunk( response_id, model_name, created, {"content": part}, finish_reason=None, ) yield _stream_chunk( response_id, model_name, created, {}, finish_reason=str(choice.get("finish_reason", "stop")), ) def iter_sse_chat_completion( model: Any, request: dict[str, Any], *, chunk_size: int = 12, ) -> Any: for chunk in iter_chat_completion_chunks(model, request, chunk_size=chunk_size): yield f"data: {json.dumps(chunk, ensure_ascii=False, separators=(',', ':'))}\n\n" yield "data: [DONE]\n\n" def run_tool_loop( model: Any, request: dict[str, Any], *, tools: dict[str, Callable[[dict[str, Any]], Any]], max_rounds: int = 3, ) -> dict[str, Any]: """Run chat completions, executing registered tools when the model asks.""" messages = [dict(message) for message in request.get("messages", []) if isinstance(message, dict)] current_request = dict(request) last_response: dict[str, Any] | None = None for _ in range(max(1, int(max_rounds))): current_request["messages"] = messages last_response = build_chat_completion_response(model, current_request) choice = last_response["choices"][0] message = choice["message"] if choice.get("finish_reason") != "tool_calls": return last_response tool_calls = message.get("tool_calls", []) if not isinstance(tool_calls, list) or not tool_calls: return last_response messages.append({"role": "assistant", "content": "", "tool_calls": tool_calls}) for tool_call in tool_calls: tool_result = _execute_tool_call(tool_call, tools) function_payload = tool_call.get("function", {}) if isinstance(tool_call, dict) else {} tool_name = str(function_payload.get("name", "tool")) messages.append( { "role": "tool", "tool_call_id": str(tool_call.get("id", "")) if isinstance(tool_call, dict) else "", "name": tool_name, "content": json.dumps(tool_result, ensure_ascii=False, separators=(",", ":")), } ) return last_response if last_response is not None else build_chat_completion_response(model, request) def parse_tool_call(text: str) -> dict[str, Any] | None: stripped = text.strip() marker = "" if not stripped.startswith(marker): return None payload = stripped[len(marker) :].strip() if not payload: return _tool_call_payload("tool", {}) name, _, raw_arguments = payload.partition(" ") name = name.strip() or "tool" arguments = _normalize_tool_arguments(raw_arguments.strip()) return _tool_call_payload(name, arguments) def _execute_tool_call( tool_call: Any, tools: dict[str, Callable[[dict[str, Any]], Any]], ) -> dict[str, Any]: if not isinstance(tool_call, dict): return {"ok": False, "error": "tool_call must be an object"} function_payload = tool_call.get("function", {}) function = function_payload if isinstance(function_payload, dict) else {} tool_name = str(function.get("name", "")) arguments = _normalize_tool_arguments(str(function.get("arguments", ""))) tool = tools.get(tool_name) if tool is None: return {"ok": False, "error": f"tool not registered: {tool_name}"} try: result = tool(arguments) except Exception as exc: # pragma: no cover - defensive surface for app tools. return {"ok": False, "error": str(exc)} if isinstance(result, dict): return result return {"ok": True, "content": result} def _tool_call_payload(name: str, arguments: dict[str, Any]) -> dict[str, Any]: return { "id": f"call_{uuid.uuid4().hex[:12]}", "type": "function", "function": { "name": name, "arguments": json.dumps(arguments, ensure_ascii=False, separators=(",", ":")), }, } def _stream_chunk( response_id: str, model_name: str, created: int, delta: dict[str, Any], *, finish_reason: str | None, ) -> dict[str, Any]: return { "id": response_id, "object": "chat.completion.chunk", "created": created, "model": model_name, "choices": [ { "index": 0, "delta": delta, "finish_reason": finish_reason, } ], } def _split_stream_content(content: str, *, chunk_size: int) -> list[str]: if not content: return [] chunks: list[str] = [] start = 0 while start < len(content): chunks.append(content[start : start + chunk_size]) start += chunk_size return chunks def _normalize_tool_arguments(raw_arguments: str) -> dict[str, Any]: if not raw_arguments: return {} try: parsed = json.loads(raw_arguments) except json.JSONDecodeError: return {"input": raw_arguments} if isinstance(parsed, dict): return parsed return {"input": parsed} def _approx_token_count(text: str) -> int: return len([part for part in text.split() if part])