Spaces:
Sleeping
Sleeping
| # SPDX-License-Identifier: GPL-3.0-or-later | |
| # | |
| # Toolify: Empower any LLM with function calling capabilities. | |
| # Copyright (C) 2025 FunnyCups (https://github.com/funnycups) | |
| import os | |
| import re | |
| import json | |
| import uuid | |
| import asyncio | |
| import httpx | |
| import secrets | |
| import string | |
| import traceback | |
| import time | |
| import random | |
| import logging | |
| import tiktoken | |
| import xml.etree.ElementTree as ET | |
| from typing import List, Dict, Any, Optional, Literal, Union | |
| from fastapi import FastAPI, Request, Header, HTTPException, Depends | |
| from fastapi.responses import JSONResponse, StreamingResponse | |
| from pydantic import BaseModel, ValidationError | |
| from config_loader import config_loader | |
| logger = logging.getLogger(__name__) | |
| # Token Counter for counting tokens | |
| class TokenCounter: | |
| """Token counter using tiktoken""" | |
| # Model prefix to encoding mapping (from tiktoken source) | |
| MODEL_PREFIX_TO_ENCODING = { | |
| "o1-": "o200k_base", | |
| "o3-": "o200k_base", | |
| "o4-mini-": "o200k_base", | |
| # chat | |
| "gpt-5-": "o200k_base", | |
| "gpt-4.5-": "o200k_base", | |
| "gpt-4.1-": "o200k_base", | |
| "chatgpt-4o-": "o200k_base", | |
| "gpt-4o-": "o200k_base", | |
| "gpt-4-": "cl100k_base", | |
| "gpt-3.5-turbo-": "cl100k_base", | |
| "gpt-35-turbo-": "cl100k_base", # Azure deployment name | |
| "gpt-oss-": "o200k_harmony", | |
| # fine-tuned | |
| "ft:gpt-4o": "o200k_base", | |
| "ft:gpt-4": "cl100k_base", | |
| "ft:gpt-3.5-turbo": "cl100k_base", | |
| "ft:davinci-002": "cl100k_base", | |
| "ft:babbage-002": "cl100k_base", | |
| } | |
| def __init__(self): | |
| self.encoders = {} | |
| def get_encoder(self, model: str): | |
| """Get or create encoder for the model""" | |
| if model not in self.encoders: | |
| encoding = None | |
| # First try to get encoding from model name directly | |
| try: | |
| self.encoders[model] = tiktoken.encoding_for_model(model) | |
| return self.encoders[model] | |
| except KeyError: | |
| pass | |
| # Try to find encoding by prefix matching | |
| for prefix, enc_name in self.MODEL_PREFIX_TO_ENCODING.items(): | |
| if model.startswith(prefix): | |
| encoding = enc_name | |
| break | |
| # Default to o200k_base for newer models | |
| if encoding is None: | |
| logger.warning(f"Model {model} not found in prefix mapping, using o200k_base encoding") | |
| encoding = "o200k_base" | |
| try: | |
| self.encoders[model] = tiktoken.get_encoding(encoding) | |
| except Exception as e: | |
| logger.warning(f"Failed to get encoding {encoding} for model {model}: {e}. Falling back to cl100k_base") | |
| self.encoders[model] = tiktoken.get_encoding("cl100k_base") | |
| return self.encoders[model] | |
| def count_tokens(self, messages: list, model: str = "gpt-3.5-turbo") -> int: | |
| """Count tokens in message list""" | |
| encoder = self.get_encoder(model) | |
| # All modern chat models use similar token counting | |
| return self._count_chat_tokens(messages, encoder, model) | |
| def _count_chat_tokens(self, messages: list, encoder, model: str) -> int: | |
| """Accurate token calculation for chat models | |
| Based on OpenAI's token counting documentation: | |
| - Each message has a fixed overhead | |
| - Content tokens are counted per message | |
| - Special tokens for message formatting | |
| """ | |
| # Token overhead varies by model | |
| if model.startswith(("gpt-3.5-turbo", "gpt-35-turbo")): | |
| # gpt-3.5-turbo uses different message overhead | |
| tokens_per_message = 4 # <|start|>role<|separator|>content<|end|> | |
| tokens_per_name = -1 # Name is omitted if not present | |
| else: | |
| # Most models including gpt-4, gpt-4o, o1, etc. | |
| tokens_per_message = 3 | |
| tokens_per_name = 1 | |
| num_tokens = 0 | |
| for message in messages: | |
| num_tokens += tokens_per_message | |
| # Count tokens for each field in the message | |
| for key, value in message.items(): | |
| if key == "content": | |
| # Handle case where content might be a list (multimodal messages) | |
| if isinstance(value, list): | |
| for item in value: | |
| if isinstance(item, dict) and item.get("type") == "text": | |
| content_text = item.get("text", "") | |
| num_tokens += len(encoder.encode(content_text, disallowed_special=())) | |
| # Note: Image tokens are not counted here as they have fixed costs | |
| elif isinstance(value, str): | |
| num_tokens += len(encoder.encode(value, disallowed_special=())) | |
| elif key == "name": | |
| num_tokens += tokens_per_name | |
| if isinstance(value, str): | |
| num_tokens += len(encoder.encode(value, disallowed_special=())) | |
| elif key == "role": | |
| # Role is already counted in tokens_per_message | |
| pass | |
| elif isinstance(value, str): | |
| # Other string fields | |
| num_tokens += len(encoder.encode(value, disallowed_special=())) | |
| # Every reply is primed with assistant role | |
| num_tokens += 3 | |
| return num_tokens | |
| def count_text_tokens(self, text: str, model: str = "gpt-3.5-turbo") -> int: | |
| """Count tokens in plain text""" | |
| encoder = self.get_encoder(model) | |
| return len(encoder.encode(text, disallowed_special=())) | |
| # Global token counter instance | |
| token_counter = TokenCounter() | |
| def generate_random_trigger_signal() -> str: | |
| """Generate a random, self-closing trigger signal like <Function_AB1c_Start/>.""" | |
| chars = string.ascii_letters + string.digits | |
| random_str = ''.join(secrets.choice(chars) for _ in range(4)) | |
| return f"<Function_{random_str}_Start/>" | |
| try: | |
| app_config = config_loader.load_config() | |
| log_level_str = app_config.features.log_level | |
| if log_level_str == "DISABLED": | |
| log_level = logging.CRITICAL + 1 | |
| else: | |
| log_level = getattr(logging, log_level_str, logging.INFO) | |
| logging.basicConfig( | |
| level=log_level, | |
| format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', | |
| datefmt='%Y-%m-%d %H:%M:%S' | |
| ) | |
| logger.info(f"✅ Configuration loaded successfully: {config_loader.config_path}") | |
| logger.info(f"📊 Configured {len(app_config.upstream_services)} upstream services") | |
| logger.info(f"🔑 Configured {len(app_config.client_authentication.allowed_keys)} client keys") | |
| MODEL_TO_SERVICE_MAPPING, ALIAS_MAPPING = config_loader.get_model_to_service_mapping() | |
| DEFAULT_SERVICE = config_loader.get_default_service() | |
| ALLOWED_CLIENT_KEYS = config_loader.get_allowed_client_keys() | |
| GLOBAL_TRIGGER_SIGNAL = generate_random_trigger_signal() | |
| logger.info(f"🎯 Configured {len(MODEL_TO_SERVICE_MAPPING)} model mappings") | |
| if ALIAS_MAPPING: | |
| logger.info(f"🔄 Configured {len(ALIAS_MAPPING)} model aliases: {list(ALIAS_MAPPING.keys())}") | |
| logger.info(f"🔄 Default service: {DEFAULT_SERVICE['name']}") | |
| except Exception as e: | |
| logger.error(f"❌ Configuration loading failed: {type(e).__name__}") | |
| logger.error(f"❌ Error details: {str(e)}") | |
| logger.error("💡 Please ensure config.yaml file exists and is properly formatted") | |
| exit(1) | |
| def build_tool_call_index_from_messages(messages: List[Dict[str, Any]]) -> Dict[str, Dict[str, str]]: | |
| """ | |
| Build tool_call_id -> {name, arguments} index from message history. | |
| This replaces the server-side mapping by extracting tool calls from assistant messages. | |
| Args: | |
| messages: List of message dicts from the request | |
| Returns: | |
| Dict mapping tool_call_id to {name, arguments} | |
| """ | |
| index = {} | |
| for msg in messages: | |
| if isinstance(msg, dict) and msg.get("role") == "assistant": | |
| tool_calls = msg.get("tool_calls") | |
| if tool_calls and isinstance(tool_calls, list): | |
| for tc in tool_calls: | |
| if isinstance(tc, dict): | |
| tc_id = tc.get("id") | |
| func = tc.get("function", {}) | |
| if tc_id and isinstance(func, dict): | |
| name = func.get("name", "") | |
| arguments = func.get("arguments", "{}") | |
| if not isinstance(arguments, str): | |
| try: | |
| arguments = json.dumps(arguments, ensure_ascii=False) | |
| except Exception: | |
| arguments = str(arguments) | |
| if name: | |
| index[tc_id] = { | |
| "name": name, | |
| "arguments": arguments | |
| } | |
| logger.debug(f"🔧 Indexed tool_call_id: {tc_id} -> {name}") | |
| logger.debug(f"🔧 Built tool_call index with {len(index)} entries") | |
| return index | |
| def get_fc_error_retry_prompt(original_response: str, error_details: str) -> str: | |
| custom_template = app_config.features.fc_error_retry_prompt_template | |
| if custom_template: | |
| return custom_template.format( | |
| original_response=original_response, | |
| error_details=error_details | |
| ) | |
| return f"""Your previous response attempted to make a function call but the format was invalid or could not be parsed. | |
| **Your original response:** | |
| ``` | |
| {original_response} | |
| ``` | |
| **Error details:** | |
| {error_details} | |
| **Instructions:** | |
| Please retry and output the function call in the correct XML format. Remember: | |
| 1. Start with the trigger signal on its own line | |
| 2. Immediately follow with the <function_calls> XML block | |
| 3. Use <args_json> with valid JSON for parameters | |
| 4. Do not add any text after </function_calls> | |
| Please provide the corrected function call now. DO NOT OUTPUT ANYTHING ELSE.""" | |
| def _schema_type_name(v: Any) -> str: | |
| if v is None: | |
| return "null" | |
| if isinstance(v, bool): | |
| return "boolean" | |
| if isinstance(v, int) and not isinstance(v, bool): | |
| return "integer" | |
| if isinstance(v, float): | |
| return "number" | |
| if isinstance(v, str): | |
| return "string" | |
| if isinstance(v, list): | |
| return "array" | |
| if isinstance(v, dict): | |
| return "object" | |
| return type(v).__name__ | |
| def _validate_value_against_schema(value: Any, schema: Dict[str, Any], path: str = "args", depth: int = 0) -> List[str]: | |
| """Best-effort JSON Schema validation for tool arguments. | |
| Intentional subset: | |
| - type, properties, required, additionalProperties | |
| - items (array) | |
| - enum, const | |
| - anyOf/oneOf/allOf (basic) | |
| - pattern/minLength/maxLength (string) | |
| Returns a list of human-readable errors. | |
| """ | |
| if schema is None: | |
| schema = {} | |
| if depth > 8: | |
| return [] # prevent pathological recursion | |
| errors: List[str] = [] | |
| # Combinators | |
| if isinstance(schema.get("allOf"), list): | |
| for idx, sub in enumerate(schema["allOf"]): | |
| errors.extend(_validate_value_against_schema(value, sub or {}, f"{path}.allOf[{idx}]", depth + 1)) | |
| return errors | |
| if isinstance(schema.get("anyOf"), list): | |
| option_errors = [ | |
| _validate_value_against_schema(value, sub or {}, path, depth + 1) | |
| for sub in schema["anyOf"] | |
| ] | |
| if not any(len(e) == 0 for e in option_errors): | |
| errors.append(f"{path}: value does not satisfy anyOf options") | |
| return errors | |
| if isinstance(schema.get("oneOf"), list): | |
| option_errors = [ | |
| _validate_value_against_schema(value, sub or {}, path, depth + 1) | |
| for sub in schema["oneOf"] | |
| ] | |
| ok_count = sum(1 for e in option_errors if len(e) == 0) | |
| if ok_count != 1: | |
| errors.append(f"{path}: value must satisfy exactly one oneOf option (matched {ok_count})") | |
| return errors | |
| # enum/const | |
| if "const" in schema: | |
| if value != schema.get("const"): | |
| errors.append(f"{path}: expected const={schema.get('const')!r}, got {value!r}") | |
| return errors | |
| enum_vals = schema.get("enum") | |
| if isinstance(enum_vals, list): | |
| if value not in enum_vals: | |
| errors.append(f"{path}: expected one of {enum_vals!r}, got {value!r}") | |
| return errors | |
| stype = schema.get("type") | |
| if stype is None: | |
| # If schema omits type but has object keywords, treat as object. | |
| if any(k in schema for k in ("properties", "required", "additionalProperties")): | |
| stype = "object" | |
| # type checks | |
| def _type_ok(t: str) -> bool: | |
| if t == "object": | |
| return isinstance(value, dict) | |
| if t == "array": | |
| return isinstance(value, list) | |
| if t == "string": | |
| return isinstance(value, str) | |
| if t == "boolean": | |
| return isinstance(value, bool) | |
| if t == "integer": | |
| return isinstance(value, int) and not isinstance(value, bool) | |
| if t == "number": | |
| return (isinstance(value, (int, float)) and not isinstance(value, bool)) | |
| if t == "null": | |
| return value is None | |
| return True | |
| if isinstance(stype, str): | |
| if not _type_ok(stype): | |
| errors.append(f"{path}: expected type '{stype}', got '{_schema_type_name(value)}'") | |
| return errors | |
| elif isinstance(stype, list): | |
| if not any(_type_ok(t) for t in stype if isinstance(t, str)): | |
| errors.append(f"{path}: expected type in {stype!r}, got '{_schema_type_name(value)}'") | |
| return errors | |
| # string constraints | |
| if isinstance(value, str): | |
| min_len = schema.get("minLength") | |
| max_len = schema.get("maxLength") | |
| if isinstance(min_len, int) and len(value) < min_len: | |
| errors.append(f"{path}: string shorter than minLength={min_len}") | |
| if isinstance(max_len, int) and len(value) > max_len: | |
| errors.append(f"{path}: string longer than maxLength={max_len}") | |
| pat = schema.get("pattern") | |
| if isinstance(pat, str): | |
| try: | |
| if re.search(pat, value) is None: | |
| errors.append(f"{path}: string does not match pattern {pat!r}") | |
| except re.error: | |
| # ignore invalid patterns in schema | |
| pass | |
| # object | |
| if isinstance(value, dict): | |
| props = schema.get("properties") | |
| if props is None: | |
| props = {} | |
| if not isinstance(props, dict): | |
| props = {} | |
| required = schema.get("required") | |
| if required is None: | |
| required = [] | |
| if not isinstance(required, list): | |
| required = [] | |
| required = [k for k in required if isinstance(k, str)] | |
| for k in required: | |
| if k not in value: | |
| errors.append(f"{path}: missing required property '{k}'") | |
| additional = schema.get("additionalProperties", True) | |
| for k, v in value.items(): | |
| if k in props: | |
| errors.extend(_validate_value_against_schema(v, props.get(k) or {}, f"{path}.{k}", depth + 1)) | |
| else: | |
| if additional is False: | |
| errors.append(f"{path}: unexpected property '{k}'") | |
| elif isinstance(additional, dict): | |
| errors.extend(_validate_value_against_schema(v, additional, f"{path}.{k}", depth + 1)) | |
| # array | |
| if isinstance(value, list): | |
| items = schema.get("items") | |
| if isinstance(items, dict): | |
| for i, v in enumerate(value): | |
| errors.extend(_validate_value_against_schema(v, items, f"{path}[{i}]", depth + 1)) | |
| return errors | |
| def validate_parsed_tools(parsed_tools: List[Dict[str, Any]], tools: List["Tool"]) -> Optional[str]: | |
| """Validate parsed tool calls against declared tools definitions. | |
| Returns a single error string if invalid, else None. | |
| """ | |
| tools = tools or [] | |
| allowed = {t.function.name: (t.function.parameters or {}) for t in tools if t and t.function and t.function.name} | |
| allowed_names = sorted(list(allowed.keys())) | |
| for idx, call in enumerate(parsed_tools or []): | |
| name = (call or {}).get("name") | |
| args = (call or {}).get("args") | |
| if not isinstance(name, str) or not name: | |
| return f"Tool call #{idx + 1}: missing tool name" | |
| if name not in allowed: | |
| return ( | |
| f"Tool call #{idx + 1}: unknown tool '{name}'. " | |
| f"Allowed tools: {allowed_names}" | |
| ) | |
| if not isinstance(args, dict): | |
| return f"Tool call #{idx + 1} '{name}': arguments must be a JSON object, got {_schema_type_name(args)}" | |
| schema = allowed[name] or {} | |
| errs = _validate_value_against_schema(args, schema, path=f"{name}") | |
| if errs: | |
| # Keep message short but actionable | |
| preview = "; ".join(errs[:6]) | |
| more = f" (+{len(errs) - 6} more)" if len(errs) > 6 else "" | |
| return f"Tool call #{idx + 1} '{name}': schema validation failed: {preview}{more}" | |
| return None | |
| def _prompt_schema_type_name(schema: Any) -> str: | |
| """Return a compact type label for prompt display.""" | |
| if not isinstance(schema, dict): | |
| return "any" | |
| stype = schema.get("type") | |
| if isinstance(stype, str): | |
| return stype | |
| if isinstance(stype, list): | |
| parts = [t for t in stype if isinstance(t, str)] | |
| return " | ".join(parts) if parts else "any" | |
| if any(k in schema for k in ("properties", "required", "additionalProperties")): | |
| return "object" | |
| if "items" in schema: | |
| return "array" | |
| if isinstance(schema.get("anyOf"), list): | |
| return "anyOf" | |
| if isinstance(schema.get("oneOf"), list): | |
| return "oneOf" | |
| if isinstance(schema.get("allOf"), list): | |
| return "allOf" | |
| return "any" | |
| def _prompt_schema_dump(value: Any) -> str: | |
| try: | |
| return json.dumps(value, ensure_ascii=False) | |
| except Exception: | |
| return str(value) | |
| def _collect_prompt_schema_constraints(schema: Dict[str, Any]) -> Dict[str, Any]: | |
| constraints: Dict[str, Any] = {} | |
| for key in [ | |
| "minimum", "maximum", "exclusiveMinimum", "exclusiveMaximum", | |
| "minLength", "maxLength", "pattern", "format", | |
| "minItems", "maxItems", "uniqueItems", | |
| "minProperties", "maxProperties", "multipleOf" | |
| ]: | |
| if key in schema: | |
| constraints[key] = schema.get(key) | |
| if _prompt_schema_type_name(schema) == "array": | |
| items = schema.get("items") or {} | |
| if isinstance(items, dict): | |
| item_type = _prompt_schema_type_name(items) | |
| if item_type != "any": | |
| constraints["items.type"] = item_type | |
| return constraints | |
| def _append_prompt_schema_body( | |
| lines: List[str], | |
| schema: Any, | |
| is_required: Optional[bool], | |
| indent_level: int, | |
| depth: int = 0 | |
| ) -> None: | |
| schema_dict = schema if isinstance(schema, dict) else {} | |
| indent = " " * indent_level | |
| if depth > 8: | |
| lines.append(f"{indent}- note: nested schema omitted after depth 8") | |
| return | |
| lines.append(f"{indent}- type: {_prompt_schema_type_name(schema_dict)}") | |
| if is_required is not None: | |
| lines.append(f"{indent}- required: {'Yes' if is_required else 'No'}") | |
| description = schema_dict.get("description") | |
| if description: | |
| lines.append(f"{indent}- description: {description}") | |
| enum_vals = schema_dict.get("enum") | |
| if enum_vals is not None: | |
| lines.append(f"{indent}- enum: {_prompt_schema_dump(enum_vals)}") | |
| if "const" in schema_dict: | |
| lines.append(f"{indent}- const: {_prompt_schema_dump(schema_dict.get('const'))}") | |
| default_val = schema_dict.get("default") | |
| if default_val is not None: | |
| lines.append(f"{indent}- default: {_prompt_schema_dump(default_val)}") | |
| examples_val = schema_dict.get("examples") or schema_dict.get("example") | |
| if examples_val is not None: | |
| lines.append(f"{indent}- examples: {_prompt_schema_dump(examples_val)}") | |
| constraints = _collect_prompt_schema_constraints(schema_dict) | |
| if constraints: | |
| lines.append(f"{indent}- constraints: {_prompt_schema_dump(constraints)}") | |
| props_raw = schema_dict.get("properties") | |
| props = props_raw if isinstance(props_raw, dict) else {} | |
| required_raw = schema_dict.get("required") | |
| required_list = required_raw if isinstance(required_raw, list) else [] | |
| required_list = [k for k in required_list if isinstance(k, str)] | |
| if required_list: | |
| lines.append(f"{indent}- required properties: {', '.join(required_list)}") | |
| if props: | |
| lines.append(f"{indent}- properties:") | |
| for child_name, child_schema in props.items(): | |
| child_indent = " " * (indent_level + 1) | |
| child_name_text = str(child_name) | |
| lines.append(f"{child_indent}- {child_name_text}:") | |
| _append_prompt_schema_body( | |
| lines, | |
| child_schema, | |
| child_name_text in required_list, | |
| indent_level + 2, | |
| depth + 1 | |
| ) | |
| items = schema_dict.get("items") | |
| if isinstance(items, dict): | |
| lines.append(f"{indent}- items:") | |
| _append_prompt_schema_body(lines, items, None, indent_level + 1, depth + 1) | |
| elif isinstance(items, list) and items: | |
| lines.append(f"{indent}- items:") | |
| for idx, item_schema in enumerate(items): | |
| item_indent = " " * (indent_level + 1) | |
| lines.append(f"{item_indent}- item[{idx}]:") | |
| _append_prompt_schema_body(lines, item_schema, None, indent_level + 2, depth + 1) | |
| additional = schema_dict.get("additionalProperties", True) | |
| if additional is False: | |
| lines.append(f"{indent}- additionalProperties: false") | |
| elif isinstance(additional, dict): | |
| lines.append(f"{indent}- additionalProperties:") | |
| _append_prompt_schema_body(lines, additional, None, indent_level + 1, depth + 1) | |
| for keyword in ("anyOf", "oneOf", "allOf"): | |
| options = schema_dict.get(keyword) | |
| if isinstance(options, list) and options: | |
| lines.append(f"{indent}- {keyword}:") | |
| for idx, option_schema in enumerate(options, start=1): | |
| option_indent = " " * (indent_level + 1) | |
| lines.append(f"{option_indent}- option {idx}:") | |
| _append_prompt_schema_body( | |
| lines, | |
| option_schema, | |
| None, | |
| indent_level + 2, | |
| depth + 1 | |
| ) | |
| async def attempt_fc_parse_with_retry( | |
| content: str, | |
| trigger_signal: str, | |
| messages: List[Dict[str, Any]], | |
| upstream_url: str, | |
| headers: Dict[str, str], | |
| model: str, | |
| tools: List["Tool"], | |
| timeout: int | |
| ) -> Optional[List[Dict[str, Any]]]: | |
| """ | |
| Attempt to parse function calls from content. If parsing fails and retry is enabled, | |
| send error details back to the model for correction. | |
| Returns parsed tool calls or None if parsing ultimately fails. | |
| """ | |
| def _parse_and_validate(current_content: str) -> tuple[Optional[List[Dict[str, Any]]], Optional[str]]: | |
| parsed = parse_function_calls_xml(current_content, trigger_signal) | |
| if not parsed: | |
| return None, None | |
| validation_error = validate_parsed_tools(parsed, tools) | |
| if validation_error: | |
| return None, validation_error | |
| return parsed, None | |
| if not app_config.features.enable_fc_error_retry: | |
| parsed, _err = _parse_and_validate(content) | |
| return parsed | |
| max_attempts = app_config.features.fc_error_retry_max_attempts | |
| current_content = content | |
| current_messages = messages.copy() | |
| for attempt in range(max_attempts): | |
| parsed_tools, validation_error = _parse_and_validate(current_content) | |
| if parsed_tools: | |
| if attempt > 0: | |
| logger.info(f"✅ Function call parsing succeeded on retry attempt {attempt + 1}") | |
| return parsed_tools | |
| # IMPORTANT: only treat this as a tool-call attempt if the trigger signal appears | |
| # outside of blocks. Otherwise models that mention the trigger | |
| # inside think can spuriously trigger retries. | |
| if find_last_trigger_signal_outside_think(current_content, trigger_signal) == -1: | |
| logger.debug("🔧 No trigger signal found outside <think> blocks; not a function call attempt") | |
| return None | |
| if attempt >= max_attempts - 1: | |
| logger.warning(f"⚠️ Function call parsing failed after {max_attempts} attempts") | |
| return None | |
| # Classify the failure type to choose the right retry strategy | |
| failure_type = _classify_fc_failure(current_content, trigger_signal) | |
| if failure_type == "no_fc": | |
| return None | |
| error_details = validation_error or _diagnose_fc_parse_error(current_content, trigger_signal) | |
| if failure_type == "truncated": | |
| # Output was cut off — ask model to continue from where it stopped | |
| retry_prompt = get_fc_continuation_prompt(current_content, error_details) | |
| logger.info(f"🔄 Function call output truncated, requesting continuation {attempt + 2}/{max_attempts}") | |
| else: | |
| # Syntax error with closed tags — ask model to rewrite entirely | |
| retry_prompt = get_fc_error_retry_prompt(current_content, error_details) | |
| logger.info(f"🔄 Function call syntax error, requesting rewrite {attempt + 2}/{max_attempts}") | |
| logger.debug(f"🔧 Failure type: {failure_type}, error details: {error_details}") | |
| retry_messages = current_messages + [ | |
| {"role": "assistant", "content": current_content}, | |
| {"role": "user", "content": retry_prompt} | |
| ] | |
| try: | |
| retry_response = await http_client.post( | |
| upstream_url, | |
| json={"model": model, "messages": retry_messages, "stream": False}, | |
| headers=headers, | |
| timeout=timeout | |
| ) | |
| retry_response.raise_for_status() | |
| retry_json = retry_response.json() | |
| if retry_json.get("choices") and len(retry_json["choices"]) > 0: | |
| retry_content = retry_json["choices"][0].get("message", {}).get("content", "") | |
| logger.debug(f"🔧 Received retry response, length: {len(retry_content)}") | |
| if failure_type == "truncated" and _is_continuation_response(retry_content, trigger_signal): | |
| # Model chose Option A: continuation — merge with truncated content | |
| current_content = _merge_truncated_and_continuation(current_content, retry_content) | |
| logger.info(f"🔧 Merged continuation, total length: {len(current_content)}") | |
| else: | |
| # Model chose Option B (full rewrite) or it was a syntax_error retry | |
| current_content = retry_content | |
| current_messages = retry_messages | |
| else: | |
| logger.warning(f"⚠️ Retry response has no valid choices") | |
| return None | |
| except Exception as e: | |
| logger.error(f"❌ Retry request failed: {e}") | |
| return None | |
| return None | |
| def _diagnose_fc_parse_error(content: str, trigger_signal: str) -> str: | |
| """Diagnose why function call parsing failed and return error description.""" | |
| errors = [] | |
| if trigger_signal not in content: | |
| errors.append(f"Trigger signal '{trigger_signal[:30]}...' not found in response") | |
| return "; ".join(errors) | |
| cleaned = remove_think_blocks(content) | |
| if "<function_calls>" not in cleaned: | |
| errors.append("Missing <function_calls> tag after trigger signal") | |
| elif "</function_calls>" not in cleaned: | |
| errors.append("Missing closing </function_calls> tag") | |
| if "<function_call>" not in cleaned: | |
| errors.append("No <function_call> blocks found inside <function_calls>") | |
| elif "</function_call>" not in cleaned: | |
| errors.append("Missing closing </function_call> tag") | |
| fc_match = re.search(r"<function_calls>([\s\S]*?)</function_calls>", cleaned) | |
| if fc_match: | |
| fc_content = fc_match.group(1) | |
| if "<tool>" not in fc_content: | |
| errors.append("Missing <tool> tag inside function_call") | |
| if "<args_json>" not in fc_content and "<args>" not in fc_content: | |
| errors.append("Missing <args_json> or <args> tag inside function_call") | |
| args_json_match = re.search(r"<args_json>([\s\S]*?)</args_json>", fc_content) | |
| if args_json_match: | |
| args_content = args_json_match.group(1).strip() | |
| cdata_match = re.search(r"<!\[CDATA\[([\s\S]*?)\]\]>", args_content) | |
| json_to_parse = cdata_match.group(1) if cdata_match else args_content | |
| try: | |
| parsed = json.loads(json_to_parse) | |
| if not isinstance(parsed, dict): | |
| errors.append(f"args_json must be a JSON object, got {type(parsed).__name__}") | |
| except json.JSONDecodeError as e: | |
| errors.append(f"Invalid JSON in args_json: {str(e)}") | |
| if not errors: | |
| errors.append("XML structure appears correct but parsing failed for unknown reason") | |
| return "; ".join(errors) | |
| def _classify_fc_failure(content: str, trigger_signal: str) -> str: | |
| """Classify function call failure type. | |
| Returns: | |
| 'no_fc' - no trigger signal found outside think blocks | |
| 'truncated' - trigger signal and opening tag found but no closing tag (output was cut off) | |
| 'syntax_error' - tags are closed but content is malformed | |
| """ | |
| if find_last_trigger_signal_outside_think(content, trigger_signal) == -1: | |
| return "no_fc" | |
| cleaned = remove_think_blocks(content) | |
| pos = find_last_trigger_signal_outside_think(cleaned, trigger_signal) | |
| if pos == -1: | |
| return "no_fc" | |
| after_trigger = cleaned[pos:] | |
| has_open = "<" + "function_calls>" in after_trigger | |
| has_close = "</" + "function_calls>" in after_trigger | |
| if not has_open: | |
| return "syntax_error" | |
| if has_open and not has_close: | |
| return "truncated" | |
| return "syntax_error" | |
| def get_fc_continuation_prompt(truncated_content: str, error_details: str) -> str: | |
| """Generate a prompt asking the model to continue its truncated function call output.""" | |
| # Show only the tail to save tokens | |
| tail = truncated_content[-1500:] | |
| fci_close = "</" + "function_call>" | |
| fc_close = "</" + "function_calls>" | |
| return ( | |
| "Your previous response was cut off before the function call XML was complete.\n" | |
| "\n" | |
| "**Your truncated response (ending abruptly):**\n" | |
| "```\n" | |
| f"{tail}\n" | |
| "```\n" | |
| "\n" | |
| f"**What happened:** {error_details}\n" | |
| "\n" | |
| "**You have two options:**\n" | |
| "\n" | |
| "**Option A (PREFERRED \u2014 Continue writing):**\n" | |
| "Output ONLY the exact continuation from where you were cut off. Rules:\n" | |
| "- Start EXACTLY from the next character after the cutoff point \u2014 do not repeat ANY text, not even a single character\n" | |
| "- If the cutoff happened mid-word, start from the next character of that word, never repeat the partial character/word\n" | |
| "- Do NOT output any trigger signal or opening tags that were already present\n" | |
| f"- End with the proper closing tags ({fci_close}, {fc_close} as needed)\n" | |
| "- Do NOT add any explanation before or after\n" | |
| "\n" | |
| "**Option B (Only if you made an error earlier):**\n" | |
| "Start fresh with the complete function call from the trigger signal. " | |
| "Output the trigger signal on its own line, followed by the complete " | |
| "function_calls block.\n" | |
| "\n" | |
| "Choose Option A unless you believe your previous output contained errors that need correction." | |
| ) | |
| def _is_continuation_response(retry_content: str, trigger_signal: str) -> bool: | |
| """Determine if the model's retry response is a continuation or a full rewrite.""" | |
| cleaned = retry_content.strip() | |
| if trigger_signal in cleaned: | |
| return False | |
| fc_open = "<" + "function_calls>" | |
| if cleaned.lstrip().startswith(fc_open): | |
| return False | |
| return True | |
| def _merge_truncated_and_continuation(truncated: str, continuation: str) -> str: | |
| """Merge truncated content with its continuation.""" | |
| return truncated.rstrip("\n") + continuation.lstrip("\n") | |
| def format_tool_result_for_ai(tool_name: str, tool_arguments: str, result_content: str) -> str: | |
| """ | |
| Format tool call results for AI understanding with complete context. | |
| Args: | |
| tool_name: Name of the tool that was called | |
| tool_arguments: Arguments passed to the tool (JSON string) | |
| result_content: Execution result from the tool | |
| Returns: | |
| Formatted text for upstream model | |
| """ | |
| formatted_text = f"""Tool execution result: | |
| - Tool name: {tool_name} | |
| - Tool arguments: {tool_arguments} | |
| - Execution result: | |
| <tool_result> | |
| {result_content} | |
| </tool_result>""" | |
| logger.debug(f"🔧 Formatted tool result for {tool_name}") | |
| return formatted_text | |
| def format_assistant_tool_calls_for_ai(tool_calls: List[Dict[str, Any]], trigger_signal: str) -> str: | |
| """Format assistant tool calls into AI-readable string format.""" | |
| logger.debug(f"🔧 Formatting assistant tool calls. Count: {len(tool_calls)}") | |
| def _wrap_cdata(text: str) -> str: | |
| # Avoid illegal ']]>' sequence inside CDATA by splitting. | |
| safe = (text or "").replace("]]>", "]]]]><![CDATA[>") | |
| return f"<![CDATA[{safe}]]>" | |
| xml_calls_parts = [] | |
| for tool_call in tool_calls: | |
| function_info = tool_call.get("function", {}) | |
| name = function_info.get("name", "") | |
| arguments_val = function_info.get("arguments", "{}") | |
| # Strict: assistant.tool_calls must carry JSON-object arguments (or a JSON string representing an object). | |
| try: | |
| if isinstance(arguments_val, dict): | |
| args_dict = arguments_val | |
| elif isinstance(arguments_val, str): | |
| parsed = json.loads(arguments_val or "{}") | |
| if not isinstance(parsed, dict): | |
| raise ValueError(f"arguments must be a JSON object, got {type(parsed).__name__}") | |
| args_dict = parsed | |
| else: | |
| raise ValueError(f"arguments must be a JSON object or JSON string, got {type(arguments_val).__name__}") | |
| except Exception as e: | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"Invalid assistant.tool_calls arguments for tool '{name}': {e}" | |
| ) | |
| args_payload = json.dumps(args_dict, ensure_ascii=False) | |
| xml_call = ( | |
| f"<function_call>\n" | |
| f"<tool>{name}</tool>\n" | |
| f"<args_json>{_wrap_cdata(args_payload)}</args_json>\n" | |
| f"</function_call>" | |
| ) | |
| xml_calls_parts.append(xml_call) | |
| all_calls = "\n".join(xml_calls_parts) | |
| final_str = f"{trigger_signal}\n<function_calls>\n{all_calls}\n</function_calls>" | |
| logger.debug("🔧 Assistant tool calls formatted successfully.") | |
| return final_str | |
| def get_function_call_prompt_template(trigger_signal: str) -> str: | |
| """ | |
| Generate prompt template based on dynamic trigger signal | |
| """ | |
| custom_template = app_config.features.prompt_template | |
| if custom_template: | |
| logger.info("🔧 Using custom prompt template from configuration") | |
| return custom_template.format( | |
| trigger_signal=trigger_signal, | |
| tools_list="{tools_list}" | |
| ) | |
| return f""" | |
| You have access to the following available tools to help solve problems: | |
| {{tools_list}} | |
| **IMPORTANT CONTEXT NOTES:** | |
| 1. You can call MULTIPLE tools in a single response if needed. | |
| 2. Even though you can call multiple tools, you MUST respect the user's later constraints and preferences (e.g., the user may request no tools, only one tool, or a specific tool/workflow). | |
| 3. The conversation context may already contain tool execution results from previous function calls. Review the conversation history carefully to avoid unnecessary duplicate tool calls. | |
| 4. When tool execution results are present in the context, they will be formatted with XML tags like <tool_result>...</tool_result> for easy identification. | |
| 5. This is the ONLY format you can use for tool calls, and any deviation will result in failure. | |
| When you need to use tools, you **MUST** strictly follow this format. Do NOT include any extra text, explanations, or dialogue on the first and second lines of the tool call syntax: | |
| 1. When starting tool calls, begin on a new line with exactly: | |
| {trigger_signal} | |
| No leading or trailing spaces, output exactly as shown above. The trigger signal MUST be on its own line and appear only once. Do not output a trigger signal for each tool call. | |
| 2. Starting from the second line, **immediately** follow with the complete <function_calls> XML block. | |
| 3. For multiple tool calls, include multiple <function_call> blocks within the same <function_calls> wrapper, not separate blocks. Output the trigger signal only once, then one <function_calls> with all <function_call> children. | |
| 4. Do not add any text or explanation after the closing </function_calls> tag. | |
| STRICT ARGUMENT KEY RULES: | |
| - You MUST use parameter keys EXACTLY as defined (case- and punctuation-sensitive). Do NOT rename, add, or remove characters. | |
| - If a key starts with a hyphen (e.g., "-i", "-C"), you MUST keep the leading hyphen in the JSON key. Never convert "-i" to "i" or "-C" to "C". | |
| - The <tool> tag must contain the exact name of a tool from the list. Any other tool name is invalid. | |
| - The <args_json> tag must contain a single JSON object with all required arguments for that tool. | |
| - You MAY wrap the JSON content inside <![CDATA[...]]> to avoid XML escaping issues. | |
| CORRECT Example (multiple tool calls): | |
| ...response content (optional)... | |
| {trigger_signal} | |
| <function_calls> | |
| <function_call> | |
| <tool>Grep</tool> | |
| <args_json><![CDATA[{{"-i": true, "-C": 2, "path": "."}}]]></args_json> | |
| </function_call> | |
| <function_call> | |
| <tool>search</tool> | |
| <args_json><![CDATA[{{"keywords": ["Python Document", "how to use python"]}}]]></args_json> | |
| </function_call> | |
| </function_calls> | |
| INCORRECT Example (extra text + wrong key names — DO NOT DO THIS): | |
| ...response content (optional)... | |
| {trigger_signal} | |
| I will call the tools for you. | |
| <function_calls> | |
| <function_call> | |
| <tool>Grep</tool> | |
| <args> | |
| <i>true</i> | |
| <C>2</C> | |
| <path>.</path> | |
| </args> | |
| </function_call> | |
| </function_calls> | |
| INCORRECT Example (output non-XML format — DO NOT DO THIS): | |
| ...response content (optional)... | |
| ```json | |
| {{"files":[{{"path":"system.py"}}]}} | |
| ``` | |
| Now please be ready to strictly follow the above specifications. | |
| """ | |
| class ToolFunction(BaseModel): | |
| name: str | |
| description: Optional[str] = None | |
| parameters: Dict[str, Any] | |
| class Tool(BaseModel): | |
| type: Literal["function"] | |
| function: ToolFunction | |
| class Message(BaseModel): | |
| role: str | |
| content: Optional[str] = None | |
| tool_calls: Optional[List[Dict[str, Any]]] = None | |
| tool_call_id: Optional[str] = None | |
| name: Optional[str] = None | |
| class Config: | |
| extra = "allow" | |
| class ToolChoice(BaseModel): | |
| type: Literal["function"] | |
| function: Dict[str, str] | |
| class ChatCompletionRequest(BaseModel): | |
| model: str | |
| messages: List[Dict[str, Any]] | |
| tools: Optional[List[Tool]] = None | |
| tool_choice: Optional[Union[str, ToolChoice]] = None | |
| stream: Optional[bool] = False | |
| stream_options: Optional[Dict[str, Any]] = None | |
| temperature: Optional[float] = None | |
| max_tokens: Optional[int] = None | |
| top_p: Optional[float] = None | |
| frequency_penalty: Optional[float] = None | |
| presence_penalty: Optional[float] = None | |
| n: Optional[int] = None | |
| stop: Optional[Union[str, List[str]]] = None | |
| class Config: | |
| extra = "allow" | |
| def generate_function_prompt(tools: List[Tool], trigger_signal: str) -> tuple[str, str]: | |
| """ | |
| Generate injected system prompt based on tools definition in client request. | |
| Returns: (prompt_content, trigger_signal) | |
| Raises: | |
| HTTPException: If tool schema validation fails (e.g., required keys not in properties) | |
| """ | |
| tools_list_str = [] | |
| for i, tool in enumerate(tools): | |
| func = tool.function | |
| name = func.name | |
| description = func.description or "" | |
| # Robustly read JSON Schema fields + validate basic types | |
| schema: Dict[str, Any] = func.parameters or {} | |
| props_raw = schema.get("properties", {}) | |
| if props_raw is None: | |
| props_raw = {} | |
| if not isinstance(props_raw, dict): | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"Tool '{name}': 'properties' must be an object, got {type(props_raw).__name__}" | |
| ) | |
| props: Dict[str, Any] = props_raw | |
| required_raw = schema.get("required", []) | |
| if required_raw is None: | |
| required_raw = [] | |
| if not isinstance(required_raw, list): | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"Tool '{name}': 'required' must be a list, got {type(required_raw).__name__}" | |
| ) | |
| non_string_required = [k for k in required_raw if not isinstance(k, str)] | |
| if non_string_required: | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"Tool '{name}': 'required' entries must be strings, got {non_string_required}" | |
| ) | |
| required_list: List[str] = required_raw | |
| missing_keys = [key for key in required_list if key not in props] | |
| if missing_keys: | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"Tool '{name}': required parameters {missing_keys} are not defined in properties" | |
| ) | |
| # Brief summary line: name (type) | |
| params_summary = ", ".join([ | |
| f"{p_name} ({_prompt_schema_type_name(p_info)})" for p_name, p_info in props.items() | |
| ]) or "None" | |
| # Build detailed parameter spec for prompt injection (default enabled) | |
| detail_lines: List[str] = [] | |
| for p_name, p_info in props.items(): | |
| detail_lines.append(f"- {p_name}:") | |
| _append_prompt_schema_body( | |
| detail_lines, | |
| p_info, | |
| p_name in required_list, | |
| indent_level=1 | |
| ) | |
| detail_block = "\n".join(detail_lines) if detail_lines else "(no parameter details)" | |
| desc_block = f"```\n{description}\n```" if description else "None" | |
| tools_list_str.append( | |
| f"{i + 1}. <tool name=\"{name}\">\n" | |
| f" Description:\n{desc_block}\n" | |
| f" Parameters summary: {params_summary}\n" | |
| f" Required parameters: {', '.join(required_list) if required_list else 'None'}\n" | |
| f" Parameter details:\n{detail_block}" | |
| ) | |
| prompt_template = get_function_call_prompt_template(trigger_signal) | |
| prompt_content = prompt_template.replace("{tools_list}", "\n\n".join(tools_list_str)) | |
| return prompt_content, trigger_signal | |
| def remove_think_blocks(text: str) -> str: | |
| """ | |
| Temporarily remove all <think>...</think> blocks for XML parsing | |
| Supports nested think tags | |
| Note: This function is only used for temporary parsing and does not affect the original content returned to the user | |
| """ | |
| while '<think>' in text and '</think>' in text: | |
| start_pos = text.find('<think>') | |
| if start_pos == -1: | |
| break | |
| pos = start_pos + 7 | |
| depth = 1 | |
| while pos < len(text) and depth > 0: | |
| if text[pos:pos+7] == '<think>': | |
| depth += 1 | |
| pos += 7 | |
| elif text[pos:pos+8] == '</think>': | |
| depth -= 1 | |
| pos += 8 | |
| else: | |
| pos += 1 | |
| if depth == 0: | |
| text = text[:start_pos] + text[pos:] | |
| else: | |
| break | |
| return text | |
| def find_last_trigger_signal_outside_think(text: str, trigger_signal: str) -> int: | |
| """ | |
| Find the last occurrence position of trigger_signal that is NOT inside any <think>...</think> block. | |
| Returns -1 if not found. | |
| """ | |
| if not text or not trigger_signal: | |
| return -1 | |
| i = 0 | |
| think_depth = 0 | |
| last_pos = -1 | |
| while i < len(text): | |
| if text.startswith("<think>", i): | |
| think_depth += 1 | |
| i += 7 | |
| continue | |
| if text.startswith("</think>", i): | |
| think_depth = max(0, think_depth - 1) | |
| i += 8 | |
| continue | |
| if think_depth == 0 and text.startswith(trigger_signal, i): | |
| last_pos = i | |
| # Move forward by 1 to allow overlapping search (not expected, but safe) | |
| i += 1 | |
| continue | |
| i += 1 | |
| return last_pos | |
| class StreamingFunctionCallDetector: | |
| """Enhanced streaming function call detector, supports dynamic trigger signals, avoids misjudgment within <think> tags | |
| Core features: | |
| 1. Avoid triggering tool call detection within <think> blocks | |
| 2. Normally output <think> block content to the user | |
| 3. Supports nested think tags | |
| """ | |
| def __init__(self, trigger_signal: str): | |
| self.trigger_signal = trigger_signal | |
| self.reset() | |
| def reset(self): | |
| self.content_buffer = "" | |
| self.state = "detecting" # detecting, tool_parsing | |
| self.in_think_block = False | |
| self.think_depth = 0 | |
| self.signal = self.trigger_signal | |
| self.signal_len = len(self.signal) | |
| def process_chunk(self, delta_content: str) -> tuple[bool, str]: | |
| """ | |
| Process streaming content chunk | |
| Returns: (is_tool_call_detected, content_to_yield) | |
| """ | |
| if not delta_content: | |
| return False, "" | |
| self.content_buffer += delta_content | |
| content_to_yield = "" | |
| if self.state == "tool_parsing": | |
| return False, "" | |
| if delta_content: | |
| logger.debug(f"🔧 Processing chunk: {repr(delta_content[:50])}{'...' if len(delta_content) > 50 else ''}, buffer length: {len(self.content_buffer)}, think state: {self.in_think_block}") | |
| i = 0 | |
| while i < len(self.content_buffer): | |
| skip_chars = self._update_think_state(i) | |
| if skip_chars > 0: | |
| for j in range(skip_chars): | |
| if i + j < len(self.content_buffer): | |
| content_to_yield += self.content_buffer[i + j] | |
| i += skip_chars | |
| continue | |
| if not self.in_think_block and self._can_detect_signal_at(i): | |
| if self.content_buffer[i:i+self.signal_len] == self.signal: | |
| logger.debug(f"🔧 Improved detector: detected trigger signal in non-think block! Signal: {self.signal[:20]}...") | |
| logger.debug(f"🔧 Trigger signal position: {i}, think state: {self.in_think_block}, think depth: {self.think_depth}") | |
| self.state = "tool_parsing" | |
| self.content_buffer = self.content_buffer[i:] | |
| return True, content_to_yield | |
| remaining_len = len(self.content_buffer) - i | |
| if remaining_len < self.signal_len or remaining_len < 8: | |
| break | |
| content_to_yield += self.content_buffer[i] | |
| i += 1 | |
| self.content_buffer = self.content_buffer[i:] | |
| return False, content_to_yield | |
| def _update_think_state(self, pos: int): | |
| """Update think tag state, supports nesting""" | |
| remaining = self.content_buffer[pos:] | |
| if remaining.startswith('<think>'): | |
| self.think_depth += 1 | |
| self.in_think_block = True | |
| logger.debug(f"🔧 Entering think block, depth: {self.think_depth}") | |
| return 7 | |
| elif remaining.startswith('</think>'): | |
| self.think_depth = max(0, self.think_depth - 1) | |
| self.in_think_block = self.think_depth > 0 | |
| logger.debug(f"🔧 Exiting think block, depth: {self.think_depth}") | |
| return 8 | |
| return 0 | |
| def _can_detect_signal_at(self, pos: int) -> bool: | |
| """Check if signal can be detected at the specified position""" | |
| return (pos + self.signal_len <= len(self.content_buffer) and | |
| not self.in_think_block) | |
| def finalize(self) -> Optional[List[Dict[str, Any]]]: | |
| """Final processing when stream ends""" | |
| if self.state == "tool_parsing": | |
| return parse_function_calls_xml(self.content_buffer, self.trigger_signal) | |
| return None | |
| def parse_function_calls_xml(xml_string: str, trigger_signal: str) -> Optional[List[Dict[str, Any]]]: | |
| """ | |
| Enhanced XML parsing function, supports dynamic trigger signals | |
| 1. Retain <think>...</think> blocks (they should be returned normally to the user) | |
| 2. Temporarily remove think blocks only when parsing function_calls to prevent think content from interfering with XML parsing | |
| 3. Find the last occurrence of the trigger signal | |
| 4. Start parsing function_calls from the last trigger signal | |
| """ | |
| logger.debug(f"🔧 Improved parser starting processing, input length: {len(xml_string) if xml_string else 0}") | |
| logger.debug(f"🔧 Using trigger signal: {trigger_signal[:20]}...") | |
| if not xml_string or trigger_signal not in xml_string: | |
| logger.debug(f"🔧 Input is empty or doesn't contain trigger signal") | |
| return None | |
| cleaned_content = remove_think_blocks(xml_string) | |
| logger.debug(f"🔧 Content length after temporarily removing think blocks: {len(cleaned_content)}") | |
| signal_positions = [] | |
| start_pos = 0 | |
| while True: | |
| pos = cleaned_content.find(trigger_signal, start_pos) | |
| if pos == -1: | |
| break | |
| signal_positions.append(pos) | |
| start_pos = pos + 1 | |
| if not signal_positions: | |
| logger.debug(f"🔧 No trigger signal found in cleaned content") | |
| return None | |
| logger.debug(f"🔧 Found {len(signal_positions)} trigger signal positions: {signal_positions}") | |
| chosen_signal_index = None | |
| chosen_signal_pos = None | |
| calls_content_match = None | |
| for idx in range(len(signal_positions) - 1, -1, -1): | |
| pos = signal_positions[idx] | |
| sub = cleaned_content[pos:] | |
| m = re.search(r"<function_calls>([\s\S]*?)</function_calls>", sub) | |
| if m: | |
| chosen_signal_index = idx | |
| chosen_signal_pos = pos | |
| calls_content_match = m | |
| logger.debug(f"🔧 Using trigger signal index {idx} at pos {pos}; content preview: {repr(sub[:100])}") | |
| break | |
| if calls_content_match is None: | |
| logger.debug(f"🔧 No function_calls tag found after any trigger signal (triggers={len(signal_positions)})") | |
| return None | |
| calls_xml = calls_content_match.group(0) | |
| calls_content = calls_content_match.group(1) | |
| logger.debug(f"🔧 function_calls content: {repr(calls_content)}") | |
| def _coerce_value(v: str): | |
| try: | |
| return json.loads(v) | |
| except Exception: | |
| return v | |
| def _parse_args_json_payload(payload: str) -> Optional[Dict[str, Any]]: | |
| """Strict args_json parsing with markdown fence removal. | |
| - Empty payload -> {} | |
| - Strips markdown code fences (```json ... ```) if present | |
| - Must be valid JSON and MUST decode to an object (dict) | |
| - Any invalid / non-object payload -> None (treated as parse failure) | |
| """ | |
| if payload is None: | |
| return {} | |
| s = payload.strip() | |
| if not s: | |
| return {} | |
| # Remove accidental markdown fences if model emits them | |
| if s.startswith("```"): | |
| s = re.sub(r"^```(?:json)?\s*", "", s) | |
| s = re.sub(r"\s*```$", "", s) | |
| try: | |
| parsed = json.loads(s) | |
| except Exception as e: | |
| logger.debug(f"🔧 Invalid JSON in args_json: {type(e).__name__}: {e}") | |
| return None | |
| if not isinstance(parsed, dict): | |
| logger.debug(f"🔧 args_json must decode to an object, got {type(parsed).__name__}") | |
| return None | |
| return parsed | |
| def _extract_cdata_text(raw: str) -> str: | |
| if raw is None: | |
| return "" | |
| if "<![CDATA[" not in raw: | |
| return raw | |
| parts = re.findall(r"<!\[CDATA\[(.*?)\]\]>", raw, flags=re.DOTALL) | |
| return "".join(parts) if parts else raw | |
| results: List[Dict[str, Any]] = [] | |
| # Primary path: strict XML parse (requires model to output valid XML) | |
| try: | |
| root = ET.fromstring(calls_xml) | |
| for i, fc in enumerate(root.findall("function_call")): | |
| tool_el = fc.find("tool") | |
| name = (tool_el.text or "").strip() if tool_el is not None else "" | |
| if not name: | |
| logger.debug(f"🔧 No tool tag found in function_call #{i+1}") | |
| continue | |
| args: Dict[str, Any] = {} | |
| args_json_el = fc.find("args_json") | |
| if args_json_el is not None: | |
| parsed_args = _parse_args_json_payload(args_json_el.text or "") | |
| if parsed_args is None: | |
| logger.debug(f"🔧 Invalid args_json in function_call #{i+1}; treating as parse failure") | |
| return None | |
| args = parsed_args | |
| else: | |
| # Legacy fallback: <args><k>json</k></args> | |
| args_el = fc.find("args") | |
| if args_el is not None: | |
| for child in list(args_el): | |
| args[child.tag] = _coerce_value(child.text or "") | |
| result = {"name": name, "args": args} | |
| results.append(result) | |
| logger.debug(f"🔧 Added tool call: {result}") | |
| logger.debug(f"🔧 Final parsing result (XML): {results}") | |
| return results if results else None | |
| except Exception as e: | |
| logger.debug(f"🔧 XML library parse failed, falling back to regex parser: {type(e).__name__}: {e}") | |
| # Fallback path: regex parse (more tolerant to malformed XML) | |
| call_blocks = re.findall(r"<function_call>([\s\S]*?)</function_call>", calls_content) | |
| logger.debug(f"🔧 Found {len(call_blocks)} function_call blocks") | |
| for i, block in enumerate(call_blocks): | |
| logger.debug(f"🔧 Processing function_call #{i+1}: {repr(block)}") | |
| tool_match = re.search(r"<tool>(.*?)</tool>", block) | |
| if not tool_match: | |
| logger.debug(f"🔧 No tool tag found in block #{i+1}") | |
| continue | |
| name = tool_match.group(1).strip() | |
| args: Dict[str, Any] = {} | |
| args_json_match = re.search(r"<args_json>([\s\S]*?)</args_json>", block) | |
| if args_json_match: | |
| raw_payload = args_json_match.group(1) | |
| payload = _extract_cdata_text(raw_payload) | |
| parsed_args = _parse_args_json_payload(payload) | |
| if parsed_args is None: | |
| logger.debug(f"🔧 Invalid args_json in function_call #{i+1} (regex path); treating as parse failure") | |
| return None | |
| args = parsed_args | |
| else: | |
| # Legacy fallback | |
| args_block_match = re.search(r"<args>([\s\S]*?)</args>", block) | |
| if args_block_match: | |
| args_content_inner = args_block_match.group(1) | |
| arg_matches = re.findall(r"<([^\s>/]+)>([\s\S]*?)</\1>", args_content_inner) | |
| for k, v in arg_matches: | |
| args[k] = _coerce_value(v) | |
| result = {"name": name, "args": args} | |
| results.append(result) | |
| logger.debug(f"🔧 Added tool call: {result}") | |
| logger.debug(f"🔧 Final parsing result (regex): {results}") | |
| return results if results else None | |
| def find_upstream(model_name: str) -> tuple[Dict[str, Any], str]: | |
| """Find upstream configuration by model name, handling aliases and passthrough mode.""" | |
| # Handle model passthrough mode | |
| if app_config.features.model_passthrough: | |
| logger.info("🔄 Model passthrough mode is active. Forwarding to 'openai' service.") | |
| openai_service = None | |
| for service in app_config.upstream_services: | |
| if service.name == "openai": | |
| openai_service = service.model_dump() | |
| break | |
| if openai_service: | |
| if not openai_service.get("api_key"): | |
| raise HTTPException(status_code=500, detail="Configuration error: API key not found for the 'openai' service in model passthrough mode.") | |
| # In passthrough mode, the model name from the request is used directly. | |
| return openai_service, model_name | |
| else: | |
| raise HTTPException(status_code=500, detail="Configuration error: 'model_passthrough' is enabled, but no upstream service named 'openai' was found.") | |
| # Default routing logic | |
| chosen_model_entry = model_name | |
| if model_name in ALIAS_MAPPING: | |
| chosen_model_entry = random.choice(ALIAS_MAPPING[model_name]) | |
| logger.info(f"🔄 Model alias '{model_name}' detected. Randomly selected '{chosen_model_entry}' for this request.") | |
| service = MODEL_TO_SERVICE_MAPPING.get(chosen_model_entry) | |
| if service: | |
| if not service.get("api_key"): | |
| raise HTTPException(status_code=500, detail=f"Model configuration error: API key not found for service '{service.get('name')}'.") | |
| else: | |
| logger.warning(f"⚠️ Model '{model_name}' not found in configuration, using default service") | |
| service = DEFAULT_SERVICE | |
| if not service.get("api_key"): | |
| raise HTTPException(status_code=500, detail="Service configuration error: Default API key not found.") | |
| actual_model_name = chosen_model_entry | |
| if ':' in chosen_model_entry: | |
| parts = chosen_model_entry.split(':', 1) | |
| if len(parts) == 2: | |
| _, actual_model_name = parts | |
| return service, actual_model_name | |
| app = FastAPI() | |
| http_client = httpx.AsyncClient() | |
| def _is_retriable_upstream_error(exc: Exception) -> bool: | |
| return isinstance(exc, (httpx.ConnectError, httpx.TimeoutException)) | |
| def _get_upstream_retry_attempts() -> int: | |
| return getattr(app_config.server, "upstream_retry_attempts", 3) or 1 | |
| def _get_upstream_retry_delay(attempt: int) -> float: | |
| base_delay = getattr(app_config.server, "upstream_retry_base_delay", 0.5) | |
| return base_delay * (2 ** attempt) | |
| async def _post_upstream_with_retry(url: str, json_body: dict, headers: Dict[str, str], timeout: int) -> httpx.Response: | |
| """POST to upstream with automatic retry on connection errors and timeouts.""" | |
| max_attempts = _get_upstream_retry_attempts() | |
| if max_attempts <= 0: | |
| return await http_client.post(url, json=json_body, headers=headers, timeout=timeout) | |
| last_error: Optional[Exception] = None | |
| for attempt in range(max_attempts): | |
| try: | |
| return await http_client.post(url, json=json_body, headers=headers, timeout=timeout) | |
| except Exception as exc: | |
| last_error = exc | |
| if not _is_retriable_upstream_error(exc) or attempt >= max_attempts - 1: | |
| raise | |
| delay = _get_upstream_retry_delay(attempt) | |
| logger.warning( | |
| "⚠️ Upstream POST failed with %s; retrying in %.1fs (%s/%s)", | |
| type(exc).__name__, delay, attempt + 2, max_attempts, | |
| ) | |
| await asyncio.sleep(delay) | |
| if last_error is not None: | |
| raise last_error | |
| raise RuntimeError("Unreachable upstream retry state") | |
| async def debug_middleware(request: Request, call_next): | |
| """Middleware for debugging validation errors, does not log conversation content.""" | |
| response = await call_next(request) | |
| if response.status_code == 422: | |
| logger.debug(f"🔍 Validation error detected for {request.method} {request.url.path}") | |
| logger.debug(f"🔍 Response status code: 422 (Pydantic validation failure)") | |
| return response | |
| async def validation_exception_handler(request: Request, exc: ValidationError): | |
| """Handle Pydantic validation errors with detailed error information""" | |
| logger.error(f"❌ Pydantic validation error: {exc}") | |
| logger.error(f"❌ Request URL: {request.url}") | |
| logger.error(f"❌ Error details: {exc.errors()}") | |
| for error in exc.errors(): | |
| logger.error(f"❌ Validation error location: {error.get('loc')}") | |
| logger.error(f"❌ Validation error message: {error.get('msg')}") | |
| logger.error(f"❌ Validation error type: {error.get('type')}") | |
| return JSONResponse( | |
| status_code=422, | |
| content={ | |
| "error": { | |
| "message": "Invalid request format", | |
| "type": "invalid_request_error", | |
| "code": "invalid_request" | |
| } | |
| } | |
| ) | |
| async def http_exception_handler(request: Request, exc: HTTPException): | |
| """Handle FastAPI HTTPException with OpenAI-compatible error envelope""" | |
| status = exc.status_code | |
| if status == 400: | |
| err_type = "invalid_request_error" | |
| code = "invalid_request" | |
| elif status == 401: | |
| err_type = "authentication_error" | |
| code = "unauthorized" | |
| elif status == 403: | |
| err_type = "permission_error" | |
| code = "forbidden" | |
| elif status == 429: | |
| err_type = "rate_limit_error" | |
| code = "rate_limit_exceeded" | |
| else: | |
| err_type = "server_error" | |
| code = "internal_error" | |
| return JSONResponse( | |
| status_code=status, | |
| content={ | |
| "error": { | |
| "message": str(exc.detail), | |
| "type": err_type, | |
| "code": code, | |
| } | |
| }, | |
| ) | |
| async def general_exception_handler(request: Request, exc: Exception): | |
| """Handle all uncaught exceptions""" | |
| logger.error(f"❌ Unhandled exception: {exc}") | |
| logger.error(f"❌ Request URL: {request.url}") | |
| logger.error(f"❌ Exception type: {type(exc).__name__}") | |
| logger.error(f"❌ Error stack: {traceback.format_exc()}") | |
| return JSONResponse( | |
| status_code=500, | |
| content={ | |
| "error": { | |
| "message": "Internal server error", | |
| "type": "server_error", | |
| "code": "internal_error" | |
| } | |
| } | |
| ) | |
| async def verify_api_key(authorization: str = Header(...)): | |
| """Dependency: verify client API key""" | |
| client_key = authorization.replace("Bearer ", "") | |
| if app_config.features.key_passthrough: | |
| # In passthrough mode, skip allowed_keys check | |
| return client_key | |
| if client_key not in ALLOWED_CLIENT_KEYS: | |
| raise HTTPException(status_code=401, detail="Unauthorized") | |
| return client_key | |
| def preprocess_messages(messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]: | |
| """ | |
| Preprocess messages: | |
| - Convert role=tool messages to role=user text for upstream compatibility | |
| - Convert assistant.tool_calls into assistant.content (XML format) for upstream context | |
| - Convert developer->system if configured | |
| """ | |
| tool_call_index = build_tool_call_index_from_messages(messages) | |
| processed_messages: List[Dict[str, Any]] = [] | |
| for message in messages: | |
| if isinstance(message, dict): | |
| if message.get("role") == "tool": | |
| tool_call_id = message.get("tool_call_id") | |
| content = message.get("content") | |
| if not tool_call_id: | |
| raise HTTPException(status_code=400, detail="Tool message missing tool_call_id") | |
| # content may be empty string in some cases; only reject None | |
| if content is None: | |
| raise HTTPException(status_code=400, detail=f"Tool message missing content for tool_call_id={tool_call_id}") | |
| tool_info = tool_call_index.get(tool_call_id) | |
| if not tool_info: | |
| raise HTTPException( | |
| status_code=400, | |
| detail=( | |
| f"tool_call_id={tool_call_id} not found in conversation history. " | |
| f"Ensure the assistant message with this tool_call is included in the messages array." | |
| ) | |
| ) | |
| formatted_content = format_tool_result_for_ai( | |
| tool_name=tool_info["name"], | |
| tool_arguments=tool_info["arguments"], | |
| result_content=content, | |
| ) | |
| processed_messages.append({ | |
| "role": "user", | |
| "content": formatted_content | |
| }) | |
| logger.debug(f"🔧 Converted tool message to user message: tool_call_id={tool_call_id}, tool={tool_info['name']}") | |
| elif message.get("role") == "assistant" and message.get("tool_calls"): | |
| tool_calls = message.get("tool_calls", []) | |
| formatted_tool_calls_str = format_assistant_tool_calls_for_ai(tool_calls, GLOBAL_TRIGGER_SIGNAL) | |
| original_content = message.get("content") or "" | |
| final_content = f"{original_content}\n{formatted_tool_calls_str}".strip() | |
| processed_message = { | |
| "role": "assistant", | |
| "content": final_content | |
| } | |
| for key, value in message.items(): | |
| if key not in ["role", "content", "tool_calls"]: | |
| processed_message[key] = value | |
| processed_messages.append(processed_message) | |
| logger.debug("🔧 Converted assistant tool_calls to content.") | |
| elif message.get("role") == "developer": | |
| if app_config.features.convert_developer_to_system: | |
| processed_message = message.copy() | |
| processed_message["role"] = "system" | |
| processed_messages.append(processed_message) | |
| logger.debug("🔧 Converted developer message to system message for better upstream compatibility") | |
| else: | |
| processed_messages.append(message) | |
| logger.debug("🔧 Keeping developer role unchanged (based on configuration)") | |
| else: | |
| processed_messages.append(message) | |
| else: | |
| processed_messages.append(message) | |
| return processed_messages | |
| async def chat_completions( | |
| request: Request, | |
| body: ChatCompletionRequest, | |
| _api_key: str = Depends(verify_api_key) | |
| ): | |
| """Main chat completion endpoint, proxy and inject function calling capabilities.""" | |
| start_time = time.time() | |
| try: | |
| logger.debug(f"🔧 Received request, model: {body.model}") | |
| logger.debug(f"🔧 Number of messages: {len(body.messages)}") | |
| logger.debug(f"🔧 Number of tools: {len(body.tools) if body.tools else 0}") | |
| logger.debug(f"🔧 Streaming: {body.stream}") | |
| upstream, actual_model = find_upstream(body.model) | |
| upstream_url = f"{upstream['base_url']}/chat/completions" | |
| logger.debug(f"🔧 Starting message preprocessing, original message count: {len(body.messages)}") | |
| processed_messages = preprocess_messages(body.messages) | |
| logger.debug(f"🔧 Preprocessing completed, processed message count: {len(processed_messages)}") | |
| if not validate_message_structure(processed_messages): | |
| logger.error(f"❌ Message structure validation failed, but continuing processing") | |
| request_body_dict = body.model_dump(exclude_unset=True) | |
| request_body_dict["model"] = actual_model | |
| request_body_dict["messages"] = processed_messages | |
| is_fc_enabled = app_config.features.enable_function_calling | |
| has_tools_in_request = bool(body.tools) | |
| has_function_call = is_fc_enabled and has_tools_in_request | |
| logger.debug(f"🔧 Request body constructed, message count: {len(processed_messages)}") | |
| except HTTPException as e: | |
| # Preserve expected status codes (e.g., 400 for invalid tool_call_id history) | |
| logger.error(f"❌ Request rejected: status_code={e.status_code}, detail={e.detail}") | |
| return JSONResponse( | |
| status_code=e.status_code, | |
| content={ | |
| "error": { | |
| "message": str(e.detail), | |
| "type": "invalid_request_error" if e.status_code == 400 else ( | |
| "authentication_error" if e.status_code == 401 else ( | |
| "permission_error" if e.status_code == 403 else ( | |
| "rate_limit_error" if e.status_code == 429 else "server_error" | |
| ) | |
| ) | |
| ), | |
| "code": "invalid_request" if e.status_code == 400 else ( | |
| "unauthorized" if e.status_code == 401 else ( | |
| "forbidden" if e.status_code == 403 else ( | |
| "rate_limit_exceeded" if e.status_code == 429 else "internal_error" | |
| ) | |
| ) | |
| ) | |
| } | |
| } | |
| ) | |
| except Exception as e: | |
| logger.error(f"❌ Request preprocessing failed: {str(e)}") | |
| logger.error(f"❌ Error type: {type(e).__name__}") | |
| if hasattr(app_config, 'debug') and app_config.debug: | |
| logger.error(f"❌ Error stack: {traceback.format_exc()}") | |
| return JSONResponse( | |
| status_code=422, | |
| content={ | |
| "error": { | |
| "message": "Invalid request format", | |
| "type": "invalid_request_error", | |
| "code": "invalid_request" | |
| } | |
| } | |
| ) | |
| if has_function_call: | |
| logger.debug(f"🔧 Using global trigger signal for this request: {GLOBAL_TRIGGER_SIGNAL}") | |
| tools_for_request: List[Tool] = body.tools or [] | |
| function_prompt, _ = generate_function_prompt(tools_for_request, GLOBAL_TRIGGER_SIGNAL) | |
| tool_choice_prompt = safe_process_tool_choice(body.tool_choice, tools_for_request) | |
| if tool_choice_prompt: | |
| function_prompt += tool_choice_prompt | |
| system_message = {"role": "system", "content": function_prompt} | |
| request_body_dict["messages"].insert(0, system_message) | |
| if "tools" in request_body_dict: | |
| del request_body_dict["tools"] | |
| if "tool_choice" in request_body_dict: | |
| del request_body_dict["tool_choice"] | |
| elif has_tools_in_request and not is_fc_enabled: | |
| logger.info(f"🔧 Function calling is disabled by configuration, ignoring 'tools' and 'tool_choice' in request.") | |
| if "tools" in request_body_dict: | |
| del request_body_dict["tools"] | |
| if "tool_choice" in request_body_dict: | |
| del request_body_dict["tool_choice"] | |
| prompt_tokens = token_counter.count_tokens(request_body_dict["messages"], body.model) | |
| logger.info(f"📊 Request to {body.model} - Actual input tokens (including all preprocessing & injected prompts): {prompt_tokens}") | |
| headers = { | |
| "Content-Type": "application/json", | |
| "Authorization": f"Bearer {_api_key}" if app_config.features.key_passthrough else f"Bearer {upstream['api_key']}", | |
| "Accept": "application/json" if not body.stream else "text/event-stream" | |
| } | |
| logger.info(f"📝 Forwarding request to upstream: {upstream['name']}") | |
| logger.info(f"📝 Model: {request_body_dict.get('model', 'unknown')}, Messages: {len(request_body_dict.get('messages', []))}") | |
| if not body.stream: | |
| try: | |
| logger.debug(f"🔧 Sending upstream request to: {upstream_url}") | |
| logger.debug(f"🔧 has_function_call: {has_function_call}") | |
| logger.debug(f"🔧 Request body contains tools: {bool(body.tools)}") | |
| upstream_response = await _post_upstream_with_retry( | |
| upstream_url, request_body_dict, headers, app_config.server.timeout | |
| ) | |
| upstream_response.raise_for_status() | |
| response_json: Dict[str, Any] = upstream_response.json() | |
| logger.debug(f"🔧 Upstream response status code: {upstream_response.status_code}") | |
| # Count output tokens and handle usage | |
| completion_text = "" | |
| if response_json.get("choices") and len(response_json["choices"]) > 0: | |
| message = response_json["choices"][0].get("message", {}) | |
| # Extract content | |
| content = message.get("content") | |
| if content: | |
| completion_text = content | |
| # Check for reasoning_content | |
| reasoning_content = message.get("reasoning_content") | |
| if reasoning_content: | |
| completion_text = (completion_text + "\n" + reasoning_content).strip() if completion_text else reasoning_content | |
| logger.debug(f"🔧 Found reasoning_content, adding {len(reasoning_content)} chars to token count") | |
| # Calculate our estimated tokens | |
| estimated_completion_tokens = token_counter.count_text_tokens(completion_text, body.model) if completion_text else 0 | |
| estimated_prompt_tokens = prompt_tokens | |
| estimated_total_tokens = estimated_prompt_tokens + estimated_completion_tokens | |
| elapsed_time = time.time() - start_time | |
| # Check if upstream provided usage and respect it | |
| upstream_usage = response_json.get("usage", {}) | |
| if upstream_usage: | |
| # Preserve upstream's usage structure and only replace zero values | |
| final_usage = upstream_usage.copy() | |
| # Replace zero or missing values with our estimates | |
| if not final_usage.get("prompt_tokens") or final_usage.get("prompt_tokens") == 0: | |
| final_usage["prompt_tokens"] = estimated_prompt_tokens | |
| logger.debug(f"🔧 Replaced zero/missing prompt_tokens with estimate: {estimated_prompt_tokens}") | |
| if not final_usage.get("completion_tokens") or final_usage.get("completion_tokens") == 0: | |
| final_usage["completion_tokens"] = estimated_completion_tokens | |
| logger.debug(f"🔧 Replaced zero/missing completion_tokens with estimate: {estimated_completion_tokens}") | |
| if not final_usage.get("total_tokens") or final_usage.get("total_tokens") == 0: | |
| final_usage["total_tokens"] = final_usage.get("prompt_tokens", estimated_prompt_tokens) + final_usage.get("completion_tokens", estimated_completion_tokens) | |
| logger.debug(f"🔧 Replaced zero/missing total_tokens with calculated value: {final_usage['total_tokens']}") | |
| response_json["usage"] = final_usage | |
| logger.debug(f"🔧 Preserved upstream usage with replacements: {final_usage}") | |
| else: | |
| # No upstream usage, provide our estimates | |
| response_json["usage"] = { | |
| "prompt_tokens": estimated_prompt_tokens, | |
| "completion_tokens": estimated_completion_tokens, | |
| "total_tokens": estimated_total_tokens | |
| } | |
| logger.debug(f"🔧 No upstream usage found, using estimates") | |
| # Log token statistics | |
| actual_usage = response_json["usage"] | |
| logger.info("=" * 60) | |
| logger.info(f"📊 Token Usage Statistics - Model: {body.model}") | |
| logger.info(f" Input Tokens: {actual_usage.get('prompt_tokens', 0)}") | |
| logger.info(f" Output Tokens: {actual_usage.get('completion_tokens', 0)}") | |
| logger.info(f" Total Tokens: {actual_usage.get('total_tokens', 0)}") | |
| logger.info(f" Duration: {elapsed_time:.2f}s") | |
| logger.info("=" * 60) | |
| if has_function_call: | |
| content = response_json["choices"][0]["message"]["content"] | |
| logger.debug(f"🔧 Complete response content: {repr(content)}") | |
| parsed_tools = await attempt_fc_parse_with_retry( | |
| content=content, | |
| trigger_signal=GLOBAL_TRIGGER_SIGNAL, | |
| messages=request_body_dict["messages"], | |
| upstream_url=upstream_url, | |
| headers=headers, | |
| model=actual_model, | |
| tools=body.tools or [], | |
| timeout=app_config.server.timeout | |
| ) | |
| logger.debug(f"🔧 XML parsing result: {parsed_tools}") | |
| if parsed_tools: | |
| logger.debug(f"🔧 Successfully parsed {len(parsed_tools)} tool calls") | |
| estimated_completion_tokens = token_counter.count_text_tokens(content, body.model) | |
| estimated_total_tokens = estimated_prompt_tokens + estimated_completion_tokens | |
| logger.debug(f"🔧 Completion tokens: {estimated_completion_tokens}") | |
| tool_calls = [] | |
| for tool in parsed_tools: | |
| tool_call_id = f"call_{uuid.uuid4().hex}" | |
| tool_calls.append({ | |
| "id": tool_call_id, | |
| "type": "function", | |
| "function": { | |
| "name": tool["name"], | |
| "arguments": json.dumps(tool["args"]) | |
| } | |
| }) | |
| logger.debug(f"🔧 Converted tool_calls: {tool_calls}") | |
| prefix_pos = find_last_trigger_signal_outside_think(content, GLOBAL_TRIGGER_SIGNAL) | |
| prefix_text = None | |
| if prefix_pos != -1: | |
| prefix_text = content[:prefix_pos].rstrip() | |
| if prefix_text == "": | |
| prefix_text = None | |
| # Preserve extra fields from upstream message (e.g., reasoning_content, refusal, audio, annotations) | |
| original_message = response_json["choices"][0]["message"] | |
| new_message = { | |
| "role": "assistant", | |
| "content": prefix_text, | |
| "tool_calls": tool_calls, | |
| } | |
| # Copy over any extra fields that upstream returned | |
| for key in original_message: | |
| if key not in ["role", "content", "tool_calls"]: | |
| new_message[key] = original_message[key] | |
| response_json["choices"][0]["message"] = new_message | |
| response_json["choices"][0]["finish_reason"] = "tool_calls" | |
| logger.debug(f"🔧 Function call conversion completed") | |
| else: | |
| logger.debug(f"🔧 No tool calls detected, returning original content (including think blocks)") | |
| else: | |
| logger.debug(f"🔧 No function calls detected or conversion conditions not met") | |
| return JSONResponse(content=response_json) | |
| except httpx.HTTPStatusError as e: | |
| logger.error(f"❌ Upstream service response error: status_code={e.response.status_code}") | |
| logger.error(f"❌ Upstream error details: {e.response.text}") | |
| if e.response.status_code == 400: | |
| error_response = { | |
| "error": { | |
| "message": "Invalid request parameters", | |
| "type": "invalid_request_error", | |
| "code": "bad_request" | |
| } | |
| } | |
| elif e.response.status_code == 401: | |
| error_response = { | |
| "error": { | |
| "message": "Authentication failed", | |
| "type": "authentication_error", | |
| "code": "unauthorized" | |
| } | |
| } | |
| elif e.response.status_code == 403: | |
| error_response = { | |
| "error": { | |
| "message": "Access forbidden", | |
| "type": "permission_error", | |
| "code": "forbidden" | |
| } | |
| } | |
| elif e.response.status_code == 429: | |
| error_response = { | |
| "error": { | |
| "message": "Rate limit exceeded", | |
| "type": "rate_limit_error", | |
| "code": "rate_limit_exceeded" | |
| } | |
| } | |
| elif e.response.status_code >= 500: | |
| error_response = { | |
| "error": { | |
| "message": "Upstream service temporarily unavailable", | |
| "type": "service_error", | |
| "code": "upstream_error" | |
| } | |
| } | |
| else: | |
| error_response = { | |
| "error": { | |
| "message": "Request processing failed", | |
| "type": "api_error", | |
| "code": "unknown_error" | |
| } | |
| } | |
| return JSONResponse(content=error_response, status_code=e.response.status_code) | |
| else: | |
| async def stream_with_token_count(): | |
| completion_tokens = 0 | |
| completion_text = "" | |
| done_received = False | |
| stream_id = None # Keep all streamed chunks under the same id (OpenAI-compatible) | |
| upstream_usage_chunk = None # Store upstream usage chunk if any | |
| async for chunk in stream_proxy_with_fc_transform( | |
| upstream_url, | |
| request_body_dict, | |
| headers, | |
| body.model, | |
| has_function_call, | |
| GLOBAL_TRIGGER_SIGNAL, | |
| request_body_dict["messages"], | |
| tools=body.tools or [], | |
| ): | |
| # Check if this is the [DONE] marker | |
| if chunk.startswith(b"data: "): | |
| try: | |
| line_data = chunk[6:].decode('utf-8').strip() | |
| if line_data == "[DONE]": | |
| done_received = True | |
| # Don't yield the [DONE] marker yet, we'll send it after usage info | |
| break | |
| elif line_data: | |
| chunk_json = json.loads(line_data) | |
| if stream_id is None and isinstance(chunk_json, dict): | |
| stream_id = chunk_json.get("id") | |
| if chunk_json.get("object") == "chat.completion.chunk.internal": | |
| raw_fc_content = chunk_json.get("_internal_fc_raw_content", "") | |
| if raw_fc_content: | |
| completion_text += raw_fc_content | |
| logger.debug(f"🔧 Received internal FC raw content for token counting: {len(raw_fc_content)} chars") | |
| continue | |
| # Check if this chunk contains usage information | |
| if "usage" in chunk_json: | |
| upstream_usage_chunk = chunk_json | |
| logger.debug(f"🔧 Detected upstream usage data in chunk") | |
| if not ("choices" in chunk_json and len(chunk_json["choices"]) > 0): | |
| continue | |
| else: | |
| chunk_json = {k: v for k, v in chunk_json.items() if k != "usage"} | |
| chunk = f"data: {json.dumps(chunk_json)}\n\n".encode('utf-8') | |
| # Process regular content chunks | |
| if "choices" in chunk_json and len(chunk_json["choices"]) > 0: | |
| delta = chunk_json["choices"][0].get("delta", {}) | |
| # Accumulate content | |
| content = delta.get("content", "") | |
| if content: | |
| completion_text += content | |
| # Accumulate reasoning_content | |
| reasoning_content = delta.get("reasoning_content", "") | |
| if reasoning_content: | |
| completion_text += reasoning_content | |
| logger.debug(f"🔧 Found reasoning_content in stream, accumulating for token count") | |
| except (json.JSONDecodeError, KeyError, UnicodeDecodeError) as e: | |
| logger.debug(f"Failed to parse chunk for token counting: {e}") | |
| pass | |
| yield chunk | |
| # Calculate our estimated tokens | |
| estimated_completion_tokens = token_counter.count_text_tokens(completion_text, body.model) if completion_text else 0 | |
| estimated_prompt_tokens = prompt_tokens | |
| estimated_total_tokens = estimated_prompt_tokens + estimated_completion_tokens | |
| elapsed_time = time.time() - start_time | |
| # Determine final usage | |
| final_usage = None | |
| if upstream_usage_chunk and "usage" in upstream_usage_chunk: | |
| # Respect upstream usage, but replace zero values | |
| upstream_usage = upstream_usage_chunk["usage"] | |
| final_usage = upstream_usage.copy() | |
| if not final_usage.get("prompt_tokens") or final_usage.get("prompt_tokens") == 0: | |
| final_usage["prompt_tokens"] = estimated_prompt_tokens | |
| logger.debug(f"🔧 Replaced zero/missing prompt_tokens with estimate: {estimated_prompt_tokens}") | |
| if not final_usage.get("completion_tokens") or final_usage.get("completion_tokens") == 0: | |
| final_usage["completion_tokens"] = estimated_completion_tokens | |
| logger.debug(f"🔧 Replaced zero/missing completion_tokens with estimate: {estimated_completion_tokens}") | |
| if not final_usage.get("total_tokens") or final_usage.get("total_tokens") == 0: | |
| final_usage["total_tokens"] = final_usage.get("prompt_tokens", estimated_prompt_tokens) + final_usage.get("completion_tokens", estimated_completion_tokens) | |
| logger.debug(f"🔧 Replaced zero/missing total_tokens with calculated value: {final_usage['total_tokens']}") | |
| logger.debug(f"🔧 Using upstream usage with replacements: {final_usage}") | |
| else: | |
| # No upstream usage, use our estimates | |
| final_usage = { | |
| "prompt_tokens": estimated_prompt_tokens, | |
| "completion_tokens": estimated_completion_tokens, | |
| "total_tokens": estimated_total_tokens | |
| } | |
| logger.debug(f"🔧 No upstream usage found, using estimates") | |
| # Log token statistics | |
| logger.info("=" * 60) | |
| logger.info(f"📊 Token Usage Statistics - Model: {body.model}") | |
| logger.info(f" Input Tokens: {final_usage['prompt_tokens']}") | |
| logger.info(f" Output Tokens: {final_usage['completion_tokens']}") | |
| logger.info(f" Total Tokens: {final_usage['total_tokens']}") | |
| logger.info(f" Duration: {elapsed_time:.2f}s") | |
| logger.info("=" * 60) | |
| # Send usage information only if requested via stream_options.include_usage | |
| if body.stream_options and body.stream_options.get("include_usage", False): | |
| usage_chunk_to_send = { | |
| "id": (upstream_usage_chunk.get("id") if isinstance(upstream_usage_chunk, dict) else None) or stream_id or f"chatcmpl-{uuid.uuid4().hex}", | |
| "object": "chat.completion.chunk", | |
| "created": int(time.time()), | |
| "model": body.model, | |
| "choices": [], | |
| "usage": final_usage | |
| } | |
| # If upstream provided additional fields in the usage chunk, preserve them | |
| if upstream_usage_chunk: | |
| for key in upstream_usage_chunk: | |
| if key not in ["usage", "choices"] and key not in usage_chunk_to_send: | |
| usage_chunk_to_send[key] = upstream_usage_chunk[key] | |
| yield f"data: {json.dumps(usage_chunk_to_send)}\n\n".encode('utf-8') | |
| logger.debug(f"🔧 Sent usage chunk in stream: {usage_chunk_to_send['usage']}") | |
| # Send [DONE] marker if it was received | |
| if done_received: | |
| yield b"data: [DONE]\n\n" | |
| return StreamingResponse( | |
| stream_with_token_count(), | |
| media_type="text/event-stream" | |
| ) | |
| async def _attempt_streaming_fc_retry( | |
| original_content: str, | |
| trigger_signal: str, | |
| messages: List[Dict[str, Any]], | |
| url: str, | |
| headers: Dict[str, str], | |
| model: str, | |
| timeout: int, | |
| tools: Optional[List["Tool"]] = None, | |
| ) -> Optional[List[Dict[str, Any]]]: | |
| max_attempts = app_config.features.fc_error_retry_max_attempts | |
| current_content = original_content | |
| current_messages = messages.copy() | |
| def _parse_and_validate(current_content: str) -> tuple[Optional[List[Dict[str, Any]]], Optional[str]]: | |
| parsed = parse_function_calls_xml(current_content, trigger_signal) | |
| if not parsed: | |
| return None, None | |
| validation_error = validate_parsed_tools(parsed, tools or []) | |
| if validation_error: | |
| return None, validation_error | |
| return parsed, None | |
| validation_error: Optional[str] = None | |
| for attempt in range(max_attempts): | |
| # Same rule as non-streaming: avoid retrying if the trigger only appears inside <think>. | |
| if find_last_trigger_signal_outside_think(current_content, trigger_signal) == -1: | |
| logger.debug("🔧 Streaming retry: no trigger signal found outside <think> blocks; aborting retry") | |
| return None | |
| validation_error = None | |
| if attempt == 0: | |
| parsed_tools, validation_error = _parse_and_validate(current_content) | |
| if parsed_tools: | |
| return parsed_tools | |
| if attempt >= max_attempts - 1: | |
| logger.warning(f"⚠️ Streaming FC retry failed after {max_attempts} attempts") | |
| return None | |
| # Classify the failure type to choose the right retry strategy | |
| failure_type = _classify_fc_failure(current_content, trigger_signal) | |
| if failure_type == "no_fc": | |
| return None | |
| error_details = validation_error or _diagnose_fc_parse_error(current_content, trigger_signal) | |
| if failure_type == "truncated": | |
| retry_prompt = get_fc_continuation_prompt(current_content, error_details) | |
| logger.info(f"🔄 Streaming FC output truncated, requesting continuation {attempt + 2}/{max_attempts}") | |
| else: | |
| retry_prompt = get_fc_error_retry_prompt(current_content, error_details) | |
| logger.info(f"🔄 Streaming FC syntax error, requesting rewrite {attempt + 2}/{max_attempts}") | |
| logger.debug(f"🔧 Failure type: {failure_type}, error details: {error_details}") | |
| retry_messages = current_messages + [ | |
| {"role": "assistant", "content": current_content}, | |
| {"role": "user", "content": retry_prompt} | |
| ] | |
| try: | |
| retry_response = await http_client.post( | |
| url, | |
| json={"model": model, "messages": retry_messages, "stream": False}, | |
| headers=headers, | |
| timeout=timeout | |
| ) | |
| retry_response.raise_for_status() | |
| retry_json = retry_response.json() | |
| if retry_json.get("choices") and len(retry_json["choices"]) > 0: | |
| retry_content = retry_json["choices"][0].get("message", {}).get("content", "") | |
| if failure_type == "truncated" and _is_continuation_response(retry_content, trigger_signal): | |
| current_content = _merge_truncated_and_continuation(current_content, retry_content) | |
| logger.info(f"🔧 Streaming: merged continuation, total length: {len(current_content)}") | |
| else: | |
| current_content = retry_content | |
| current_messages = retry_messages | |
| parsed_tools, validation_error = _parse_and_validate(current_content) | |
| if parsed_tools: | |
| return parsed_tools | |
| else: | |
| logger.warning(f"⚠️ Streaming FC retry response has no valid choices") | |
| return None | |
| except Exception as e: | |
| logger.error(f"❌ Streaming FC retry request failed: {e}") | |
| return None | |
| return None | |
| async def stream_proxy_with_fc_transform( | |
| url: str, | |
| body: dict, | |
| headers: dict, | |
| model: str, | |
| has_fc: bool, | |
| trigger_signal: str, | |
| original_messages: Optional[List[Dict[str, Any]]] = None, | |
| tools: Optional[List["Tool"]] = None, | |
| ): | |
| """ | |
| Enhanced streaming proxy, supports dynamic trigger signals, avoids misjudgment within think tags | |
| """ | |
| logger.info(f"📝 Starting streaming response from: {url}") | |
| logger.info(f"📝 Function calling enabled: {has_fc}") | |
| if not has_fc or not trigger_signal: | |
| max_attempts = _get_upstream_retry_attempts() | |
| streamed_any_output = False | |
| for attempt in range(max_attempts): | |
| try: | |
| async with http_client.stream("POST", url, json=body, headers=headers, timeout=app_config.server.timeout) as response: | |
| async for chunk in response.aiter_bytes(): | |
| streamed_any_output = True | |
| yield chunk | |
| return | |
| except httpx.RemoteProtocolError: | |
| logger.debug("🔧 Upstream closed connection prematurely, ending stream response") | |
| return | |
| except Exception as exc: | |
| if streamed_any_output or not _is_retriable_upstream_error(exc) or attempt >= max_attempts - 1: | |
| raise | |
| delay = _get_upstream_retry_delay(attempt) | |
| logger.warning( | |
| "⚠️ Upstream stream failed with %s; retrying in %.1fs (%s/%s)", | |
| type(exc).__name__, delay, attempt + 2, max_attempts, | |
| ) | |
| await asyncio.sleep(delay) | |
| return | |
| detector = StreamingFunctionCallDetector(trigger_signal) | |
| stream_id = None | |
| pending_finish_reason = None | |
| def _ensure_stream_id(chunk_json: Optional[Dict[str, Any]] = None) -> str: | |
| nonlocal stream_id | |
| if stream_id is None: | |
| upstream_id = chunk_json.get("id") if isinstance(chunk_json, dict) else None | |
| stream_id = upstream_id or f"chatcmpl-passthrough-{uuid.uuid4().hex}" | |
| return stream_id | |
| def _prepare_tool_calls(parsed_tools: List[Dict[str, Any]]): | |
| tool_calls = [] | |
| for i, tool in enumerate(parsed_tools): | |
| tool_call_id = f"call_{uuid.uuid4().hex}" | |
| tool_calls.append({ | |
| "index": i, "id": tool_call_id, "type": "function", | |
| "function": { "name": tool["name"], "arguments": json.dumps(tool["args"]) } | |
| }) | |
| return tool_calls | |
| def _build_tool_call_sse_chunks(parsed_tools: List[Dict[str, Any]], model_id: str, raw_content: str = "") -> List[str]: | |
| tool_calls = _prepare_tool_calls(parsed_tools) | |
| chunks: List[str] = [] | |
| if raw_content: | |
| metadata_chunk = { | |
| "object": "chat.completion.chunk.internal", | |
| "_internal_fc_raw_content": raw_content | |
| } | |
| chunks.append(f"data: {json.dumps(metadata_chunk)}\n\n") | |
| tc_id = stream_id or f"chatcmpl-{uuid.uuid4().hex}" | |
| initial_chunk = { | |
| "id": tc_id, "object": "chat.completion.chunk", | |
| "created": int(time.time()), "model": model_id, | |
| "choices": [{"index": 0, "delta": {"role": "assistant", "content": None, "tool_calls": tool_calls}, "finish_reason": None}], | |
| } | |
| chunks.append(f"data: {json.dumps(initial_chunk)}\n\n") | |
| final_chunk = { | |
| "id": tc_id, "object": "chat.completion.chunk", | |
| "created": int(time.time()), "model": model_id, | |
| "choices": [{"index": 0, "delta": {}, "finish_reason": "tool_calls"}], | |
| } | |
| chunks.append(f"data: {json.dumps(final_chunk)}\n\n") | |
| chunks.append("data: [DONE]\n\n") | |
| return chunks | |
| max_attempts = _get_upstream_retry_attempts() | |
| for attempt in range(max_attempts): | |
| try: | |
| async with http_client.stream("POST", url, json=body, headers=headers, timeout=app_config.server.timeout) as response: | |
| if response.status_code != 200: | |
| error_content = await response.aread() | |
| logger.error(f"❌ Upstream service stream response error: status_code={response.status_code}") | |
| logger.error(f"❌ Upstream error details: {error_content.decode('utf-8', errors='ignore')}") | |
| if response.status_code == 401: | |
| error_message = "Authentication failed" | |
| elif response.status_code == 403: | |
| error_message = "Access forbidden" | |
| elif response.status_code == 429: | |
| error_message = "Rate limit exceeded" | |
| elif response.status_code >= 500: | |
| error_message = "Upstream service temporarily unavailable" | |
| else: | |
| error_message = "Request processing failed" | |
| error_chunk = {"error": {"message": error_message, "type": "upstream_error"}} | |
| yield f"data: {json.dumps(error_chunk)}\n\n".encode('utf-8') | |
| yield b"data: [DONE]\n\n" | |
| return | |
| async for line in response.aiter_lines(): | |
| if detector.state == "tool_parsing": | |
| if line.startswith("data:"): | |
| line_data = line[len("data: "):].strip() | |
| if line_data and line_data != "[DONE]": | |
| try: | |
| chunk_json = json.loads(line_data) | |
| delta_content = chunk_json.get("choices", [{}])[0].get("delta", {}).get("content", "") or "" | |
| detector.content_buffer += delta_content | |
| # Early termination: once </function_calls> appears, parse and finish immediately | |
| if "</function_calls>" in detector.content_buffer: | |
| logger.debug("🔧 Detected </function_calls> in stream, finalizing early...") | |
| parsed_tools = detector.finalize() | |
| if parsed_tools: | |
| validation_error = validate_parsed_tools(parsed_tools, tools or []) | |
| if validation_error: | |
| logger.info(f"🔧 Tool/schema validation failed in stream finalize: {validation_error}") | |
| parsed_tools = None | |
| if parsed_tools: | |
| logger.debug(f"🔧 Early finalize: parsed {len(parsed_tools)} tool calls") | |
| for sse in _build_tool_call_sse_chunks(parsed_tools, model, detector.content_buffer): | |
| yield sse.encode('utf-8') | |
| return | |
| else: | |
| if app_config.features.enable_fc_error_retry and original_messages: | |
| logger.info(f"🔄 Early finalize FC parsing failed, attempting retry...") | |
| retry_parsed = await _attempt_streaming_fc_retry( | |
| original_content=detector.content_buffer, | |
| trigger_signal=trigger_signal, | |
| messages=original_messages, | |
| url=url, | |
| headers=headers, | |
| model=model, | |
| timeout=app_config.server.timeout, | |
| tools=tools, | |
| ) | |
| if retry_parsed: | |
| logger.info(f"✅ Early finalize FC retry succeeded, parsed {len(retry_parsed)} tool calls") | |
| for sse in _build_tool_call_sse_chunks(retry_parsed, model, detector.content_buffer): | |
| yield sse.encode('utf-8') | |
| return | |
| else: | |
| logger.warning(f"⚠️ Early finalize FC retry also failed, ending stream") | |
| else: | |
| logger.warning( | |
| "⚠️ Early finalize detected </function_calls> but failed to parse tool calls; " | |
| "silently ending stream. buffer_len=%s preview=%r", | |
| len(detector.content_buffer), | |
| detector.content_buffer[:200], | |
| ) | |
| stop_chunk = { | |
| "id": _ensure_stream_id(), | |
| "object": "chat.completion.chunk", | |
| "created": int(time.time()), | |
| "model": model, | |
| "choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}] | |
| } | |
| yield f"data: {json.dumps(stop_chunk)}\n\n".encode('utf-8') | |
| yield b"data: [DONE]\n\n" | |
| return | |
| except (json.JSONDecodeError, IndexError): | |
| pass | |
| continue | |
| if line.startswith("data:"): | |
| line_data = line[len("data: "):].strip() | |
| if not line_data or line_data == "[DONE]": | |
| continue | |
| try: | |
| chunk_json = json.loads(line_data) | |
| delta = chunk_json.get("choices", [{}])[0].get("delta", {}) | |
| delta_content = delta.get("content", "") or "" | |
| delta_reasoning = delta.get("reasoning_content", "") or "" | |
| finish_reason = chunk_json.get("choices", [{}])[0].get("finish_reason") | |
| # Forward reasoning_content directly (it's not part of function call detection) | |
| if delta_reasoning: | |
| reasoning_chunk = { | |
| "id": _ensure_stream_id(chunk_json), | |
| "object": "chat.completion.chunk", | |
| "created": int(time.time()), | |
| "model": model, | |
| "choices": [{"index": 0, "delta": {"reasoning_content": delta_reasoning}}] | |
| } | |
| yield f"data: {json.dumps(reasoning_chunk)}\n\n".encode('utf-8') | |
| if delta_content: | |
| is_detected, content_to_yield = detector.process_chunk(delta_content) | |
| if content_to_yield: | |
| yield_chunk = { | |
| "id": _ensure_stream_id(chunk_json), | |
| "object": "chat.completion.chunk", | |
| "created": int(time.time()), | |
| "model": model, | |
| "choices": [{"index": 0, "delta": {"content": content_to_yield}}] | |
| } | |
| yield f"data: {json.dumps(yield_chunk)}\n\n".encode('utf-8') | |
| if is_detected: | |
| # Tool call signal detected, switch to parsing mode | |
| continue | |
| if finish_reason: | |
| pending_finish_reason = finish_reason | |
| except (json.JSONDecodeError, IndexError): | |
| # Ensure we always yield bytes to keep stream_with_token_count() stable | |
| yield (line + "\n\n").encode("utf-8") | |
| # Stream completed successfully, break out of retry loop | |
| break | |
| except Exception as exc: | |
| if not _is_retriable_upstream_error(exc) or attempt >= max_attempts - 1: | |
| logger.error(f"❌ Failed to connect to upstream service: {exc}") | |
| logger.error(f"❌ Error type: {type(exc).__name__}") | |
| error_message = "Failed to connect to upstream service" | |
| error_chunk = {"error": {"message": error_message, "type": "connection_error"}} | |
| yield f"data: {json.dumps(error_chunk)}\n\n".encode('utf-8') | |
| yield b"data: [DONE]\n\n" | |
| return | |
| delay = _get_upstream_retry_delay(attempt) | |
| logger.warning( | |
| "⚠️ Upstream stream (FC) failed with %s; retrying in %.1fs (%s/%s)", | |
| type(exc).__name__, delay, attempt + 2, max_attempts, | |
| ) | |
| await asyncio.sleep(delay) | |
| # Reset detector state for retry | |
| detector = StreamingFunctionCallDetector(trigger_signal) | |
| stream_id = None | |
| pending_finish_reason = None | |
| continue | |
| if detector.state == "tool_parsing": | |
| logger.debug(f"🔧 Stream ended, starting to parse tool call XML...") | |
| parsed_tools = detector.finalize() | |
| if parsed_tools: | |
| validation_error = validate_parsed_tools(parsed_tools, tools or []) | |
| if validation_error: | |
| logger.info(f"🔧 Tool/schema validation failed at stream end: {validation_error}") | |
| parsed_tools = None | |
| if parsed_tools: | |
| logger.debug(f"🔧 Streaming processing: Successfully parsed {len(parsed_tools)} tool calls") | |
| for sse in _build_tool_call_sse_chunks(parsed_tools, model, detector.content_buffer): | |
| yield sse.encode("utf-8") | |
| return | |
| else: | |
| if app_config.features.enable_fc_error_retry and original_messages: | |
| logger.info(f"🔄 Streaming FC parsing failed, attempting retry with error correction...") | |
| retry_parsed = await _attempt_streaming_fc_retry( | |
| original_content=detector.content_buffer, | |
| trigger_signal=trigger_signal, | |
| messages=original_messages, | |
| url=url, | |
| headers=headers, | |
| model=model, | |
| timeout=app_config.server.timeout, | |
| tools=tools, | |
| ) | |
| if retry_parsed: | |
| logger.info(f"✅ Streaming FC retry succeeded, parsed {len(retry_parsed)} tool calls") | |
| for sse in _build_tool_call_sse_chunks(retry_parsed, model, detector.content_buffer): | |
| yield sse.encode("utf-8") | |
| return | |
| else: | |
| logger.warning(f"⚠️ Streaming FC retry also failed, falling back to text output") | |
| else: | |
| logger.warning( | |
| "⚠️ Detected tool call signal but XML parsing failed; outputting accumulated text. " | |
| "buffer_len=%s preview=%r", | |
| len(detector.content_buffer), | |
| detector.content_buffer[:300], | |
| ) | |
| if detector.content_buffer: | |
| content_chunk = { | |
| "id": _ensure_stream_id(), | |
| "object": "chat.completion.chunk", | |
| "created": int(time.time()), | |
| "model": model, | |
| "choices": [{"index": 0, "delta": {"content": detector.content_buffer}}] | |
| } | |
| yield f"data: {json.dumps(content_chunk)}\n\n".encode('utf-8') | |
| elif detector.state == "detecting" and detector.content_buffer: | |
| # If stream has ended but buffer still has remaining characters insufficient to form signal, output them | |
| final_yield_chunk = { | |
| "id": _ensure_stream_id(), "object": "chat.completion.chunk", | |
| "created": int(time.time()), "model": model, | |
| "choices": [{"index": 0, "delta": {"content": detector.content_buffer}}] | |
| } | |
| yield f"data: {json.dumps(final_yield_chunk)}\n\n".encode('utf-8') | |
| stop_chunk = { | |
| "id": _ensure_stream_id(), | |
| "object": "chat.completion.chunk", | |
| "created": int(time.time()), | |
| "model": model, | |
| "choices": [{"index": 0, "delta": {}, "finish_reason": pending_finish_reason or "stop"}] | |
| } | |
| yield f"data: {json.dumps(stop_chunk)}\n\n".encode('utf-8') | |
| yield b"data: [DONE]\n\n" | |
| def read_root(): | |
| return { | |
| "status": "Toolify is running", | |
| "config": { | |
| "upstream_services_count": len(app_config.upstream_services), | |
| "client_keys_count": len(app_config.client_authentication.allowed_keys), | |
| "models_count": len(MODEL_TO_SERVICE_MAPPING), | |
| "features": { | |
| "function_calling": app_config.features.enable_function_calling, | |
| "log_level": app_config.features.log_level, | |
| "convert_developer_to_system": app_config.features.convert_developer_to_system, | |
| "random_trigger": True | |
| } | |
| } | |
| } | |
| async def list_models(_api_key: str = Depends(verify_api_key)): | |
| """List all available models""" | |
| visible_models = set() | |
| for model_name in MODEL_TO_SERVICE_MAPPING.keys(): | |
| if ':' in model_name: | |
| parts = model_name.split(':', 1) | |
| if len(parts) == 2: | |
| alias, _ = parts | |
| visible_models.add(alias) | |
| else: | |
| visible_models.add(model_name) | |
| else: | |
| visible_models.add(model_name) | |
| models = [] | |
| for model_id in sorted(visible_models): | |
| models.append({ | |
| "id": model_id, | |
| "object": "model", | |
| "created": 1677610602, | |
| "owned_by": "openai", | |
| "permission": [], | |
| "root": model_id, | |
| "parent": None | |
| }) | |
| return { | |
| "object": "list", | |
| "data": models | |
| } | |
| def validate_message_structure(messages: List[Dict[str, Any]]) -> bool: | |
| """Validate if message structure meets requirements""" | |
| try: | |
| valid_roles = ["system", "user", "assistant", "tool"] | |
| if not app_config.features.convert_developer_to_system: | |
| valid_roles.append("developer") | |
| for i, msg in enumerate(messages): | |
| if "role" not in msg: | |
| logger.error(f"❌ Message {i} missing role field") | |
| return False | |
| if msg["role"] not in valid_roles: | |
| logger.error(f"❌ Invalid role value for message {i}: {msg['role']}") | |
| return False | |
| if msg["role"] == "tool": | |
| if "tool_call_id" not in msg: | |
| logger.error(f"❌ Tool message {i} missing tool_call_id field") | |
| return False | |
| content = msg.get("content") | |
| content_info = "" | |
| if content: | |
| if isinstance(content, str): | |
| content_info = f", content=text({len(content)} chars)" | |
| elif isinstance(content, list): | |
| text_parts = [item for item in content if isinstance(item, dict) and item.get('type') == 'text'] | |
| image_parts = [item for item in content if isinstance(item, dict) and item.get('type') == 'image_url'] | |
| content_info = f", content=multimodal(text={len(text_parts)}, images={len(image_parts)})" | |
| else: | |
| content_info = f", content={type(content).__name__}" | |
| else: | |
| content_info = ", content=empty" | |
| logger.debug(f"✅ Message {i} validation passed: role={msg['role']}{content_info}") | |
| logger.debug(f"✅ All messages validated successfully, total {len(messages)} messages") | |
| return True | |
| except Exception as e: | |
| logger.error(f"❌ Message validation exception: {e}") | |
| return False | |
| def safe_process_tool_choice(tool_choice, tools: Optional[List[Tool]] = None) -> str: | |
| """ | |
| Process tool_choice field and return additional prompt instructions. | |
| Args: | |
| tool_choice: The tool_choice value from the request (str or ToolChoice object) | |
| tools: List of available tools (for validation when specific tool is required) | |
| Returns: | |
| Additional prompt text to append to the function calling prompt | |
| Raises: | |
| HTTPException: If tool_choice specifies a tool that doesn't exist in tools list | |
| """ | |
| try: | |
| if tool_choice is None: | |
| return "" | |
| if isinstance(tool_choice, str): | |
| if tool_choice == "none": | |
| return "\n\n**IMPORTANT:** You are prohibited from using any tools in this round. Please respond like a normal chat assistant and answer the user's question directly." | |
| elif tool_choice == "auto": | |
| # Default behavior, no additional constraints | |
| return "" | |
| elif tool_choice == "required": | |
| return "\n\n**IMPORTANT:** You MUST call at least one tool in this response. Do not respond without using tools." | |
| else: | |
| logger.warning(f"⚠️ Unknown tool_choice string value: {tool_choice}") | |
| return "" | |
| # Handle ToolChoice object: {"type": "function", "function": {"name": "xxx"}} | |
| elif hasattr(tool_choice, 'function'): | |
| function_dict = tool_choice.function | |
| if not isinstance(function_dict, dict): | |
| raise HTTPException(status_code=400, detail="tool_choice.function must be an object") | |
| required_tool_name = function_dict.get("name") | |
| if not required_tool_name or not isinstance(required_tool_name, str): | |
| raise HTTPException(status_code=400, detail="tool_choice.function.name must be a non-empty string") | |
| if not tools: | |
| raise HTTPException(status_code=400, detail="tool_choice requires a non-empty tools list in the request") | |
| tool_names = [t.function.name for t in tools] | |
| if required_tool_name not in tool_names: | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"tool_choice specifies tool '{required_tool_name}' which is not in the tools list. Available tools: {tool_names}" | |
| ) | |
| return f"\n\n**IMPORTANT:** In this round, you must use ONLY the tool named `{required_tool_name}`. Generate the necessary parameters and output in the specified XML format." | |
| else: | |
| logger.warning(f"⚠️ Unsupported tool_choice type: {type(tool_choice)}") | |
| return "" | |
| except HTTPException: | |
| # Re-raise HTTPException to preserve status code | |
| raise | |
| except Exception as e: | |
| logger.error(f"❌ Error processing tool_choice: {e}") | |
| return "" | |
| if __name__ == "__main__": | |
| import uvicorn | |
| logger.info(f"🚀 Starting server on {app_config.server.host}:{app_config.server.port}") | |
| logger.info(f"⏱️ Request timeout: {app_config.server.timeout} seconds") | |
| uvicorn.run( | |
| app, | |
| host=app_config.server.host, | |
| port=app_config.server.port, | |
| log_level=app_config.features.log_level.lower() if app_config.features.log_level != "DISABLED" else "critical" | |
| ) | |