| from __future__ import annotations |
|
|
| import json |
| import time |
| import uuid |
| from typing import Any, Callable |
|
|
| from .cli import compose_generation_context |
|
|
|
|
| def build_chat_completion_response(model: Any, request: dict[str, Any]) -> dict[str, Any]: |
| """Run a Reframr model behind an OpenAI-style chat-completions shape.""" |
|
|
| model_name = str(request.get("model", "reframr")) |
| context = compose_generation_context( |
| str(request.get("prompt", "")), |
| system=str(request.get("system", "")), |
| messages=request.get("messages"), |
| tool_results=request.get("tool_results", request.get("toolResults")), |
| ) |
| generated_text = str( |
| model.generate_text( |
| context, |
| max_tokens=int(request.get("max_tokens", request.get("max_completion_tokens", 120))), |
| reasoning_mode=request.get("reasoning_mode", request.get("reasoningMode")), |
| temperature=float(request.get("temperature", 0.58)), |
| top_k=int(request.get("top_k", request.get("decode_top_k", 64))), |
| top_p=float(request.get("top_p", request.get("decode_top_p", 0.92))), |
| repetition_penalty=float(request.get("repetition_penalty", 1.25)), |
| ) |
| ).strip() |
| tool_call = parse_tool_call(generated_text) |
| if tool_call is None: |
| message = {"role": "assistant", "content": generated_text} |
| finish_reason = "stop" |
| else: |
| message = {"role": "assistant", "content": "", "tool_calls": [tool_call]} |
| finish_reason = "tool_calls" |
| prompt_tokens = _approx_token_count(context) |
| completion_tokens = _approx_token_count(generated_text) |
| return { |
| "id": f"chatcmpl-{uuid.uuid4().hex}", |
| "object": "chat.completion", |
| "created": int(time.time()), |
| "model": model_name, |
| "choices": [ |
| { |
| "index": 0, |
| "message": message, |
| "finish_reason": finish_reason, |
| } |
| ], |
| "usage": { |
| "prompt_tokens": prompt_tokens, |
| "completion_tokens": completion_tokens, |
| "total_tokens": prompt_tokens + completion_tokens, |
| }, |
| } |
|
|
|
|
| def iter_chat_completion_chunks( |
| model: Any, |
| request: dict[str, Any], |
| *, |
| chunk_size: int = 12, |
| ) -> Any: |
| """Yield OpenAI-style streaming chunk dictionaries for a Reframr response.""" |
|
|
| full_response = build_chat_completion_response(model, request) |
| model_name = str(full_response["model"]) |
| response_id = str(full_response["id"]) |
| created = int(full_response["created"]) |
| choice = full_response["choices"][0] |
| message = choice["message"] |
| yield _stream_chunk( |
| response_id, |
| model_name, |
| created, |
| {"role": "assistant"}, |
| finish_reason=None, |
| ) |
| tool_calls = message.get("tool_calls") if isinstance(message, dict) else None |
| if isinstance(tool_calls, list) and tool_calls: |
| yield _stream_chunk( |
| response_id, |
| model_name, |
| created, |
| {"tool_calls": tool_calls}, |
| finish_reason=None, |
| ) |
| else: |
| content = str(message.get("content", "")) if isinstance(message, dict) else "" |
| for part in _split_stream_content(content, chunk_size=max(1, int(chunk_size))): |
| yield _stream_chunk( |
| response_id, |
| model_name, |
| created, |
| {"content": part}, |
| finish_reason=None, |
| ) |
| yield _stream_chunk( |
| response_id, |
| model_name, |
| created, |
| {}, |
| finish_reason=str(choice.get("finish_reason", "stop")), |
| ) |
|
|
|
|
| def iter_sse_chat_completion( |
| model: Any, |
| request: dict[str, Any], |
| *, |
| chunk_size: int = 12, |
| ) -> Any: |
| for chunk in iter_chat_completion_chunks(model, request, chunk_size=chunk_size): |
| yield f"data: {json.dumps(chunk, ensure_ascii=False, separators=(',', ':'))}\n\n" |
| yield "data: [DONE]\n\n" |
|
|
|
|
| def run_tool_loop( |
| model: Any, |
| request: dict[str, Any], |
| *, |
| tools: dict[str, Callable[[dict[str, Any]], Any]], |
| max_rounds: int = 3, |
| ) -> dict[str, Any]: |
| """Run chat completions, executing registered tools when the model asks.""" |
|
|
| messages = [dict(message) for message in request.get("messages", []) if isinstance(message, dict)] |
| current_request = dict(request) |
| last_response: dict[str, Any] | None = None |
| for _ in range(max(1, int(max_rounds))): |
| current_request["messages"] = messages |
| last_response = build_chat_completion_response(model, current_request) |
| choice = last_response["choices"][0] |
| message = choice["message"] |
| if choice.get("finish_reason") != "tool_calls": |
| return last_response |
| tool_calls = message.get("tool_calls", []) |
| if not isinstance(tool_calls, list) or not tool_calls: |
| return last_response |
| messages.append({"role": "assistant", "content": "", "tool_calls": tool_calls}) |
| for tool_call in tool_calls: |
| tool_result = _execute_tool_call(tool_call, tools) |
| function_payload = tool_call.get("function", {}) if isinstance(tool_call, dict) else {} |
| tool_name = str(function_payload.get("name", "tool")) |
| messages.append( |
| { |
| "role": "tool", |
| "tool_call_id": str(tool_call.get("id", "")) if isinstance(tool_call, dict) else "", |
| "name": tool_name, |
| "content": json.dumps(tool_result, ensure_ascii=False, separators=(",", ":")), |
| } |
| ) |
| return last_response if last_response is not None else build_chat_completion_response(model, request) |
|
|
|
|
| def parse_tool_call(text: str) -> dict[str, Any] | None: |
| stripped = text.strip() |
| marker = "<tool_call>" |
| if not stripped.startswith(marker): |
| return None |
| payload = stripped[len(marker) :].strip() |
| if not payload: |
| return _tool_call_payload("tool", {}) |
| name, _, raw_arguments = payload.partition(" ") |
| name = name.strip() or "tool" |
| arguments = _normalize_tool_arguments(raw_arguments.strip()) |
| return _tool_call_payload(name, arguments) |
|
|
|
|
| def _execute_tool_call( |
| tool_call: Any, |
| tools: dict[str, Callable[[dict[str, Any]], Any]], |
| ) -> dict[str, Any]: |
| if not isinstance(tool_call, dict): |
| return {"ok": False, "error": "tool_call must be an object"} |
| function_payload = tool_call.get("function", {}) |
| function = function_payload if isinstance(function_payload, dict) else {} |
| tool_name = str(function.get("name", "")) |
| arguments = _normalize_tool_arguments(str(function.get("arguments", ""))) |
| tool = tools.get(tool_name) |
| if tool is None: |
| return {"ok": False, "error": f"tool not registered: {tool_name}"} |
| try: |
| result = tool(arguments) |
| except Exception as exc: |
| return {"ok": False, "error": str(exc)} |
| if isinstance(result, dict): |
| return result |
| return {"ok": True, "content": result} |
|
|
|
|
| def _tool_call_payload(name: str, arguments: dict[str, Any]) -> dict[str, Any]: |
| return { |
| "id": f"call_{uuid.uuid4().hex[:12]}", |
| "type": "function", |
| "function": { |
| "name": name, |
| "arguments": json.dumps(arguments, ensure_ascii=False, separators=(",", ":")), |
| }, |
| } |
|
|
|
|
| def _stream_chunk( |
| response_id: str, |
| model_name: str, |
| created: int, |
| delta: dict[str, Any], |
| *, |
| finish_reason: str | None, |
| ) -> dict[str, Any]: |
| return { |
| "id": response_id, |
| "object": "chat.completion.chunk", |
| "created": created, |
| "model": model_name, |
| "choices": [ |
| { |
| "index": 0, |
| "delta": delta, |
| "finish_reason": finish_reason, |
| } |
| ], |
| } |
|
|
|
|
| def _split_stream_content(content: str, *, chunk_size: int) -> list[str]: |
| if not content: |
| return [] |
| chunks: list[str] = [] |
| start = 0 |
| while start < len(content): |
| chunks.append(content[start : start + chunk_size]) |
| start += chunk_size |
| return chunks |
|
|
|
|
| def _normalize_tool_arguments(raw_arguments: str) -> dict[str, Any]: |
| if not raw_arguments: |
| return {} |
| try: |
| parsed = json.loads(raw_arguments) |
| except json.JSONDecodeError: |
| return {"input": raw_arguments} |
| if isinstance(parsed, dict): |
| return parsed |
| return {"input": parsed} |
|
|
|
|
| def _approx_token_count(text: str) -> int: |
| return len([part for part in text.split() if part]) |
|
|