| |
| |
| |
| |
|
|
| import os |
| import re |
| import json |
| import uuid |
| import httpx |
| import secrets |
| import string |
| import traceback |
| import time |
| import random |
| import logging |
| import tiktoken |
| import xml.etree.ElementTree as ET |
| from typing import List, Dict, Any, Optional, Literal, Union |
|
|
| from fastapi import FastAPI, Request, Header, HTTPException, Depends |
| from fastapi.responses import JSONResponse, StreamingResponse |
| from pydantic import BaseModel, ValidationError |
|
|
| from config_loader import config_loader |
|
|
| logger = logging.getLogger(__name__) |
|
|
| |
| class TokenCounter: |
| """Token counter using tiktoken""" |
| |
| |
| MODEL_PREFIX_TO_ENCODING = { |
| "o1-": "o200k_base", |
| "o3-": "o200k_base", |
| "o4-mini-": "o200k_base", |
| |
| "gpt-5-": "o200k_base", |
| "gpt-4.5-": "o200k_base", |
| "gpt-4.1-": "o200k_base", |
| "chatgpt-4o-": "o200k_base", |
| "gpt-4o-": "o200k_base", |
| "gpt-4-": "cl100k_base", |
| "gpt-3.5-turbo-": "cl100k_base", |
| "gpt-35-turbo-": "cl100k_base", |
| "gpt-oss-": "o200k_harmony", |
| |
| "ft:gpt-4o": "o200k_base", |
| "ft:gpt-4": "cl100k_base", |
| "ft:gpt-3.5-turbo": "cl100k_base", |
| "ft:davinci-002": "cl100k_base", |
| "ft:babbage-002": "cl100k_base", |
| } |
| |
| def __init__(self): |
| self.encoders = {} |
| |
| def get_encoder(self, model: str): |
| """Get or create encoder for the model""" |
| if model not in self.encoders: |
| encoding = None |
| |
| |
| try: |
| self.encoders[model] = tiktoken.encoding_for_model(model) |
| return self.encoders[model] |
| except KeyError: |
| pass |
| |
| |
| for prefix, enc_name in self.MODEL_PREFIX_TO_ENCODING.items(): |
| if model.startswith(prefix): |
| encoding = enc_name |
| break |
| |
| |
| if encoding is None: |
| logger.warning(f"Model {model} not found in prefix mapping, using o200k_base encoding") |
| encoding = "o200k_base" |
| |
| try: |
| self.encoders[model] = tiktoken.get_encoding(encoding) |
| except Exception as e: |
| logger.warning(f"Failed to get encoding {encoding} for model {model}: {e}. Falling back to cl100k_base") |
| self.encoders[model] = tiktoken.get_encoding("cl100k_base") |
| |
| return self.encoders[model] |
| |
| def count_tokens(self, messages: list, model: str = "gpt-3.5-turbo") -> int: |
| """Count tokens in message list""" |
| encoder = self.get_encoder(model) |
| |
| |
| return self._count_chat_tokens(messages, encoder, model) |
| |
| def _count_chat_tokens(self, messages: list, encoder, model: str) -> int: |
| """Accurate token calculation for chat models |
| |
| Based on OpenAI's token counting documentation: |
| - Each message has a fixed overhead |
| - Content tokens are counted per message |
| - Special tokens for message formatting |
| """ |
| |
| if model.startswith(("gpt-3.5-turbo", "gpt-35-turbo")): |
| |
| tokens_per_message = 4 |
| tokens_per_name = -1 |
| else: |
| |
| tokens_per_message = 3 |
| tokens_per_name = 1 |
| |
| num_tokens = 0 |
| for message in messages: |
| num_tokens += tokens_per_message |
| |
| |
| for key, value in message.items(): |
| if key == "content": |
| |
| if isinstance(value, list): |
| for item in value: |
| if isinstance(item, dict) and item.get("type") == "text": |
| content_text = item.get("text", "") |
| num_tokens += len(encoder.encode(content_text, disallowed_special=())) |
| |
| elif isinstance(value, str): |
| num_tokens += len(encoder.encode(value, disallowed_special=())) |
| elif key == "name": |
| num_tokens += tokens_per_name |
| if isinstance(value, str): |
| num_tokens += len(encoder.encode(value, disallowed_special=())) |
| elif key == "role": |
| |
| pass |
| elif isinstance(value, str): |
| |
| num_tokens += len(encoder.encode(value, disallowed_special=())) |
| |
| |
| num_tokens += 3 |
| return num_tokens |
| |
| def count_text_tokens(self, text: str, model: str = "gpt-3.5-turbo") -> int: |
| """Count tokens in plain text""" |
| encoder = self.get_encoder(model) |
| return len(encoder.encode(text, disallowed_special=())) |
|
|
| |
| token_counter = TokenCounter() |
|
|
| def generate_random_trigger_signal() -> str: |
| """Generate a random, self-closing trigger signal like <Function_AB1c_Start/>.""" |
| chars = string.ascii_letters + string.digits |
| random_str = ''.join(secrets.choice(chars) for _ in range(4)) |
| return f"<Function_{random_str}_Start/>" |
|
|
| try: |
| app_config = config_loader.load_config() |
| |
| log_level_str = app_config.features.log_level |
| if log_level_str == "DISABLED": |
| log_level = logging.CRITICAL + 1 |
| else: |
| log_level = getattr(logging, log_level_str, logging.INFO) |
| |
| logging.basicConfig( |
| level=log_level, |
| format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', |
| datefmt='%Y-%m-%d %H:%M:%S' |
| ) |
| |
| logger.info(f"✅ Configuration loaded successfully: {config_loader.config_path}") |
| logger.info(f"📊 Configured {len(app_config.upstream_services)} upstream services") |
| logger.info(f"🔑 Configured {len(app_config.client_authentication.allowed_keys)} client keys") |
| |
| MODEL_TO_SERVICE_MAPPING, ALIAS_MAPPING = config_loader.get_model_to_service_mapping() |
| DEFAULT_SERVICE = config_loader.get_default_service() |
| ALLOWED_CLIENT_KEYS = config_loader.get_allowed_client_keys() |
| GLOBAL_TRIGGER_SIGNAL = generate_random_trigger_signal() |
| |
| logger.info(f"🎯 Configured {len(MODEL_TO_SERVICE_MAPPING)} model mappings") |
| if ALIAS_MAPPING: |
| logger.info(f"🔄 Configured {len(ALIAS_MAPPING)} model aliases: {list(ALIAS_MAPPING.keys())}") |
| logger.info(f"🔄 Default service: {DEFAULT_SERVICE['name']}") |
| |
| except Exception as e: |
| logger.error(f"❌ Configuration loading failed: {type(e).__name__}") |
| logger.error(f"❌ Error details: {str(e)}") |
| logger.error("💡 Please ensure config.yaml file exists and is properly formatted") |
| exit(1) |
| def build_tool_call_index_from_messages(messages: List[Dict[str, Any]]) -> Dict[str, Dict[str, str]]: |
| """ |
| Build tool_call_id -> {name, arguments} index from message history. |
| This replaces the server-side mapping by extracting tool calls from assistant messages. |
| |
| Args: |
| messages: List of message dicts from the request |
| |
| Returns: |
| Dict mapping tool_call_id to {name, arguments} |
| """ |
| index = {} |
| for msg in messages: |
| if isinstance(msg, dict) and msg.get("role") == "assistant": |
| tool_calls = msg.get("tool_calls") |
| if tool_calls and isinstance(tool_calls, list): |
| for tc in tool_calls: |
| if isinstance(tc, dict): |
| tc_id = tc.get("id") |
| func = tc.get("function", {}) |
| if tc_id and isinstance(func, dict): |
| name = func.get("name", "") |
| arguments = func.get("arguments", "{}") |
| if not isinstance(arguments, str): |
| try: |
| arguments = json.dumps(arguments, ensure_ascii=False) |
| except Exception: |
| arguments = str(arguments) |
|
|
| if name: |
| index[tc_id] = { |
| "name": name, |
| "arguments": arguments |
| } |
| logger.debug(f"🔧 Indexed tool_call_id: {tc_id} -> {name}") |
| |
| logger.debug(f"🔧 Built tool_call index with {len(index)} entries") |
| return index |
|
|
| def get_fc_error_retry_prompt(original_response: str, error_details: str) -> str: |
| custom_template = app_config.features.fc_error_retry_prompt_template |
| if custom_template: |
| return custom_template.format( |
| original_response=original_response, |
| error_details=error_details |
| ) |
| |
| return f"""Your previous response attempted to make a function call but the format was invalid or could not be parsed. |
| |
| **Your original response:** |
| ``` |
| {original_response} |
| ``` |
| |
| **Error details:** |
| {error_details} |
| |
| **Instructions:** |
| Please retry and output the function call in the correct XML format. Remember: |
| 1. Start with the trigger signal on its own line |
| 2. Immediately follow with the <function_calls> XML block |
| 3. Use <args_json> with valid JSON for parameters |
| 4. Do not add any text after </function_calls> |
| |
| Please provide the corrected function call now. DO NOT OUTPUT ANYTHING ELSE.""" |
|
|
|
|
| def _schema_type_name(v: Any) -> str: |
| if v is None: |
| return "null" |
| if isinstance(v, bool): |
| return "boolean" |
| if isinstance(v, int) and not isinstance(v, bool): |
| return "integer" |
| if isinstance(v, float): |
| return "number" |
| if isinstance(v, str): |
| return "string" |
| if isinstance(v, list): |
| return "array" |
| if isinstance(v, dict): |
| return "object" |
| return type(v).__name__ |
|
|
|
|
| def _validate_value_against_schema(value: Any, schema: Dict[str, Any], path: str = "args", depth: int = 0) -> List[str]: |
| """Best-effort JSON Schema validation for tool arguments. |
| |
| Intentional subset: |
| - type, properties, required, additionalProperties |
| - items (array) |
| - enum, const |
| - anyOf/oneOf/allOf (basic) |
| - pattern/minLength/maxLength (string) |
| |
| Returns a list of human-readable errors. |
| """ |
| if schema is None: |
| schema = {} |
| if depth > 8: |
| return [] |
|
|
| errors: List[str] = [] |
|
|
| |
| if isinstance(schema.get("allOf"), list): |
| for idx, sub in enumerate(schema["allOf"]): |
| errors.extend(_validate_value_against_schema(value, sub or {}, f"{path}.allOf[{idx}]", depth + 1)) |
| return errors |
|
|
| if isinstance(schema.get("anyOf"), list): |
| option_errors = [ |
| _validate_value_against_schema(value, sub or {}, path, depth + 1) |
| for sub in schema["anyOf"] |
| ] |
| if not any(len(e) == 0 for e in option_errors): |
| errors.append(f"{path}: value does not satisfy anyOf options") |
| return errors |
|
|
| if isinstance(schema.get("oneOf"), list): |
| option_errors = [ |
| _validate_value_against_schema(value, sub or {}, path, depth + 1) |
| for sub in schema["oneOf"] |
| ] |
| ok_count = sum(1 for e in option_errors if len(e) == 0) |
| if ok_count != 1: |
| errors.append(f"{path}: value must satisfy exactly one oneOf option (matched {ok_count})") |
| return errors |
|
|
| |
| if "const" in schema: |
| if value != schema.get("const"): |
| errors.append(f"{path}: expected const={schema.get('const')!r}, got {value!r}") |
| return errors |
|
|
| enum_vals = schema.get("enum") |
| if isinstance(enum_vals, list): |
| if value not in enum_vals: |
| errors.append(f"{path}: expected one of {enum_vals!r}, got {value!r}") |
| return errors |
|
|
| stype = schema.get("type") |
| if stype is None: |
| |
| if any(k in schema for k in ("properties", "required", "additionalProperties")): |
| stype = "object" |
|
|
| |
| def _type_ok(t: str) -> bool: |
| if t == "object": |
| return isinstance(value, dict) |
| if t == "array": |
| return isinstance(value, list) |
| if t == "string": |
| return isinstance(value, str) |
| if t == "boolean": |
| return isinstance(value, bool) |
| if t == "integer": |
| return isinstance(value, int) and not isinstance(value, bool) |
| if t == "number": |
| return (isinstance(value, (int, float)) and not isinstance(value, bool)) |
| if t == "null": |
| return value is None |
| return True |
|
|
| if isinstance(stype, str): |
| if not _type_ok(stype): |
| errors.append(f"{path}: expected type '{stype}', got '{_schema_type_name(value)}'") |
| return errors |
| elif isinstance(stype, list): |
| if not any(_type_ok(t) for t in stype if isinstance(t, str)): |
| errors.append(f"{path}: expected type in {stype!r}, got '{_schema_type_name(value)}'") |
| return errors |
|
|
| |
| if isinstance(value, str): |
| min_len = schema.get("minLength") |
| max_len = schema.get("maxLength") |
| if isinstance(min_len, int) and len(value) < min_len: |
| errors.append(f"{path}: string shorter than minLength={min_len}") |
| if isinstance(max_len, int) and len(value) > max_len: |
| errors.append(f"{path}: string longer than maxLength={max_len}") |
| pat = schema.get("pattern") |
| if isinstance(pat, str): |
| try: |
| if re.search(pat, value) is None: |
| errors.append(f"{path}: string does not match pattern {pat!r}") |
| except re.error: |
| |
| pass |
|
|
| |
| if isinstance(value, dict): |
| props = schema.get("properties") |
| if props is None: |
| props = {} |
| if not isinstance(props, dict): |
| props = {} |
| required = schema.get("required") |
| if required is None: |
| required = [] |
| if not isinstance(required, list): |
| required = [] |
| required = [k for k in required if isinstance(k, str)] |
|
|
| for k in required: |
| if k not in value: |
| errors.append(f"{path}: missing required property '{k}'") |
|
|
| additional = schema.get("additionalProperties", True) |
|
|
| for k, v in value.items(): |
| if k in props: |
| errors.extend(_validate_value_against_schema(v, props.get(k) or {}, f"{path}.{k}", depth + 1)) |
| else: |
| if additional is False: |
| errors.append(f"{path}: unexpected property '{k}'") |
| elif isinstance(additional, dict): |
| errors.extend(_validate_value_against_schema(v, additional, f"{path}.{k}", depth + 1)) |
|
|
| |
| if isinstance(value, list): |
| items = schema.get("items") |
| if isinstance(items, dict): |
| for i, v in enumerate(value): |
| errors.extend(_validate_value_against_schema(v, items, f"{path}[{i}]", depth + 1)) |
|
|
| return errors |
|
|
|
|
| def validate_parsed_tools(parsed_tools: List[Dict[str, Any]], tools: List["Tool"]) -> Optional[str]: |
| """Validate parsed tool calls against declared tools definitions. |
| |
| Returns a single error string if invalid, else None. |
| """ |
| tools = tools or [] |
| allowed = {t.function.name: (t.function.parameters or {}) for t in tools if t and t.function and t.function.name} |
| allowed_names = sorted(list(allowed.keys())) |
|
|
| for idx, call in enumerate(parsed_tools or []): |
| name = (call or {}).get("name") |
| args = (call or {}).get("args") |
|
|
| if not isinstance(name, str) or not name: |
| return f"Tool call #{idx + 1}: missing tool name" |
|
|
| if name not in allowed: |
| return ( |
| f"Tool call #{idx + 1}: unknown tool '{name}'. " |
| f"Allowed tools: {allowed_names}" |
| ) |
|
|
| if not isinstance(args, dict): |
| return f"Tool call #{idx + 1} '{name}': arguments must be a JSON object, got {_schema_type_name(args)}" |
|
|
| schema = allowed[name] or {} |
| errs = _validate_value_against_schema(args, schema, path=f"{name}") |
| if errs: |
| |
| preview = "; ".join(errs[:6]) |
| more = f" (+{len(errs) - 6} more)" if len(errs) > 6 else "" |
| return f"Tool call #{idx + 1} '{name}': schema validation failed: {preview}{more}" |
|
|
| return None |
|
|
|
|
| async def attempt_fc_parse_with_retry( |
| content: str, |
| trigger_signal: str, |
| messages: List[Dict[str, Any]], |
| upstream_url: str, |
| headers: Dict[str, str], |
| model: str, |
| tools: List["Tool"], |
| timeout: int |
| ) -> Optional[List[Dict[str, Any]]]: |
| """ |
| Attempt to parse function calls from content. If parsing fails and retry is enabled, |
| send error details back to the model for correction. |
| |
| Returns parsed tool calls or None if parsing ultimately fails. |
| """ |
| def _parse_and_validate(current_content: str) -> tuple[Optional[List[Dict[str, Any]]], Optional[str]]: |
| parsed = parse_function_calls_xml(current_content, trigger_signal) |
| if not parsed: |
| return None, None |
| validation_error = validate_parsed_tools(parsed, tools) |
| if validation_error: |
| return None, validation_error |
| return parsed, None |
|
|
| if not app_config.features.enable_fc_error_retry: |
| parsed, _err = _parse_and_validate(content) |
| return parsed |
| |
| max_attempts = app_config.features.fc_error_retry_max_attempts |
| current_content = content |
| current_messages = messages.copy() |
| |
| for attempt in range(max_attempts): |
| parsed_tools, validation_error = _parse_and_validate(current_content) |
|
|
| if parsed_tools: |
| if attempt > 0: |
| logger.info(f"✅ Function call parsing succeeded on retry attempt {attempt + 1}") |
| return parsed_tools |
| |
| |
| |
| |
| if find_last_trigger_signal_outside_think(current_content, trigger_signal) == -1: |
| logger.debug("🔧 No trigger signal found outside <think> blocks; not a function call attempt") |
| return None |
| |
| if attempt >= max_attempts - 1: |
| logger.warning(f"⚠️ Function call parsing failed after {max_attempts} attempts") |
| return None |
| |
| |
| error_details = validation_error or _diagnose_fc_parse_error(current_content, trigger_signal) |
| retry_prompt = get_fc_error_retry_prompt(current_content, error_details) |
| |
| logger.info(f"🔄 Function call parsing failed, attempting retry {attempt + 2}/{max_attempts}") |
| logger.debug(f"🔧 Error details: {error_details}") |
| |
| retry_messages = current_messages + [ |
| {"role": "assistant", "content": current_content}, |
| {"role": "user", "content": retry_prompt} |
| ] |
| |
| try: |
| retry_response = await http_client.post( |
| upstream_url, |
| json={"model": model, "messages": retry_messages, "stream": False}, |
| headers=headers, |
| timeout=timeout |
| ) |
| retry_response.raise_for_status() |
| retry_json = retry_response.json() |
| |
| if retry_json.get("choices") and len(retry_json["choices"]) > 0: |
| current_content = retry_json["choices"][0].get("message", {}).get("content", "") |
| current_messages = retry_messages |
| logger.debug(f"🔧 Received retry response, length: {len(current_content)}") |
| else: |
| logger.warning(f"⚠️ Retry response has no valid choices") |
| return None |
| |
| except Exception as e: |
| logger.error(f"❌ Retry request failed: {e}") |
| return None |
| |
| return None |
|
|
|
|
| def _diagnose_fc_parse_error(content: str, trigger_signal: str) -> str: |
| """Diagnose why function call parsing failed and return error description.""" |
| errors = [] |
| |
| if trigger_signal not in content: |
| errors.append(f"Trigger signal '{trigger_signal[:30]}...' not found in response") |
| return "; ".join(errors) |
| |
| cleaned = remove_think_blocks(content) |
| |
| if "<function_calls>" not in cleaned: |
| errors.append("Missing <function_calls> tag after trigger signal") |
| elif "</function_calls>" not in cleaned: |
| errors.append("Missing closing </function_calls> tag") |
| |
| if "<function_call>" not in cleaned: |
| errors.append("No <function_call> blocks found inside <function_calls>") |
| elif "</function_call>" not in cleaned: |
| errors.append("Missing closing </function_call> tag") |
| |
| fc_match = re.search(r"<function_calls>([\s\S]*?)</function_calls>", cleaned) |
| if fc_match: |
| fc_content = fc_match.group(1) |
| |
| if "<tool>" not in fc_content: |
| errors.append("Missing <tool> tag inside function_call") |
| |
| if "<args_json>" not in fc_content and "<args>" not in fc_content: |
| errors.append("Missing <args_json> or <args> tag inside function_call") |
| |
| args_json_match = re.search(r"<args_json>([\s\S]*?)</args_json>", fc_content) |
| if args_json_match: |
| args_content = args_json_match.group(1).strip() |
| cdata_match = re.search(r"<!\[CDATA\[([\s\S]*?)\]\]>", args_content) |
| json_to_parse = cdata_match.group(1) if cdata_match else args_content |
| |
| try: |
| parsed = json.loads(json_to_parse) |
| if not isinstance(parsed, dict): |
| errors.append(f"args_json must be a JSON object, got {type(parsed).__name__}") |
| except json.JSONDecodeError as e: |
| errors.append(f"Invalid JSON in args_json: {str(e)}") |
| |
| if not errors: |
| errors.append("XML structure appears correct but parsing failed for unknown reason") |
| |
| return "; ".join(errors) |
|
|
|
|
| def format_tool_result_for_ai(tool_name: str, tool_arguments: str, result_content: str) -> str: |
| """ |
| Format tool call results for AI understanding with complete context. |
| |
| Args: |
| tool_name: Name of the tool that was called |
| tool_arguments: Arguments passed to the tool (JSON string) |
| result_content: Execution result from the tool |
| |
| Returns: |
| Formatted text for upstream model |
| """ |
| formatted_text = f"""Tool execution result: |
| - Tool name: {tool_name} |
| - Tool arguments: {tool_arguments} |
| - Execution result: |
| <tool_result> |
| {result_content} |
| </tool_result>""" |
| |
| logger.debug(f"🔧 Formatted tool result for {tool_name}") |
| return formatted_text |
|
|
| def format_assistant_tool_calls_for_ai(tool_calls: List[Dict[str, Any]], trigger_signal: str) -> str: |
| """Format assistant tool calls into AI-readable string format.""" |
| logger.debug(f"🔧 Formatting assistant tool calls. Count: {len(tool_calls)}") |
|
|
| def _wrap_cdata(text: str) -> str: |
| |
| safe = (text or "").replace("]]>", "]]]]><![CDATA[>") |
| return f"<![CDATA[{safe}]]>" |
| |
| xml_calls_parts = [] |
| for tool_call in tool_calls: |
| function_info = tool_call.get("function", {}) |
| name = function_info.get("name", "") |
| arguments_val = function_info.get("arguments", "{}") |
|
|
| |
| try: |
| if isinstance(arguments_val, dict): |
| args_dict = arguments_val |
| elif isinstance(arguments_val, str): |
| parsed = json.loads(arguments_val or "{}") |
| if not isinstance(parsed, dict): |
| raise ValueError(f"arguments must be a JSON object, got {type(parsed).__name__}") |
| args_dict = parsed |
| else: |
| raise ValueError(f"arguments must be a JSON object or JSON string, got {type(arguments_val).__name__}") |
| except Exception as e: |
| raise HTTPException( |
| status_code=400, |
| detail=f"Invalid assistant.tool_calls arguments for tool '{name}': {e}" |
| ) |
|
|
| args_payload = json.dumps(args_dict, ensure_ascii=False) |
| xml_call = ( |
| f"<function_call>\n" |
| f"<tool>{name}</tool>\n" |
| f"<args_json>{_wrap_cdata(args_payload)}</args_json>\n" |
| f"</function_call>" |
| ) |
| xml_calls_parts.append(xml_call) |
|
|
| all_calls = "\n".join(xml_calls_parts) |
| final_str = f"{trigger_signal}\n<function_calls>\n{all_calls}\n</function_calls>" |
| |
| logger.debug("🔧 Assistant tool calls formatted successfully.") |
| return final_str |
|
|
| def get_function_call_prompt_template(trigger_signal: str) -> str: |
| """ |
| Generate prompt template based on dynamic trigger signal |
| """ |
| custom_template = app_config.features.prompt_template |
| if custom_template: |
| logger.info("🔧 Using custom prompt template from configuration") |
| return custom_template.format( |
| trigger_signal=trigger_signal, |
| tools_list="{tools_list}" |
| ) |
| |
| return f""" |
| You have access to the following available tools to help solve problems: |
| |
| {{tools_list}} |
| |
| **IMPORTANT CONTEXT NOTES:** |
| 1. You can call MULTIPLE tools in a single response if needed. |
| 2. Even though you can call multiple tools, you MUST respect the user's later constraints and preferences (e.g., the user may request no tools, only one tool, or a specific tool/workflow). |
| 3. The conversation context may already contain tool execution results from previous function calls. Review the conversation history carefully to avoid unnecessary duplicate tool calls. |
| 4. When tool execution results are present in the context, they will be formatted with XML tags like <tool_result>...</tool_result> for easy identification. |
| 5. This is the ONLY format you can use for tool calls, and any deviation will result in failure. |
| |
| When you need to use tools, you **MUST** strictly follow this format. Do NOT include any extra text, explanations, or dialogue on the first and second lines of the tool call syntax: |
| |
| 1. When starting tool calls, begin on a new line with exactly: |
| {trigger_signal} |
| No leading or trailing spaces, output exactly as shown above. The trigger signal MUST be on its own line and appear only once. Do not output a trigger signal for each tool call. |
| |
| 2. Starting from the second line, **immediately** follow with the complete <function_calls> XML block. |
| |
| 3. For multiple tool calls, include multiple <function_call> blocks within the same <function_calls> wrapper, not separate blocks. Output the trigger signal only once, then one <function_calls> with all <function_call> children. |
| |
| 4. Do not add any text or explanation after the closing </function_calls> tag. |
| |
| STRICT ARGUMENT KEY RULES: |
| - You MUST use parameter keys EXACTLY as defined (case- and punctuation-sensitive). Do NOT rename, add, or remove characters. |
| - If a key starts with a hyphen (e.g., "-i", "-C"), you MUST keep the leading hyphen in the JSON key. Never convert "-i" to "i" or "-C" to "C". |
| - The <tool> tag must contain the exact name of a tool from the list. Any other tool name is invalid. |
| - The <args_json> tag must contain a single JSON object with all required arguments for that tool. |
| - You MAY wrap the JSON content inside <![CDATA[...]]> to avoid XML escaping issues. |
| |
| CORRECT Example (multiple tool calls): |
| ...response content (optional)... |
| {trigger_signal} |
| <function_calls> |
| <function_call> |
| <tool>Grep</tool> |
| <args_json><![CDATA[{{"-i": true, "-C": 2, "path": "."}}]]></args_json> |
| </function_call> |
| <function_call> |
| <tool>search</tool> |
| <args_json><![CDATA[{{"keywords": ["Python Document", "how to use python"]}}]]></args_json> |
| </function_call> |
| </function_calls> |
| |
| INCORRECT Example (extra text + wrong key names — DO NOT DO THIS): |
| ...response content (optional)... |
| {trigger_signal} |
| I will call the tools for you. |
| <function_calls> |
| <function_call> |
| <tool>Grep</tool> |
| <args> |
| <i>true</i> |
| <C>2</C> |
| <path>.</path> |
| </args> |
| </function_call> |
| </function_calls> |
| |
| INCORRECT Example (output non-XML format — DO NOT DO THIS): |
| ...response content (optional)... |
| ```json |
| {{"files":[{{"path":"system.py"}}]}} |
| ``` |
| |
| Now please be ready to strictly follow the above specifications. |
| """ |
|
|
| class ToolFunction(BaseModel): |
| name: str |
| description: Optional[str] = None |
| parameters: Dict[str, Any] |
|
|
| class Tool(BaseModel): |
| type: Literal["function"] |
| function: ToolFunction |
|
|
| class Message(BaseModel): |
| role: str |
| content: Optional[str] = None |
| tool_calls: Optional[List[Dict[str, Any]]] = None |
| tool_call_id: Optional[str] = None |
| name: Optional[str] = None |
| |
| class Config: |
| extra = "allow" |
|
|
| class ToolChoice(BaseModel): |
| type: Literal["function"] |
| function: Dict[str, str] |
|
|
| class ChatCompletionRequest(BaseModel): |
| model: str |
| messages: List[Dict[str, Any]] |
| tools: Optional[List[Tool]] = None |
| tool_choice: Optional[Union[str, ToolChoice]] = None |
| stream: Optional[bool] = False |
| stream_options: Optional[Dict[str, Any]] = None |
| temperature: Optional[float] = None |
| max_tokens: Optional[int] = None |
| top_p: Optional[float] = None |
| frequency_penalty: Optional[float] = None |
| presence_penalty: Optional[float] = None |
| n: Optional[int] = None |
| stop: Optional[Union[str, List[str]]] = None |
| |
| class Config: |
| extra = "allow" |
|
|
|
|
| def generate_function_prompt(tools: List[Tool], trigger_signal: str) -> tuple[str, str]: |
| """ |
| Generate injected system prompt based on tools definition in client request. |
| Returns: (prompt_content, trigger_signal) |
| |
| Raises: |
| HTTPException: If tool schema validation fails (e.g., required keys not in properties) |
| """ |
| tools_list_str = [] |
| for i, tool in enumerate(tools): |
| func = tool.function |
| name = func.name |
| description = func.description or "" |
|
|
| |
| schema: Dict[str, Any] = func.parameters or {} |
|
|
| props_raw = schema.get("properties", {}) |
| if props_raw is None: |
| props_raw = {} |
| if not isinstance(props_raw, dict): |
| raise HTTPException( |
| status_code=400, |
| detail=f"Tool '{name}': 'properties' must be an object, got {type(props_raw).__name__}" |
| ) |
| props: Dict[str, Any] = props_raw |
|
|
| required_raw = schema.get("required", []) |
| if required_raw is None: |
| required_raw = [] |
| if not isinstance(required_raw, list): |
| raise HTTPException( |
| status_code=400, |
| detail=f"Tool '{name}': 'required' must be a list, got {type(required_raw).__name__}" |
| ) |
|
|
| non_string_required = [k for k in required_raw if not isinstance(k, str)] |
| if non_string_required: |
| raise HTTPException( |
| status_code=400, |
| detail=f"Tool '{name}': 'required' entries must be strings, got {non_string_required}" |
| ) |
|
|
| required_list: List[str] = required_raw |
|
|
| missing_keys = [key for key in required_list if key not in props] |
| if missing_keys: |
| raise HTTPException( |
| status_code=400, |
| detail=f"Tool '{name}': required parameters {missing_keys} are not defined in properties" |
| ) |
|
|
| |
| params_summary = ", ".join([ |
| f"{p_name} ({(p_info or {}).get('type', 'any')})" for p_name, p_info in props.items() |
| ]) or "None" |
|
|
| |
| detail_lines: List[str] = [] |
| for p_name, p_info in props.items(): |
| p_info = p_info or {} |
| p_type = p_info.get("type", "any") |
| is_required = "Yes" if p_name in required_list else "No" |
| p_desc = p_info.get("description") |
| enum_vals = p_info.get("enum") |
| default_val = p_info.get("default") |
| examples_val = p_info.get("examples") or p_info.get("example") |
|
|
| |
| constraints: Dict[str, Any] = {} |
| for key in [ |
| "minimum", "maximum", "exclusiveMinimum", "exclusiveMaximum", |
| "minLength", "maxLength", "pattern", "format", |
| "minItems", "maxItems", "uniqueItems" |
| ]: |
| if key in p_info: |
| constraints[key] = p_info.get(key) |
|
|
| |
| if p_type == "array": |
| items = p_info.get("items") or {} |
| if isinstance(items, dict): |
| itype = items.get("type") |
| if itype: |
| constraints["items.type"] = itype |
|
|
| |
| detail_lines.append(f"- {p_name}:") |
| detail_lines.append(f" - type: {p_type}") |
| detail_lines.append(f" - required: {is_required}") |
| if p_desc: |
| detail_lines.append(f" - description: {p_desc}") |
| if enum_vals is not None: |
| try: |
| detail_lines.append(f" - enum: {json.dumps(enum_vals, ensure_ascii=False)}") |
| except Exception: |
| detail_lines.append(f" - enum: {enum_vals}") |
| if default_val is not None: |
| try: |
| detail_lines.append(f" - default: {json.dumps(default_val, ensure_ascii=False)}") |
| except Exception: |
| detail_lines.append(f" - default: {default_val}") |
| if examples_val is not None: |
| try: |
| detail_lines.append(f" - examples: {json.dumps(examples_val, ensure_ascii=False)}") |
| except Exception: |
| detail_lines.append(f" - examples: {examples_val}") |
| if constraints: |
| try: |
| detail_lines.append(f" - constraints: {json.dumps(constraints, ensure_ascii=False)}") |
| except Exception: |
| detail_lines.append(f" - constraints: {constraints}") |
|
|
| detail_block = "\n".join(detail_lines) if detail_lines else "(no parameter details)" |
|
|
| desc_block = f"```\n{description}\n```" if description else "None" |
|
|
| tools_list_str.append( |
| f"{i + 1}. <tool name=\"{name}\">\n" |
| f" Description:\n{desc_block}\n" |
| f" Parameters summary: {params_summary}\n" |
| f" Required parameters: {', '.join(required_list) if required_list else 'None'}\n" |
| f" Parameter details:\n{detail_block}" |
| ) |
| |
| prompt_template = get_function_call_prompt_template(trigger_signal) |
| prompt_content = prompt_template.replace("{tools_list}", "\n\n".join(tools_list_str)) |
| |
| return prompt_content, trigger_signal |
|
|
| def remove_think_blocks(text: str) -> str: |
| """ |
| Temporarily remove all <think>...</think> blocks for XML parsing |
| Supports nested think tags |
| Note: This function is only used for temporary parsing and does not affect the original content returned to the user |
| """ |
| while '<think>' in text and '</think>' in text: |
| start_pos = text.find('<think>') |
| if start_pos == -1: |
| break |
| |
| pos = start_pos + 7 |
| depth = 1 |
| |
| while pos < len(text) and depth > 0: |
| if text[pos:pos+7] == '<think>': |
| depth += 1 |
| pos += 7 |
| elif text[pos:pos+8] == '</think>': |
| depth -= 1 |
| pos += 8 |
| else: |
| pos += 1 |
| |
| if depth == 0: |
| text = text[:start_pos] + text[pos:] |
| else: |
| break |
| |
| return text |
|
|
| def find_last_trigger_signal_outside_think(text: str, trigger_signal: str) -> int: |
| """ |
| Find the last occurrence position of trigger_signal that is NOT inside any <think>...</think> block. |
| Returns -1 if not found. |
| """ |
| if not text or not trigger_signal: |
| return -1 |
|
|
| i = 0 |
| think_depth = 0 |
| last_pos = -1 |
|
|
| while i < len(text): |
| if text.startswith("<think>", i): |
| think_depth += 1 |
| i += 7 |
| continue |
|
|
| if text.startswith("</think>", i): |
| think_depth = max(0, think_depth - 1) |
| i += 8 |
| continue |
|
|
| if think_depth == 0 and text.startswith(trigger_signal, i): |
| last_pos = i |
| |
| i += 1 |
| continue |
|
|
| i += 1 |
|
|
| return last_pos |
|
|
| class StreamingFunctionCallDetector: |
| """Enhanced streaming function call detector, supports dynamic trigger signals, avoids misjudgment within <think> tags |
| |
| Core features: |
| 1. Avoid triggering tool call detection within <think> blocks |
| 2. Normally output <think> block content to the user |
| 3. Supports nested think tags |
| """ |
| |
| def __init__(self, trigger_signal: str): |
| self.trigger_signal = trigger_signal |
| self.reset() |
| |
| def reset(self): |
| self.content_buffer = "" |
| self.state = "detecting" |
| self.in_think_block = False |
| self.think_depth = 0 |
| self.signal = self.trigger_signal |
| self.signal_len = len(self.signal) |
| |
| def process_chunk(self, delta_content: str) -> tuple[bool, str]: |
| """ |
| Process streaming content chunk |
| Returns: (is_tool_call_detected, content_to_yield) |
| """ |
| if not delta_content: |
| return False, "" |
| |
| self.content_buffer += delta_content |
| content_to_yield = "" |
| |
| if self.state == "tool_parsing": |
| return False, "" |
| |
| if delta_content: |
| logger.debug(f"🔧 Processing chunk: {repr(delta_content[:50])}{'...' if len(delta_content) > 50 else ''}, buffer length: {len(self.content_buffer)}, think state: {self.in_think_block}") |
| |
| i = 0 |
| while i < len(self.content_buffer): |
| skip_chars = self._update_think_state(i) |
| if skip_chars > 0: |
| for j in range(skip_chars): |
| if i + j < len(self.content_buffer): |
| content_to_yield += self.content_buffer[i + j] |
| i += skip_chars |
| continue |
| |
| if not self.in_think_block and self._can_detect_signal_at(i): |
| if self.content_buffer[i:i+self.signal_len] == self.signal: |
| logger.debug(f"🔧 Improved detector: detected trigger signal in non-think block! Signal: {self.signal[:20]}...") |
| logger.debug(f"🔧 Trigger signal position: {i}, think state: {self.in_think_block}, think depth: {self.think_depth}") |
| self.state = "tool_parsing" |
| self.content_buffer = self.content_buffer[i:] |
| return True, content_to_yield |
| |
| remaining_len = len(self.content_buffer) - i |
| if remaining_len < self.signal_len or remaining_len < 8: |
| break |
| |
| content_to_yield += self.content_buffer[i] |
| i += 1 |
| |
| self.content_buffer = self.content_buffer[i:] |
| return False, content_to_yield |
| |
| def _update_think_state(self, pos: int): |
| """Update think tag state, supports nesting""" |
| remaining = self.content_buffer[pos:] |
| |
| if remaining.startswith('<think>'): |
| self.think_depth += 1 |
| self.in_think_block = True |
| logger.debug(f"🔧 Entering think block, depth: {self.think_depth}") |
| return 7 |
| |
| elif remaining.startswith('</think>'): |
| self.think_depth = max(0, self.think_depth - 1) |
| self.in_think_block = self.think_depth > 0 |
| logger.debug(f"🔧 Exiting think block, depth: {self.think_depth}") |
| return 8 |
| |
| return 0 |
| |
| def _can_detect_signal_at(self, pos: int) -> bool: |
| """Check if signal can be detected at the specified position""" |
| return (pos + self.signal_len <= len(self.content_buffer) and |
| not self.in_think_block) |
| |
| def finalize(self) -> Optional[List[Dict[str, Any]]]: |
| """Final processing when stream ends""" |
| if self.state == "tool_parsing": |
| return parse_function_calls_xml(self.content_buffer, self.trigger_signal) |
| return None |
|
|
| def parse_function_calls_xml(xml_string: str, trigger_signal: str) -> Optional[List[Dict[str, Any]]]: |
| """ |
| Enhanced XML parsing function, supports dynamic trigger signals |
| 1. Retain <think>...</think> blocks (they should be returned normally to the user) |
| 2. Temporarily remove think blocks only when parsing function_calls to prevent think content from interfering with XML parsing |
| 3. Find the last occurrence of the trigger signal |
| 4. Start parsing function_calls from the last trigger signal |
| """ |
| logger.debug(f"🔧 Improved parser starting processing, input length: {len(xml_string) if xml_string else 0}") |
| logger.debug(f"🔧 Using trigger signal: {trigger_signal[:20]}...") |
| |
| if not xml_string or trigger_signal not in xml_string: |
| logger.debug(f"🔧 Input is empty or doesn't contain trigger signal") |
| return None |
| |
| cleaned_content = remove_think_blocks(xml_string) |
| logger.debug(f"🔧 Content length after temporarily removing think blocks: {len(cleaned_content)}") |
| |
| signal_positions = [] |
| start_pos = 0 |
| while True: |
| pos = cleaned_content.find(trigger_signal, start_pos) |
| if pos == -1: |
| break |
| signal_positions.append(pos) |
| start_pos = pos + 1 |
| |
| if not signal_positions: |
| logger.debug(f"🔧 No trigger signal found in cleaned content") |
| return None |
| |
| logger.debug(f"🔧 Found {len(signal_positions)} trigger signal positions: {signal_positions}") |
| |
| chosen_signal_index = None |
| chosen_signal_pos = None |
| calls_content_match = None |
|
|
| for idx in range(len(signal_positions) - 1, -1, -1): |
| pos = signal_positions[idx] |
| sub = cleaned_content[pos:] |
| m = re.search(r"<function_calls>([\s\S]*?)</function_calls>", sub) |
| if m: |
| chosen_signal_index = idx |
| chosen_signal_pos = pos |
| calls_content_match = m |
| logger.debug(f"🔧 Using trigger signal index {idx} at pos {pos}; content preview: {repr(sub[:100])}") |
| break |
|
|
| if calls_content_match is None: |
| logger.debug(f"🔧 No function_calls tag found after any trigger signal (triggers={len(signal_positions)})") |
| return None |
|
|
| calls_xml = calls_content_match.group(0) |
| calls_content = calls_content_match.group(1) |
| logger.debug(f"🔧 function_calls content: {repr(calls_content)}") |
|
|
| def _coerce_value(v: str): |
| try: |
| return json.loads(v) |
| except Exception: |
| return v |
|
|
| def _parse_args_json_payload(payload: str) -> Optional[Dict[str, Any]]: |
| """Strict args_json parsing. |
| |
| - Empty payload -> {} |
| - Must be valid JSON and MUST decode to an object (dict) |
| - Any invalid / non-object payload -> None (treated as parse failure) |
| """ |
| if payload is None: |
| return {} |
| s = payload.strip() |
| if not s: |
| return {} |
| try: |
| parsed = json.loads(s) |
| except Exception as e: |
| logger.debug(f"🔧 Invalid JSON in args_json: {type(e).__name__}: {e}") |
| return None |
| if not isinstance(parsed, dict): |
| logger.debug(f"🔧 args_json must decode to an object, got {type(parsed).__name__}") |
| return None |
| return parsed |
|
|
| def _extract_cdata_text(raw: str) -> str: |
| if raw is None: |
| return "" |
| if "<![CDATA[" not in raw: |
| return raw |
| parts = re.findall(r"<!\[CDATA\[(.*?)\]\]>", raw, flags=re.DOTALL) |
| return "".join(parts) if parts else raw |
|
|
| results: List[Dict[str, Any]] = [] |
|
|
| |
| try: |
| root = ET.fromstring(calls_xml) |
| for i, fc in enumerate(root.findall("function_call")): |
| tool_el = fc.find("tool") |
| name = (tool_el.text or "").strip() if tool_el is not None else "" |
| if not name: |
| logger.debug(f"🔧 No tool tag found in function_call #{i+1}") |
| continue |
|
|
| args: Dict[str, Any] = {} |
|
|
| args_json_el = fc.find("args_json") |
| if args_json_el is not None: |
| parsed_args = _parse_args_json_payload(args_json_el.text or "") |
| if parsed_args is None: |
| logger.debug(f"🔧 Invalid args_json in function_call #{i+1}; treating as parse failure") |
| return None |
| args = parsed_args |
| else: |
| |
| args_el = fc.find("args") |
| if args_el is not None: |
| for child in list(args_el): |
| args[child.tag] = _coerce_value(child.text or "") |
|
|
| result = {"name": name, "args": args} |
| results.append(result) |
| logger.debug(f"🔧 Added tool call: {result}") |
|
|
| logger.debug(f"🔧 Final parsing result (XML): {results}") |
| return results if results else None |
| except Exception as e: |
| logger.debug(f"🔧 XML library parse failed, falling back to regex parser: {type(e).__name__}: {e}") |
|
|
| |
| call_blocks = re.findall(r"<function_call>([\s\S]*?)</function_call>", calls_content) |
| logger.debug(f"🔧 Found {len(call_blocks)} function_call blocks") |
|
|
| for i, block in enumerate(call_blocks): |
| logger.debug(f"🔧 Processing function_call #{i+1}: {repr(block)}") |
|
|
| tool_match = re.search(r"<tool>(.*?)</tool>", block) |
| if not tool_match: |
| logger.debug(f"🔧 No tool tag found in block #{i+1}") |
| continue |
|
|
| name = tool_match.group(1).strip() |
| args: Dict[str, Any] = {} |
|
|
| args_json_match = re.search(r"<args_json>([\s\S]*?)</args_json>", block) |
| if args_json_match: |
| raw_payload = args_json_match.group(1) |
| payload = _extract_cdata_text(raw_payload) |
| parsed_args = _parse_args_json_payload(payload) |
| if parsed_args is None: |
| logger.debug(f"🔧 Invalid args_json in function_call #{i+1} (regex path); treating as parse failure") |
| return None |
| args = parsed_args |
| else: |
| |
| args_block_match = re.search(r"<args>([\s\S]*?)</args>", block) |
| if args_block_match: |
| args_content_inner = args_block_match.group(1) |
| arg_matches = re.findall(r"<([^\s>/]+)>([\s\S]*?)</\1>", args_content_inner) |
| for k, v in arg_matches: |
| args[k] = _coerce_value(v) |
|
|
| result = {"name": name, "args": args} |
| results.append(result) |
| logger.debug(f"🔧 Added tool call: {result}") |
|
|
| logger.debug(f"🔧 Final parsing result (regex): {results}") |
| return results if results else None |
|
|
| def find_upstream(model_name: str) -> tuple[Dict[str, Any], str]: |
| """Find upstream configuration by model name, handling aliases and passthrough mode.""" |
| |
| |
| if app_config.features.model_passthrough: |
| logger.info("🔄 Model passthrough mode is active. Forwarding to 'openai' service.") |
| openai_service = None |
| for service in app_config.upstream_services: |
| if service.name == "openai": |
| openai_service = service.model_dump() |
| break |
| |
| if openai_service: |
| if not openai_service.get("api_key"): |
| raise HTTPException(status_code=500, detail="Configuration error: API key not found for the 'openai' service in model passthrough mode.") |
| |
| return openai_service, model_name |
| else: |
| raise HTTPException(status_code=500, detail="Configuration error: 'model_passthrough' is enabled, but no upstream service named 'openai' was found.") |
|
|
| |
| chosen_model_entry = model_name |
| |
| if model_name in ALIAS_MAPPING: |
| chosen_model_entry = random.choice(ALIAS_MAPPING[model_name]) |
| logger.info(f"🔄 Model alias '{model_name}' detected. Randomly selected '{chosen_model_entry}' for this request.") |
|
|
| service = MODEL_TO_SERVICE_MAPPING.get(chosen_model_entry) |
| |
| if service: |
| if not service.get("api_key"): |
| raise HTTPException(status_code=500, detail=f"Model configuration error: API key not found for service '{service.get('name')}'.") |
| else: |
| logger.warning(f"⚠️ Model '{model_name}' not found in configuration, using default service") |
| service = DEFAULT_SERVICE |
| if not service.get("api_key"): |
| raise HTTPException(status_code=500, detail="Service configuration error: Default API key not found.") |
|
|
| actual_model_name = chosen_model_entry |
| if ':' in chosen_model_entry: |
| parts = chosen_model_entry.split(':', 1) |
| if len(parts) == 2: |
| _, actual_model_name = parts |
| |
| return service, actual_model_name |
|
|
| app = FastAPI() |
| http_client = httpx.AsyncClient() |
|
|
| @app.middleware("http") |
| async def debug_middleware(request: Request, call_next): |
| """Middleware for debugging validation errors, does not log conversation content.""" |
| response = await call_next(request) |
| |
| if response.status_code == 422: |
| logger.debug(f"🔍 Validation error detected for {request.method} {request.url.path}") |
| logger.debug(f"🔍 Response status code: 422 (Pydantic validation failure)") |
| |
| return response |
|
|
| @app.exception_handler(ValidationError) |
| async def validation_exception_handler(request: Request, exc: ValidationError): |
| """Handle Pydantic validation errors with detailed error information""" |
| logger.error(f"❌ Pydantic validation error: {exc}") |
| logger.error(f"❌ Request URL: {request.url}") |
| logger.error(f"❌ Error details: {exc.errors()}") |
| |
| for error in exc.errors(): |
| logger.error(f"❌ Validation error location: {error.get('loc')}") |
| logger.error(f"❌ Validation error message: {error.get('msg')}") |
| logger.error(f"❌ Validation error type: {error.get('type')}") |
| |
| return JSONResponse( |
| status_code=422, |
| content={ |
| "error": { |
| "message": "Invalid request format", |
| "type": "invalid_request_error", |
| "code": "invalid_request" |
| } |
| } |
| ) |
|
|
| @app.exception_handler(HTTPException) |
| async def http_exception_handler(request: Request, exc: HTTPException): |
| """Handle FastAPI HTTPException with OpenAI-compatible error envelope""" |
| status = exc.status_code |
|
|
| if status == 400: |
| err_type = "invalid_request_error" |
| code = "invalid_request" |
| elif status == 401: |
| err_type = "authentication_error" |
| code = "unauthorized" |
| elif status == 403: |
| err_type = "permission_error" |
| code = "forbidden" |
| elif status == 429: |
| err_type = "rate_limit_error" |
| code = "rate_limit_exceeded" |
| else: |
| err_type = "server_error" |
| code = "internal_error" |
|
|
| return JSONResponse( |
| status_code=status, |
| content={ |
| "error": { |
| "message": str(exc.detail), |
| "type": err_type, |
| "code": code, |
| } |
| }, |
| ) |
|
|
| @app.exception_handler(Exception) |
| async def general_exception_handler(request: Request, exc: Exception): |
| """Handle all uncaught exceptions""" |
| logger.error(f"❌ Unhandled exception: {exc}") |
| logger.error(f"❌ Request URL: {request.url}") |
| logger.error(f"❌ Exception type: {type(exc).__name__}") |
| logger.error(f"❌ Error stack: {traceback.format_exc()}") |
| |
| return JSONResponse( |
| status_code=500, |
| content={ |
| "error": { |
| "message": "Internal server error", |
| "type": "server_error", |
| "code": "internal_error" |
| } |
| } |
| ) |
|
|
| async def verify_api_key(authorization: str = Header(...)): |
| """Dependency: verify client API key""" |
| client_key = authorization.replace("Bearer ", "") |
| if app_config.features.key_passthrough: |
| |
| return client_key |
| if client_key not in ALLOWED_CLIENT_KEYS: |
| raise HTTPException(status_code=401, detail="Unauthorized") |
| return client_key |
|
|
| def preprocess_messages(messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]: |
| """ |
| Preprocess messages: |
| - Convert role=tool messages to role=user text for upstream compatibility |
| - Convert assistant.tool_calls into assistant.content (XML format) for upstream context |
| - Convert developer->system if configured |
| """ |
| tool_call_index = build_tool_call_index_from_messages(messages) |
|
|
| processed_messages: List[Dict[str, Any]] = [] |
|
|
| for message in messages: |
| if isinstance(message, dict): |
| if message.get("role") == "tool": |
| tool_call_id = message.get("tool_call_id") |
| content = message.get("content") |
|
|
| if not tool_call_id: |
| raise HTTPException(status_code=400, detail="Tool message missing tool_call_id") |
|
|
| |
| if content is None: |
| raise HTTPException(status_code=400, detail=f"Tool message missing content for tool_call_id={tool_call_id}") |
|
|
| tool_info = tool_call_index.get(tool_call_id) |
| if not tool_info: |
| raise HTTPException( |
| status_code=400, |
| detail=( |
| f"tool_call_id={tool_call_id} not found in conversation history. " |
| f"Ensure the assistant message with this tool_call is included in the messages array." |
| ) |
| ) |
|
|
| formatted_content = format_tool_result_for_ai( |
| tool_name=tool_info["name"], |
| tool_arguments=tool_info["arguments"], |
| result_content=content, |
| ) |
|
|
| processed_messages.append({ |
| "role": "user", |
| "content": formatted_content |
| }) |
| logger.debug(f"🔧 Converted tool message to user message: tool_call_id={tool_call_id}, tool={tool_info['name']}") |
|
|
| elif message.get("role") == "assistant" and message.get("tool_calls"): |
| tool_calls = message.get("tool_calls", []) |
| formatted_tool_calls_str = format_assistant_tool_calls_for_ai(tool_calls, GLOBAL_TRIGGER_SIGNAL) |
|
|
| original_content = message.get("content") or "" |
| final_content = f"{original_content}\n{formatted_tool_calls_str}".strip() |
|
|
| processed_message = { |
| "role": "assistant", |
| "content": final_content |
| } |
| for key, value in message.items(): |
| if key not in ["role", "content", "tool_calls"]: |
| processed_message[key] = value |
|
|
| processed_messages.append(processed_message) |
| logger.debug("🔧 Converted assistant tool_calls to content.") |
|
|
| elif message.get("role") == "developer": |
| if app_config.features.convert_developer_to_system: |
| processed_message = message.copy() |
| processed_message["role"] = "system" |
| processed_messages.append(processed_message) |
| logger.debug("🔧 Converted developer message to system message for better upstream compatibility") |
| else: |
| processed_messages.append(message) |
| logger.debug("🔧 Keeping developer role unchanged (based on configuration)") |
| else: |
| processed_messages.append(message) |
| else: |
| processed_messages.append(message) |
|
|
| return processed_messages |
|
|
| @app.post("/v1/chat/completions") |
| async def chat_completions( |
| request: Request, |
| body: ChatCompletionRequest, |
| _api_key: str = Depends(verify_api_key) |
| ): |
| """Main chat completion endpoint, proxy and inject function calling capabilities.""" |
| start_time = time.time() |
| |
| try: |
| logger.debug(f"🔧 Received request, model: {body.model}") |
| logger.debug(f"🔧 Number of messages: {len(body.messages)}") |
| logger.debug(f"🔧 Number of tools: {len(body.tools) if body.tools else 0}") |
| logger.debug(f"🔧 Streaming: {body.stream}") |
| |
| upstream, actual_model = find_upstream(body.model) |
| upstream_url = f"{upstream['base_url']}/chat/completions" |
| |
| logger.debug(f"🔧 Starting message preprocessing, original message count: {len(body.messages)}") |
| processed_messages = preprocess_messages(body.messages) |
| logger.debug(f"🔧 Preprocessing completed, processed message count: {len(processed_messages)}") |
| |
| if not validate_message_structure(processed_messages): |
| logger.error(f"❌ Message structure validation failed, but continuing processing") |
| |
| request_body_dict = body.model_dump(exclude_unset=True) |
| request_body_dict["model"] = actual_model |
| request_body_dict["messages"] = processed_messages |
| is_fc_enabled = app_config.features.enable_function_calling |
| has_tools_in_request = bool(body.tools) |
| has_function_call = is_fc_enabled and has_tools_in_request |
| |
| logger.debug(f"🔧 Request body constructed, message count: {len(processed_messages)}") |
| |
| except HTTPException as e: |
| |
| logger.error(f"❌ Request rejected: status_code={e.status_code}, detail={e.detail}") |
| return JSONResponse( |
| status_code=e.status_code, |
| content={ |
| "error": { |
| "message": str(e.detail), |
| "type": "invalid_request_error" if e.status_code == 400 else ( |
| "authentication_error" if e.status_code == 401 else ( |
| "permission_error" if e.status_code == 403 else ( |
| "rate_limit_error" if e.status_code == 429 else "server_error" |
| ) |
| ) |
| ), |
| "code": "invalid_request" if e.status_code == 400 else ( |
| "unauthorized" if e.status_code == 401 else ( |
| "forbidden" if e.status_code == 403 else ( |
| "rate_limit_exceeded" if e.status_code == 429 else "internal_error" |
| ) |
| ) |
| ) |
| } |
| } |
| ) |
|
|
| except Exception as e: |
| logger.error(f"❌ Request preprocessing failed: {str(e)}") |
| logger.error(f"❌ Error type: {type(e).__name__}") |
| if hasattr(app_config, 'debug') and app_config.debug: |
| logger.error(f"❌ Error stack: {traceback.format_exc()}") |
| |
| return JSONResponse( |
| status_code=422, |
| content={ |
| "error": { |
| "message": "Invalid request format", |
| "type": "invalid_request_error", |
| "code": "invalid_request" |
| } |
| } |
| ) |
|
|
| if has_function_call: |
| logger.debug(f"🔧 Using global trigger signal for this request: {GLOBAL_TRIGGER_SIGNAL}") |
|
|
| tools_for_request: List[Tool] = body.tools or [] |
| function_prompt, _ = generate_function_prompt(tools_for_request, GLOBAL_TRIGGER_SIGNAL) |
| |
| tool_choice_prompt = safe_process_tool_choice(body.tool_choice, tools_for_request) |
| if tool_choice_prompt: |
| function_prompt += tool_choice_prompt |
|
|
| system_message = {"role": "system", "content": function_prompt} |
| request_body_dict["messages"].insert(0, system_message) |
| |
| if "tools" in request_body_dict: |
| del request_body_dict["tools"] |
| if "tool_choice" in request_body_dict: |
| del request_body_dict["tool_choice"] |
|
|
| elif has_tools_in_request and not is_fc_enabled: |
| logger.info(f"🔧 Function calling is disabled by configuration, ignoring 'tools' and 'tool_choice' in request.") |
| if "tools" in request_body_dict: |
| del request_body_dict["tools"] |
| if "tool_choice" in request_body_dict: |
| del request_body_dict["tool_choice"] |
|
|
| prompt_tokens = token_counter.count_tokens(request_body_dict["messages"], body.model) |
| logger.info(f"📊 Request to {body.model} - Actual input tokens (including all preprocessing & injected prompts): {prompt_tokens}") |
|
|
| headers = { |
| "Content-Type": "application/json", |
| "Authorization": f"Bearer {_api_key}" if app_config.features.key_passthrough else f"Bearer {upstream['api_key']}", |
| "Accept": "application/json" if not body.stream else "text/event-stream" |
| } |
|
|
| logger.info(f"📝 Forwarding request to upstream: {upstream['name']}") |
| logger.info(f"📝 Model: {request_body_dict.get('model', 'unknown')}, Messages: {len(request_body_dict.get('messages', []))}") |
|
|
| if not body.stream: |
| try: |
| logger.debug(f"🔧 Sending upstream request to: {upstream_url}") |
| logger.debug(f"🔧 has_function_call: {has_function_call}") |
| logger.debug(f"🔧 Request body contains tools: {bool(body.tools)}") |
| |
| upstream_response = await http_client.post( |
| upstream_url, json=request_body_dict, headers=headers, timeout=app_config.server.timeout |
| ) |
| upstream_response.raise_for_status() |
| |
| response_json = upstream_response.json() |
| logger.debug(f"🔧 Upstream response status code: {upstream_response.status_code}") |
| |
| |
| completion_text = "" |
| if response_json.get("choices") and len(response_json["choices"]) > 0: |
| message = response_json["choices"][0].get("message", {}) |
| |
| |
| content = message.get("content") |
| if content: |
| completion_text = content |
| |
| |
| reasoning_content = message.get("reasoning_content") |
| if reasoning_content: |
| completion_text = (completion_text + "\n" + reasoning_content).strip() if completion_text else reasoning_content |
| logger.debug(f"🔧 Found reasoning_content, adding {len(reasoning_content)} chars to token count") |
| |
| |
| estimated_completion_tokens = token_counter.count_text_tokens(completion_text, body.model) if completion_text else 0 |
| estimated_prompt_tokens = prompt_tokens |
| estimated_total_tokens = estimated_prompt_tokens + estimated_completion_tokens |
| elapsed_time = time.time() - start_time |
| |
| |
| upstream_usage = response_json.get("usage", {}) |
| if upstream_usage: |
| |
| final_usage = upstream_usage.copy() |
| |
| |
| if not final_usage.get("prompt_tokens") or final_usage.get("prompt_tokens") == 0: |
| final_usage["prompt_tokens"] = estimated_prompt_tokens |
| logger.debug(f"🔧 Replaced zero/missing prompt_tokens with estimate: {estimated_prompt_tokens}") |
| |
| if not final_usage.get("completion_tokens") or final_usage.get("completion_tokens") == 0: |
| final_usage["completion_tokens"] = estimated_completion_tokens |
| logger.debug(f"🔧 Replaced zero/missing completion_tokens with estimate: {estimated_completion_tokens}") |
| |
| if not final_usage.get("total_tokens") or final_usage.get("total_tokens") == 0: |
| final_usage["total_tokens"] = final_usage.get("prompt_tokens", estimated_prompt_tokens) + final_usage.get("completion_tokens", estimated_completion_tokens) |
| logger.debug(f"🔧 Replaced zero/missing total_tokens with calculated value: {final_usage['total_tokens']}") |
| |
| response_json["usage"] = final_usage |
| logger.debug(f"🔧 Preserved upstream usage with replacements: {final_usage}") |
| else: |
| |
| response_json["usage"] = { |
| "prompt_tokens": estimated_prompt_tokens, |
| "completion_tokens": estimated_completion_tokens, |
| "total_tokens": estimated_total_tokens |
| } |
| logger.debug(f"🔧 No upstream usage found, using estimates") |
| |
| |
| actual_usage = response_json["usage"] |
| logger.info("=" * 60) |
| logger.info(f"📊 Token Usage Statistics - Model: {body.model}") |
| logger.info(f" Input Tokens: {actual_usage.get('prompt_tokens', 0)}") |
| logger.info(f" Output Tokens: {actual_usage.get('completion_tokens', 0)}") |
| logger.info(f" Total Tokens: {actual_usage.get('total_tokens', 0)}") |
| logger.info(f" Duration: {elapsed_time:.2f}s") |
| logger.info("=" * 60) |
| |
| if has_function_call: |
| content = response_json["choices"][0]["message"]["content"] |
| logger.debug(f"🔧 Complete response content: {repr(content)}") |
| |
| parsed_tools = await attempt_fc_parse_with_retry( |
| content=content, |
| trigger_signal=GLOBAL_TRIGGER_SIGNAL, |
| messages=request_body_dict["messages"], |
| upstream_url=upstream_url, |
| headers=headers, |
| model=actual_model, |
| tools=body.tools or [], |
| timeout=app_config.server.timeout |
| ) |
| logger.debug(f"🔧 XML parsing result: {parsed_tools}") |
| |
| if parsed_tools: |
| logger.debug(f"🔧 Successfully parsed {len(parsed_tools)} tool calls") |
| estimated_completion_tokens = token_counter.count_text_tokens(content, body.model) |
| estimated_total_tokens = estimated_prompt_tokens + estimated_completion_tokens |
| logger.debug(f"🔧 Completion tokens: {estimated_completion_tokens}") |
| |
| tool_calls = [] |
| for tool in parsed_tools: |
| tool_call_id = f"call_{uuid.uuid4().hex}" |
| tool_calls.append({ |
| "id": tool_call_id, |
| "type": "function", |
| "function": { |
| "name": tool["name"], |
| "arguments": json.dumps(tool["args"]) |
| } |
| }) |
| logger.debug(f"🔧 Converted tool_calls: {tool_calls}") |
| |
| prefix_pos = find_last_trigger_signal_outside_think(content, GLOBAL_TRIGGER_SIGNAL) |
| prefix_text = None |
| if prefix_pos != -1: |
| prefix_text = content[:prefix_pos].rstrip() |
| if prefix_text == "": |
| prefix_text = None |
|
|
| |
| original_message = response_json["choices"][0]["message"] |
| new_message = { |
| "role": "assistant", |
| "content": prefix_text, |
| "tool_calls": tool_calls, |
| } |
| |
| for key in original_message: |
| if key not in ["role", "content", "tool_calls"]: |
| new_message[key] = original_message[key] |
| response_json["choices"][0]["message"] = new_message |
| response_json["choices"][0]["finish_reason"] = "tool_calls" |
| logger.debug(f"🔧 Function call conversion completed") |
| else: |
| logger.debug(f"🔧 No tool calls detected, returning original content (including think blocks)") |
| else: |
| logger.debug(f"🔧 No function calls detected or conversion conditions not met") |
| |
| return JSONResponse(content=response_json) |
|
|
| except httpx.HTTPStatusError as e: |
| logger.error(f"❌ Upstream service response error: status_code={e.response.status_code}") |
| logger.error(f"❌ Upstream error details: {e.response.text}") |
| |
| if e.response.status_code == 400: |
| error_response = { |
| "error": { |
| "message": "Invalid request parameters", |
| "type": "invalid_request_error", |
| "code": "bad_request" |
| } |
| } |
| elif e.response.status_code == 401: |
| error_response = { |
| "error": { |
| "message": "Authentication failed", |
| "type": "authentication_error", |
| "code": "unauthorized" |
| } |
| } |
| elif e.response.status_code == 403: |
| error_response = { |
| "error": { |
| "message": "Access forbidden", |
| "type": "permission_error", |
| "code": "forbidden" |
| } |
| } |
| elif e.response.status_code == 429: |
| error_response = { |
| "error": { |
| "message": "Rate limit exceeded", |
| "type": "rate_limit_error", |
| "code": "rate_limit_exceeded" |
| } |
| } |
| elif e.response.status_code >= 500: |
| error_response = { |
| "error": { |
| "message": "Upstream service temporarily unavailable", |
| "type": "service_error", |
| "code": "upstream_error" |
| } |
| } |
| else: |
| error_response = { |
| "error": { |
| "message": "Request processing failed", |
| "type": "api_error", |
| "code": "unknown_error" |
| } |
| } |
| |
| return JSONResponse(content=error_response, status_code=e.response.status_code) |
| |
| else: |
| async def stream_with_token_count(): |
| completion_tokens = 0 |
| completion_text = "" |
| done_received = False |
| stream_id = None |
| upstream_usage_chunk = None |
| |
| async for chunk in stream_proxy_with_fc_transform( |
| upstream_url, |
| request_body_dict, |
| headers, |
| body.model, |
| has_function_call, |
| GLOBAL_TRIGGER_SIGNAL, |
| request_body_dict["messages"], |
| tools=body.tools or [], |
| ): |
| |
| if chunk.startswith(b"data: "): |
| try: |
| line_data = chunk[6:].decode('utf-8').strip() |
| if line_data == "[DONE]": |
| done_received = True |
| |
| break |
| elif line_data: |
| chunk_json = json.loads(line_data) |
|
|
| if stream_id is None and isinstance(chunk_json, dict): |
| stream_id = chunk_json.get("id") |
| |
| if chunk_json.get("object") == "chat.completion.chunk.internal": |
| raw_fc_content = chunk_json.get("_internal_fc_raw_content", "") |
| if raw_fc_content: |
| completion_text += raw_fc_content |
| logger.debug(f"🔧 Received internal FC raw content for token counting: {len(raw_fc_content)} chars") |
| continue |
| |
| |
| if "usage" in chunk_json: |
| upstream_usage_chunk = chunk_json |
| logger.debug(f"🔧 Detected upstream usage data in chunk") |
| if not ("choices" in chunk_json and len(chunk_json["choices"]) > 0): |
| continue |
| else: |
| chunk_json = {k: v for k, v in chunk_json.items() if k != "usage"} |
| chunk = f"data: {json.dumps(chunk_json)}\n\n".encode('utf-8') |
| |
| |
| if "choices" in chunk_json and len(chunk_json["choices"]) > 0: |
| delta = chunk_json["choices"][0].get("delta", {}) |
| |
| |
| content = delta.get("content", "") |
| if content: |
| completion_text += content |
| |
| |
| reasoning_content = delta.get("reasoning_content", "") |
| if reasoning_content: |
| completion_text += reasoning_content |
| logger.debug(f"🔧 Found reasoning_content in stream, accumulating for token count") |
| except (json.JSONDecodeError, KeyError, UnicodeDecodeError) as e: |
| logger.debug(f"Failed to parse chunk for token counting: {e}") |
| pass |
| |
| yield chunk |
| |
| |
| estimated_completion_tokens = token_counter.count_text_tokens(completion_text, body.model) if completion_text else 0 |
| estimated_prompt_tokens = prompt_tokens |
| estimated_total_tokens = estimated_prompt_tokens + estimated_completion_tokens |
| elapsed_time = time.time() - start_time |
| |
| |
| final_usage = None |
| if upstream_usage_chunk and "usage" in upstream_usage_chunk: |
| |
| upstream_usage = upstream_usage_chunk["usage"] |
| final_usage = upstream_usage.copy() |
| |
| if not final_usage.get("prompt_tokens") or final_usage.get("prompt_tokens") == 0: |
| final_usage["prompt_tokens"] = estimated_prompt_tokens |
| logger.debug(f"🔧 Replaced zero/missing prompt_tokens with estimate: {estimated_prompt_tokens}") |
| |
| if not final_usage.get("completion_tokens") or final_usage.get("completion_tokens") == 0: |
| final_usage["completion_tokens"] = estimated_completion_tokens |
| logger.debug(f"🔧 Replaced zero/missing completion_tokens with estimate: {estimated_completion_tokens}") |
| |
| if not final_usage.get("total_tokens") or final_usage.get("total_tokens") == 0: |
| final_usage["total_tokens"] = final_usage.get("prompt_tokens", estimated_prompt_tokens) + final_usage.get("completion_tokens", estimated_completion_tokens) |
| logger.debug(f"🔧 Replaced zero/missing total_tokens with calculated value: {final_usage['total_tokens']}") |
| |
| logger.debug(f"🔧 Using upstream usage with replacements: {final_usage}") |
| else: |
| |
| final_usage = { |
| "prompt_tokens": estimated_prompt_tokens, |
| "completion_tokens": estimated_completion_tokens, |
| "total_tokens": estimated_total_tokens |
| } |
| logger.debug(f"🔧 No upstream usage found, using estimates") |
| |
| |
| logger.info("=" * 60) |
| logger.info(f"📊 Token Usage Statistics - Model: {body.model}") |
| logger.info(f" Input Tokens: {final_usage['prompt_tokens']}") |
| logger.info(f" Output Tokens: {final_usage['completion_tokens']}") |
| logger.info(f" Total Tokens: {final_usage['total_tokens']}") |
| logger.info(f" Duration: {elapsed_time:.2f}s") |
| logger.info("=" * 60) |
| |
| |
| if body.stream_options and body.stream_options.get("include_usage", False): |
| usage_chunk_to_send = { |
| "id": (upstream_usage_chunk.get("id") if isinstance(upstream_usage_chunk, dict) else None) or stream_id or f"chatcmpl-{uuid.uuid4().hex}", |
| "object": "chat.completion.chunk", |
| "created": int(time.time()), |
| "model": body.model, |
| "choices": [], |
| "usage": final_usage |
| } |
| |
| |
| if upstream_usage_chunk: |
| for key in upstream_usage_chunk: |
| if key not in ["usage", "choices"] and key not in usage_chunk_to_send: |
| usage_chunk_to_send[key] = upstream_usage_chunk[key] |
| |
| yield f"data: {json.dumps(usage_chunk_to_send)}\n\n".encode('utf-8') |
| logger.debug(f"🔧 Sent usage chunk in stream: {usage_chunk_to_send['usage']}") |
| |
| |
| if done_received: |
| yield b"data: [DONE]\n\n" |
| |
| return StreamingResponse( |
| stream_with_token_count(), |
| media_type="text/event-stream" |
| ) |
|
|
| async def _attempt_streaming_fc_retry( |
| original_content: str, |
| trigger_signal: str, |
| messages: List[Dict[str, Any]], |
| url: str, |
| headers: Dict[str, str], |
| model: str, |
| timeout: int, |
| tools: Optional[List["Tool"]] = None, |
| ) -> Optional[List[Dict[str, Any]]]: |
| max_attempts = app_config.features.fc_error_retry_max_attempts |
| current_content = original_content |
| current_messages = messages.copy() |
|
|
| def _parse_and_validate(current_content: str) -> tuple[Optional[List[Dict[str, Any]]], Optional[str]]: |
| parsed = parse_function_calls_xml(current_content, trigger_signal) |
| if not parsed: |
| return None, None |
| validation_error = validate_parsed_tools(parsed, tools or []) |
| if validation_error: |
| return None, validation_error |
| return parsed, None |
| |
| validation_error: Optional[str] = None |
|
|
| for attempt in range(max_attempts): |
| |
| if find_last_trigger_signal_outside_think(current_content, trigger_signal) == -1: |
| logger.debug("🔧 Streaming retry: no trigger signal found outside <think> blocks; aborting retry") |
| return None |
|
|
| validation_error = None |
| if attempt == 0: |
| parsed_tools, validation_error = _parse_and_validate(current_content) |
| if parsed_tools: |
| return parsed_tools |
| |
| if attempt >= max_attempts - 1: |
| logger.warning(f"⚠️ Streaming FC retry failed after {max_attempts} attempts") |
| return None |
| |
| |
| error_details = validation_error or _diagnose_fc_parse_error(current_content, trigger_signal) |
| retry_prompt = get_fc_error_retry_prompt(current_content, error_details) |
| |
| logger.info(f"🔄 Streaming FC retry attempt {attempt + 2}/{max_attempts}") |
| |
| retry_messages = current_messages + [ |
| {"role": "assistant", "content": current_content}, |
| {"role": "user", "content": retry_prompt} |
| ] |
| |
| try: |
| retry_response = await http_client.post( |
| url, |
| json={"model": model, "messages": retry_messages, "stream": False}, |
| headers=headers, |
| timeout=timeout |
| ) |
| retry_response.raise_for_status() |
| retry_json = retry_response.json() |
| |
| if retry_json.get("choices") and len(retry_json["choices"]) > 0: |
| current_content = retry_json["choices"][0].get("message", {}).get("content", "") |
| current_messages = retry_messages |
| |
| parsed_tools, validation_error = _parse_and_validate(current_content) |
| if parsed_tools: |
| return parsed_tools |
| else: |
| logger.warning(f"⚠️ Streaming FC retry response has no valid choices") |
| return None |
| |
| except Exception as e: |
| logger.error(f"❌ Streaming FC retry request failed: {e}") |
| return None |
| |
| return None |
|
|
|
|
| async def stream_proxy_with_fc_transform( |
| url: str, |
| body: dict, |
| headers: dict, |
| model: str, |
| has_fc: bool, |
| trigger_signal: str, |
| original_messages: Optional[List[Dict[str, Any]]] = None, |
| tools: Optional[List["Tool"]] = None, |
| ): |
| """ |
| Enhanced streaming proxy, supports dynamic trigger signals, avoids misjudgment within think tags |
| """ |
| logger.info(f"📝 Starting streaming response from: {url}") |
| logger.info(f"📝 Function calling enabled: {has_fc}") |
|
|
| if not has_fc or not trigger_signal: |
| try: |
| async with http_client.stream("POST", url, json=body, headers=headers, timeout=app_config.server.timeout) as response: |
| async for chunk in response.aiter_bytes(): |
| yield chunk |
| except httpx.RemoteProtocolError: |
| logger.debug("🔧 Upstream closed connection prematurely, ending stream response") |
| return |
| return |
| detector = StreamingFunctionCallDetector(trigger_signal) |
|
|
| def _prepare_tool_calls(parsed_tools: List[Dict[str, Any]]): |
| tool_calls = [] |
| for i, tool in enumerate(parsed_tools): |
| tool_call_id = f"call_{uuid.uuid4().hex}" |
| tool_calls.append({ |
| "index": i, "id": tool_call_id, "type": "function", |
| "function": { "name": tool["name"], "arguments": json.dumps(tool["args"]) } |
| }) |
| return tool_calls |
|
|
| def _build_tool_call_sse_chunks(parsed_tools: List[Dict[str, Any]], model_id: str, raw_content: str = "") -> List[str]: |
| tool_calls = _prepare_tool_calls(parsed_tools) |
| chunks: List[str] = [] |
|
|
| if raw_content: |
| metadata_chunk = { |
| "object": "chat.completion.chunk.internal", |
| "_internal_fc_raw_content": raw_content |
| } |
| chunks.append(f"data: {json.dumps(metadata_chunk)}\n\n") |
|
|
| initial_chunk = { |
| "id": f"chatcmpl-{uuid.uuid4().hex}", "object": "chat.completion.chunk", |
| "created": int(os.path.getmtime(__file__)), "model": model_id, |
| "choices": [{"index": 0, "delta": {"role": "assistant", "content": None, "tool_calls": tool_calls}, "finish_reason": None}], |
| } |
| chunks.append(f"data: {json.dumps(initial_chunk)}\n\n") |
|
|
|
|
| final_chunk = { |
| "id": f"chatcmpl-{uuid.uuid4().hex}", "object": "chat.completion.chunk", |
| "created": int(os.path.getmtime(__file__)), "model": model_id, |
| "choices": [{"index": 0, "delta": {}, "finish_reason": "tool_calls"}], |
| } |
| chunks.append(f"data: {json.dumps(final_chunk)}\n\n") |
| chunks.append("data: [DONE]\n\n") |
| return chunks |
|
|
| try: |
| async with http_client.stream("POST", url, json=body, headers=headers, timeout=app_config.server.timeout) as response: |
| if response.status_code != 200: |
| error_content = await response.aread() |
| logger.error(f"❌ Upstream service stream response error: status_code={response.status_code}") |
| logger.error(f"❌ Upstream error details: {error_content.decode('utf-8', errors='ignore')}") |
| |
| if response.status_code == 401: |
| error_message = "Authentication failed" |
| elif response.status_code == 403: |
| error_message = "Access forbidden" |
| elif response.status_code == 429: |
| error_message = "Rate limit exceeded" |
| elif response.status_code >= 500: |
| error_message = "Upstream service temporarily unavailable" |
| else: |
| error_message = "Request processing failed" |
| |
| error_chunk = {"error": {"message": error_message, "type": "upstream_error"}} |
| yield f"data: {json.dumps(error_chunk)}\n\n".encode('utf-8') |
| yield b"data: [DONE]\n\n" |
| return |
|
|
| async for line in response.aiter_lines(): |
| if detector.state == "tool_parsing": |
| if line.startswith("data:"): |
| line_data = line[len("data: "):].strip() |
| if line_data and line_data != "[DONE]": |
| try: |
| chunk_json = json.loads(line_data) |
| delta_content = chunk_json.get("choices", [{}])[0].get("delta", {}).get("content", "") or "" |
| detector.content_buffer += delta_content |
| |
| if "</function_calls>" in detector.content_buffer: |
| logger.debug("🔧 Detected </function_calls> in stream, finalizing early...") |
| parsed_tools = detector.finalize() |
| if parsed_tools: |
| validation_error = validate_parsed_tools(parsed_tools, tools or []) |
| if validation_error: |
| logger.info(f"🔧 Tool/schema validation failed in stream finalize: {validation_error}") |
| parsed_tools = None |
|
|
| if parsed_tools: |
| logger.debug(f"🔧 Early finalize: parsed {len(parsed_tools)} tool calls") |
| for sse in _build_tool_call_sse_chunks(parsed_tools, model, detector.content_buffer): |
| yield sse.encode('utf-8') |
| return |
| else: |
| if app_config.features.enable_fc_error_retry and original_messages: |
| logger.info(f"🔄 Early finalize FC parsing failed, attempting retry...") |
| retry_parsed = await _attempt_streaming_fc_retry( |
| original_content=detector.content_buffer, |
| trigger_signal=trigger_signal, |
| messages=original_messages, |
| url=url, |
| headers=headers, |
| model=model, |
| timeout=app_config.server.timeout, |
| tools=tools, |
| ) |
| if retry_parsed: |
| logger.info(f"✅ Early finalize FC retry succeeded, parsed {len(retry_parsed)} tool calls") |
| for sse in _build_tool_call_sse_chunks(retry_parsed, model, detector.content_buffer): |
| yield sse.encode('utf-8') |
| return |
| else: |
| logger.warning(f"⚠️ Early finalize FC retry also failed, ending stream") |
| else: |
| logger.warning( |
| "⚠️ Early finalize detected </function_calls> but failed to parse tool calls; " |
| "silently ending stream. buffer_len=%s preview=%r", |
| len(detector.content_buffer), |
| detector.content_buffer[:200], |
| ) |
| stop_chunk = { |
| "id": f"chatcmpl-{uuid.uuid4().hex}", |
| "object": "chat.completion.chunk", |
| "created": int(os.path.getmtime(__file__)), |
| "model": model, |
| "choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}] |
| } |
| yield f"data: {json.dumps(stop_chunk)}\n\n".encode('utf-8') |
| yield b"data: [DONE]\n\n" |
| return |
| except (json.JSONDecodeError, IndexError): |
| pass |
| continue |
| |
| if line.startswith("data:"): |
| line_data = line[len("data: "):].strip() |
| if not line_data or line_data == "[DONE]": |
| continue |
| |
| try: |
| chunk_json = json.loads(line_data) |
| delta = chunk_json.get("choices", [{}])[0].get("delta", {}) |
| delta_content = delta.get("content", "") or "" |
| delta_reasoning = delta.get("reasoning_content", "") or "" |
| finish_reason = chunk_json.get("choices", [{}])[0].get("finish_reason") |
| |
| |
| if delta_reasoning: |
| reasoning_chunk = { |
| "id": chunk_json.get("id") or f"chatcmpl-passthrough-{uuid.uuid4().hex}", |
| "object": "chat.completion.chunk", |
| "created": chunk_json.get("created") or int(os.path.getmtime(__file__)), |
| "model": model, |
| "choices": [{"index": 0, "delta": {"reasoning_content": delta_reasoning}}] |
| } |
| yield f"data: {json.dumps(reasoning_chunk)}\n\n".encode('utf-8') |
| |
| if delta_content: |
| is_detected, content_to_yield = detector.process_chunk(delta_content) |
| |
| if content_to_yield: |
| yield_chunk = { |
| "id": f"chatcmpl-passthrough-{uuid.uuid4().hex}", |
| "object": "chat.completion.chunk", |
| "created": int(os.path.getmtime(__file__)), |
| "model": model, |
| "choices": [{"index": 0, "delta": {"content": content_to_yield}}] |
| } |
| yield f"data: {json.dumps(yield_chunk)}\n\n".encode('utf-8') |
| |
| if is_detected: |
| |
| continue |
| |
| if finish_reason: |
| finish_chunk = { |
| "id": chunk_json.get("id") or f"chatcmpl-passthrough-{uuid.uuid4().hex}", |
| "object": "chat.completion.chunk", |
| "created": chunk_json.get("created") or int(os.path.getmtime(__file__)), |
| "model": model, |
| "choices": [{"index": 0, "delta": {}, "finish_reason": finish_reason}] |
| } |
| yield f"data: {json.dumps(finish_chunk)}\n\n".encode('utf-8') |
| |
| except (json.JSONDecodeError, IndexError): |
| |
| yield (line + "\n\n").encode("utf-8") |
|
|
| except httpx.RequestError as e: |
| logger.error(f"❌ Failed to connect to upstream service: {e}") |
| logger.error(f"❌ Error type: {type(e).__name__}") |
| |
| error_message = "Failed to connect to upstream service" |
| error_chunk = {"error": {"message": error_message, "type": "connection_error"}} |
| yield f"data: {json.dumps(error_chunk)}\n\n".encode('utf-8') |
| yield b"data: [DONE]\n\n" |
| return |
|
|
| if detector.state == "tool_parsing": |
| logger.debug(f"🔧 Stream ended, starting to parse tool call XML...") |
| parsed_tools = detector.finalize() |
| if parsed_tools: |
| validation_error = validate_parsed_tools(parsed_tools, tools or []) |
| if validation_error: |
| logger.info(f"🔧 Tool/schema validation failed at stream end: {validation_error}") |
| parsed_tools = None |
|
|
| if parsed_tools: |
| logger.debug(f"🔧 Streaming processing: Successfully parsed {len(parsed_tools)} tool calls") |
| for sse in _build_tool_call_sse_chunks(parsed_tools, model, detector.content_buffer): |
| yield sse.encode("utf-8") |
| return |
| else: |
| if app_config.features.enable_fc_error_retry and original_messages: |
| logger.info(f"🔄 Streaming FC parsing failed, attempting retry with error correction...") |
| retry_parsed = await _attempt_streaming_fc_retry( |
| original_content=detector.content_buffer, |
| trigger_signal=trigger_signal, |
| messages=original_messages, |
| url=url, |
| headers=headers, |
| model=model, |
| timeout=app_config.server.timeout, |
| tools=tools, |
| ) |
| if retry_parsed: |
| logger.info(f"✅ Streaming FC retry succeeded, parsed {len(retry_parsed)} tool calls") |
| for sse in _build_tool_call_sse_chunks(retry_parsed, model, detector.content_buffer): |
| yield sse.encode("utf-8") |
| return |
| else: |
| logger.warning(f"⚠️ Streaming FC retry also failed, falling back to text output") |
| else: |
| logger.warning( |
| "⚠️ Detected tool call signal but XML parsing failed; outputting accumulated text. " |
| "buffer_len=%s preview=%r", |
| len(detector.content_buffer), |
| detector.content_buffer[:300], |
| ) |
| |
| if detector.content_buffer: |
| content_chunk = { |
| "id": f"chatcmpl-fallback-{uuid.uuid4().hex}", |
| "object": "chat.completion.chunk", |
| "created": int(os.path.getmtime(__file__)), |
| "model": model, |
| "choices": [{"index": 0, "delta": {"content": detector.content_buffer}}] |
| } |
| yield f"data: {json.dumps(content_chunk)}\n\n".encode('utf-8') |
|
|
| elif detector.state == "detecting" and detector.content_buffer: |
| |
| final_yield_chunk = { |
| "id": f"chatcmpl-finalflush-{uuid.uuid4().hex}", "object": "chat.completion.chunk", |
| "created": int(os.path.getmtime(__file__)), "model": model, |
| "choices": [{"index": 0, "delta": {"content": detector.content_buffer}}] |
| } |
| yield f"data: {json.dumps(final_yield_chunk)}\n\n".encode('utf-8') |
| |
| stop_chunk = { |
| "id": f"chatcmpl-{uuid.uuid4().hex}", |
| "object": "chat.completion.chunk", |
| "created": int(os.path.getmtime(__file__)), |
| "model": model, |
| "choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}] |
| } |
| yield f"data: {json.dumps(stop_chunk)}\n\n".encode('utf-8') |
| yield b"data: [DONE]\n\n" |
|
|
|
|
| @app.get("/") |
| def read_root(): |
| return { |
| "status": "Toolify is running", |
| "config": { |
| "upstream_services_count": len(app_config.upstream_services), |
| "client_keys_count": len(app_config.client_authentication.allowed_keys), |
| "models_count": len(MODEL_TO_SERVICE_MAPPING), |
| "features": { |
| "function_calling": app_config.features.enable_function_calling, |
| "log_level": app_config.features.log_level, |
| "convert_developer_to_system": app_config.features.convert_developer_to_system, |
| "random_trigger": True |
| } |
| } |
| } |
|
|
| @app.get("/v1/models") |
| async def list_models(_api_key: str = Depends(verify_api_key)): |
| """List all available models""" |
| visible_models = set() |
| for model_name in MODEL_TO_SERVICE_MAPPING.keys(): |
| if ':' in model_name: |
| parts = model_name.split(':', 1) |
| if len(parts) == 2: |
| alias, _ = parts |
| visible_models.add(alias) |
| else: |
| visible_models.add(model_name) |
| else: |
| visible_models.add(model_name) |
|
|
| models = [] |
| for model_id in sorted(visible_models): |
| models.append({ |
| "id": model_id, |
| "object": "model", |
| "created": 1677610602, |
| "owned_by": "openai", |
| "permission": [], |
| "root": model_id, |
| "parent": None |
| }) |
| |
| return { |
| "object": "list", |
| "data": models |
| } |
|
|
|
|
| def validate_message_structure(messages: List[Dict[str, Any]]) -> bool: |
| """Validate if message structure meets requirements""" |
| try: |
| valid_roles = ["system", "user", "assistant", "tool"] |
| if not app_config.features.convert_developer_to_system: |
| valid_roles.append("developer") |
| |
| for i, msg in enumerate(messages): |
| if "role" not in msg: |
| logger.error(f"❌ Message {i} missing role field") |
| return False |
| |
| if msg["role"] not in valid_roles: |
| logger.error(f"❌ Invalid role value for message {i}: {msg['role']}") |
| return False |
| |
| if msg["role"] == "tool": |
| if "tool_call_id" not in msg: |
| logger.error(f"❌ Tool message {i} missing tool_call_id field") |
| return False |
| |
| content = msg.get("content") |
| content_info = "" |
| if content: |
| if isinstance(content, str): |
| content_info = f", content=text({len(content)} chars)" |
| elif isinstance(content, list): |
| text_parts = [item for item in content if isinstance(item, dict) and item.get('type') == 'text'] |
| image_parts = [item for item in content if isinstance(item, dict) and item.get('type') == 'image_url'] |
| content_info = f", content=multimodal(text={len(text_parts)}, images={len(image_parts)})" |
| else: |
| content_info = f", content={type(content).__name__}" |
| else: |
| content_info = ", content=empty" |
| |
| logger.debug(f"✅ Message {i} validation passed: role={msg['role']}{content_info}") |
| |
| logger.debug(f"✅ All messages validated successfully, total {len(messages)} messages") |
| return True |
| except Exception as e: |
| logger.error(f"❌ Message validation exception: {e}") |
| return False |
|
|
| def safe_process_tool_choice(tool_choice, tools: Optional[List[Tool]] = None) -> str: |
| """ |
| Process tool_choice field and return additional prompt instructions. |
| |
| Args: |
| tool_choice: The tool_choice value from the request (str or ToolChoice object) |
| tools: List of available tools (for validation when specific tool is required) |
| |
| Returns: |
| Additional prompt text to append to the function calling prompt |
| |
| Raises: |
| HTTPException: If tool_choice specifies a tool that doesn't exist in tools list |
| """ |
| try: |
| if tool_choice is None: |
| return "" |
| |
| if isinstance(tool_choice, str): |
| if tool_choice == "none": |
| return "\n\n**IMPORTANT:** You are prohibited from using any tools in this round. Please respond like a normal chat assistant and answer the user's question directly." |
| elif tool_choice == "auto": |
| |
| return "" |
| elif tool_choice == "required": |
| return "\n\n**IMPORTANT:** You MUST call at least one tool in this response. Do not respond without using tools." |
| else: |
| logger.warning(f"⚠️ Unknown tool_choice string value: {tool_choice}") |
| return "" |
| |
| |
| elif hasattr(tool_choice, 'function'): |
| function_dict = tool_choice.function |
| if not isinstance(function_dict, dict): |
| raise HTTPException(status_code=400, detail="tool_choice.function must be an object") |
|
|
| required_tool_name = function_dict.get("name") |
| if not required_tool_name or not isinstance(required_tool_name, str): |
| raise HTTPException(status_code=400, detail="tool_choice.function.name must be a non-empty string") |
|
|
| if not tools: |
| raise HTTPException(status_code=400, detail="tool_choice requires a non-empty tools list in the request") |
|
|
| tool_names = [t.function.name for t in tools] |
| if required_tool_name not in tool_names: |
| raise HTTPException( |
| status_code=400, |
| detail=f"tool_choice specifies tool '{required_tool_name}' which is not in the tools list. Available tools: {tool_names}" |
| ) |
|
|
| return f"\n\n**IMPORTANT:** In this round, you must use ONLY the tool named `{required_tool_name}`. Generate the necessary parameters and output in the specified XML format." |
| |
| else: |
| logger.warning(f"⚠️ Unsupported tool_choice type: {type(tool_choice)}") |
| return "" |
| |
| except HTTPException: |
| |
| raise |
| except Exception as e: |
| logger.error(f"❌ Error processing tool_choice: {e}") |
| return "" |
|
|
| if __name__ == "__main__": |
| import uvicorn |
| logger.info(f"🚀 Starting server on {app_config.server.host}:{app_config.server.port}") |
| logger.info(f"⏱️ Request timeout: {app_config.server.timeout} seconds") |
| |
| uvicorn.run( |
| app, |
| host=app_config.server.host, |
| port=app_config.server.port, |
| log_level=app_config.features.log_level.lower() if app_config.features.log_level != "DISABLED" else "critical" |
| ) |
|
|