| from __future__ import annotations |
|
|
| import collections.abc |
| import hashlib |
| import json |
| import logging |
| import re |
| import threading |
| import time |
| import uuid |
| from datetime import timedelta |
| from pathlib import Path |
| from typing import Callable, Optional, Sequence, Union |
|
|
| import aiohttp |
| import mimeparse |
| from open_webui.env import CHAT_STREAM_RESPONSE_CHUNK_MAX_BUFFER_SIZE |
|
|
| log = logging.getLogger(__name__) |
|
|
|
|
| def deep_update(d, u): |
| for k, v in u.items(): |
| if isinstance(v, collections.abc.Mapping): |
| d[k] = deep_update(d.get(k, {}), v) |
| else: |
| d[k] = v |
| return d |
|
|
|
|
| def get_allow_block_lists(filter_list): |
| allow_list = [] |
| block_list = [] |
|
|
| if filter_list: |
| for d in filter_list: |
| if d.startswith('!'): |
| |
| block_list.append(d[1:].strip()) |
| else: |
| |
| allow_list.append(d.strip()) |
|
|
| return allow_list, block_list |
|
|
|
|
| def is_string_allowed(string: Union[str, Sequence[str]], filter_list: list[str | None] = None) -> bool: |
| """ |
| Checks if a string is allowed based on the provided filter list. |
| :param string: The string or sequence of strings to check (e.g., domain or hostname). |
| :param filter_list: List of allowed/blocked strings. Strings starting with "!" are blocked. |
| :return: True if the string or sequence of strings is allowed, False otherwise. |
| """ |
| if not filter_list: |
| return True |
|
|
| allow_list, block_list = get_allow_block_lists(filter_list) |
| strings = [string] if isinstance(string, str) else list(string) |
|
|
| |
| if allow_list: |
| if not any(s.endswith(allowed) for s in strings for allowed in allow_list): |
| return False |
|
|
| |
| if any(s.endswith(blocked) for s in strings for blocked in block_list): |
| return False |
|
|
| return True |
|
|
|
|
| def get_message_list(messages_map, message_id): |
| """ |
| Reconstructs a list of messages in order up to the specified message_id. |
| |
| :param message_id: ID of the message to reconstruct the chain |
| :param messages: Message history dict containing all messages |
| :return: List of ordered messages starting from the root to the given message |
| """ |
|
|
| |
| if not messages_map: |
| return [] |
|
|
| |
| current_message = messages_map.get(message_id) |
|
|
| if not current_message: |
| return [] |
|
|
| |
| message_list = [] |
| visited_message_ids = set() |
|
|
| while current_message: |
| message_id = current_message.get('id') |
| if message_id in visited_message_ids: |
| |
| break |
|
|
| if message_id is not None: |
| visited_message_ids.add(message_id) |
|
|
| message_list.append(current_message) |
| parent_id = current_message.get('parentId') |
| current_message = messages_map.get(parent_id) if parent_id else None |
|
|
| message_list.reverse() |
| return message_list |
|
|
|
|
| def get_messages_content(messages: list[dict]) -> str: |
| return '\n'.join([f'{message["role"].upper()}: {get_content_from_message(message)}' for message in messages]) |
|
|
|
|
| def get_last_user_message_item(messages: list[dict]) -> dict | None: |
| for message in reversed(messages): |
| if message['role'] == 'user': |
| return message |
| return None |
|
|
|
|
| def get_content_from_message(message: dict) -> str | None: |
| if isinstance(message.get('content'), list): |
| for item in message['content']: |
| if item['type'] == 'text': |
| return item['text'] |
| else: |
| return message.get('content') |
| return None |
|
|
|
|
| def reconcile_tool_pairs(messages: list[dict]) -> list[dict]: |
| """Drop unpaired tool_use / tool_result from a reconstructed conversation. |
| |
| Stored output can be incomplete — a tool result may be missing (e.g. the |
| knowledge base was updated mid-chat, or the call was interrupted), or a |
| tool call may be missing while its result survived. Strict providers |
| (Anthropic, AWS Bedrock Converse) reject either direction of mismatch. |
| |
| Well-formed output is unaffected: every id pairs, so nothing is stripped. |
| """ |
| completed_tool_call_ids = { |
| message['tool_call_id'] for message in messages if message.get('role') == 'tool' and message.get('tool_call_id') |
| } |
| requested_tool_call_ids = { |
| tool_call['id'] |
| for message in messages |
| for tool_call in message.get('tool_calls') or () |
| if message.get('role') == 'assistant' and tool_call.get('id') |
| } |
|
|
| reconciled_messages = [] |
| for message in messages: |
| role = message.get('role') |
|
|
| |
| if role == 'tool' and message.get('tool_call_id') not in requested_tool_call_ids: |
| continue |
|
|
| |
| if role != 'assistant' or not message.get('tool_calls'): |
| reconciled_messages.append(message) |
| continue |
|
|
| |
| valid_tool_calls = [ |
| tool_call for tool_call in message['tool_calls'] if tool_call.get('id') in completed_tool_call_ids |
| ] |
|
|
| if valid_tool_calls: |
| reconciled_messages.append({**message, 'tool_calls': valid_tool_calls}) |
| continue |
|
|
| |
| |
| content = message.get('content', '') |
| has_meaningful_content = content.strip() if isinstance(content, str) else content |
| if has_meaningful_content or message.get('reasoning_content'): |
| reconciled_messages.append({key: value for key, value in message.items() if key != 'tool_calls'}) |
|
|
| return reconciled_messages |
|
|
|
|
| def convert_output_to_messages( |
| output: list, |
| raw: bool = False, |
| reasoning_format: str | None = None, |
| ) -> list[dict]: |
| """ |
| Convert OR-aligned output items to OpenAI Chat Completion-format messages. |
| |
| This reconstructs the full conversation from the stored Responses API-native |
| output items, including assistant messages with tool_calls arrays and tool |
| role messages. |
| |
| Args: |
| output: List of OR-aligned output items (Responses API format). |
| raw: If True, include code interpreter blocks for LLM re-processing |
| follow-ups. |
| reasoning_format: How to include reasoning blocks in the output: |
| - None: skip reasoning (default, safe for strict providers). |
| - ``'think_tags'``: wrap in ``<think>`` tags inside content |
| (for Ollama, which expects reasoning as tagged content). |
| - ``'reasoning_content'``: set as ``reasoning_content`` top-level field |
| (for llama.cpp, which routes it via the chat template). |
| """ |
| if not output or not isinstance(output, list): |
| return [] |
|
|
| messages = [] |
| pending_tool_calls = [] |
| pending_content = [] |
| pending_reasoning = [] |
|
|
| def flush_pending(): |
| nonlocal pending_content, pending_tool_calls, pending_reasoning |
| if not pending_content and not pending_tool_calls and not pending_reasoning: |
| return |
|
|
| message = { |
| 'role': 'assistant', |
| 'content': '\n'.join(pending_content) if pending_content else '', |
| **({'tool_calls': pending_tool_calls} if pending_tool_calls else {}), |
| } |
|
|
| if pending_reasoning: |
| message['reasoning_content'] = '\n'.join(pending_reasoning) |
|
|
| messages.append(message) |
| pending_content = [] |
| pending_tool_calls = [] |
| pending_reasoning = [] |
|
|
| for item in output: |
| item_type = item.get('type', '') |
|
|
| if item_type == 'message': |
| |
| content_parts = item.get('content', []) |
| text = '' |
| for part in content_parts: |
| if part.get('type') == 'output_text': |
| text += part.get('text', '') |
| if text: |
| pending_content.append(text) |
|
|
| elif item_type == 'function_call': |
| |
| arguments = item.get('arguments', '{}') |
| |
| if not isinstance(arguments, str): |
| arguments = json.dumps(arguments) |
| pending_tool_calls.append( |
| { |
| 'id': item.get('call_id', ''), |
| 'type': 'function', |
| 'function': { |
| 'name': item.get('name', ''), |
| 'arguments': arguments, |
| }, |
| } |
| ) |
|
|
| elif item_type == 'function_call_output': |
| |
| flush_pending() |
|
|
| |
| output_parts = item.get('output', []) |
| content = '' |
| image_urls = [] |
| for part in output_parts: |
| if part.get('type') == 'input_text': |
| output_text = part.get('text', '') |
| content += str(output_text) if not isinstance(output_text, str) else output_text |
| elif part.get('type') == 'input_image': |
| url = part.get('image_url', '') |
| if url: |
| image_urls.append(url) |
|
|
| if image_urls: |
| |
| messages.append( |
| { |
| 'role': 'tool', |
| 'tool_call_id': item.get('call_id', ''), |
| 'content': [ |
| {'type': 'input_text', 'text': content}, |
| *[{'type': 'input_image', 'image_url': url} for url in image_urls], |
| ], |
| } |
| ) |
| else: |
| messages.append( |
| { |
| 'role': 'tool', |
| 'tool_call_id': item.get('call_id', ''), |
| 'content': content, |
| } |
| ) |
|
|
| elif item_type == 'reasoning': |
| if not reasoning_format: |
| continue |
|
|
| reasoning_text = '' |
| source_list = item.get('summary', []) or item.get('content', []) |
| for part in source_list: |
| if part.get('type') == 'output_text': |
| reasoning_text += part.get('text', '') |
| elif 'text' in part: |
| reasoning_text += part.get('text', '') |
|
|
| if reasoning_text: |
| if reasoning_format == 'think_tags': |
| |
| start_tag = item.get('start_tag', '<think>') |
| end_tag = item.get('end_tag', '</think>') |
| pending_content.append(f'{start_tag}{reasoning_text}{end_tag}') |
| elif reasoning_format == 'reasoning_content': |
| |
| pending_reasoning.append(reasoning_text) |
|
|
| elif item_type == 'open_webui:code_interpreter': |
| |
| |
| code = item.get('code', '') |
| code_output = item.get('output', '') |
|
|
| if code: |
| pending_content.append(f'<code_interpreter>\n{code}\n</code_interpreter>') |
|
|
| if code_output: |
| if isinstance(code_output, dict): |
| stdout = code_output.get('stdout', '') |
| result = code_output.get('result', '') |
| output_text = stdout or result |
| else: |
| output_text = str(code_output) |
| if output_text: |
| pending_content.append(f'<code_interpreter_output>\n{output_text}\n</code_interpreter_output>') |
|
|
| elif item_type.startswith('open_webui:'): |
| |
| pass |
|
|
| |
| flush_pending() |
|
|
| return reconcile_tool_pairs(messages) |
|
|
|
|
| def get_last_user_message(messages: list[dict]) -> str | None: |
| message = get_last_user_message_item(messages) |
| if message is None: |
| return None |
| return get_content_from_message(message) |
|
|
|
|
| def set_last_user_message_content(content: str, messages: list[dict]) -> list[dict]: |
| """ |
| Replace the text content of the last user message in-place. |
| Handles both plain-string and list-of-parts content formats. |
| """ |
| for message in reversed(messages): |
| if message.get('role') == 'user': |
| if isinstance(message.get('content'), list): |
| for item in message['content']: |
| if item.get('type') == 'text': |
| item['text'] = content |
| break |
| else: |
| message['content'] = content |
| break |
| return messages |
|
|
|
|
| def get_last_assistant_message_item(messages: list[dict]) -> dict | None: |
| for message in reversed(messages): |
| if message['role'] == 'assistant': |
| return message |
| return None |
|
|
|
|
| def get_last_assistant_message(messages: list[dict]) -> str | None: |
| for message in reversed(messages): |
| if message['role'] == 'assistant': |
| return get_content_from_message(message) |
| return None |
|
|
|
|
| def get_system_message(messages: list[dict]) -> dict | None: |
| for message in messages: |
| if message['role'] == 'system': |
| return message |
| return None |
|
|
|
|
| def remove_system_message(messages: list[dict]) -> list[dict]: |
| return [message for message in messages if message['role'] != 'system'] |
|
|
|
|
| def pop_system_message(messages: list[dict]) -> tuple[dict | None, list[dict]]: |
| return get_system_message(messages), remove_system_message(messages) |
|
|
|
|
| def merge_system_messages(messages: list[dict]) -> list[dict]: |
| """ |
| Merge all system messages into one at position 0. |
| |
| Some chat templates (e.g. Qwen) require exactly one system |
| message at the start. Multiple pipeline stages may each |
| insert their own system message; this function consolidates |
| them. |
| """ |
| system_contents: list[str] = [] |
| other_messages: list[dict] = [] |
|
|
| for message in messages: |
| if message.get('role') == 'system': |
| content = get_content_from_message(message) |
| if content: |
| system_contents.append(content) |
| else: |
| other_messages.append(message) |
|
|
| if not system_contents: |
| return other_messages |
|
|
| merged = {'role': 'system', 'content': '\n'.join(system_contents)} |
| return [merged, *other_messages] |
|
|
|
|
| def update_message_content(message: dict, content: str, append: bool = True) -> dict: |
| if isinstance(message['content'], list): |
| for item in message['content']: |
| if item['type'] == 'text': |
| if append: |
| item['text'] = f'{item["text"]}\n{content}' |
| else: |
| item['text'] = f'{content}\n{item["text"]}' |
| else: |
| if append: |
| message['content'] = f'{message["content"]}\n{content}' |
| else: |
| message['content'] = f'{content}\n{message["content"]}' |
| return message |
|
|
|
|
| def replace_system_message_content(content: str, messages: list[dict]) -> dict: |
| for message in messages: |
| if message['role'] == 'system': |
| message['content'] = content |
| break |
| return messages |
|
|
|
|
| def add_or_update_system_message(content: str, messages: list[dict], append: bool = False): |
| """ |
| Adds a new system message at the beginning of the messages list |
| or updates the existing system message at the beginning. |
| |
| :param msg: The message to be added or appended. |
| :param messages: The list of message dictionaries. |
| :return: The updated list of message dictionaries. |
| """ |
|
|
| if messages and messages[0].get('role') == 'system': |
| messages[0] = update_message_content(messages[0], content, append) |
| else: |
| |
| messages.insert(0, {'role': 'system', 'content': content}) |
|
|
| return messages |
|
|
|
|
| def add_or_update_user_message(content: str, messages: list[dict], append: bool = True): |
| """ |
| Adds a new user message at the end of the messages list |
| or updates the existing user message at the end. |
| |
| :param msg: The message to be added or appended. |
| :param messages: The list of message dictionaries. |
| :return: The updated list of message dictionaries. |
| """ |
|
|
| if messages and messages[-1].get('role') == 'user': |
| messages[-1] = update_message_content(messages[-1], content, append) |
| else: |
| |
| messages.append({'role': 'user', 'content': content}) |
|
|
| return messages |
|
|
|
|
| def prepend_to_first_user_message_content(content: str, messages: list[dict]) -> list[dict]: |
| for message in messages: |
| if message['role'] == 'user': |
| message = update_message_content(message, content, append=False) |
| break |
| return messages |
|
|
|
|
| def append_or_update_assistant_message(content: str, messages: list[dict]): |
| """ |
| Adds a new assistant message at the end of the messages list |
| or updates the existing assistant message at the end. |
| |
| :param msg: The message to be added or appended. |
| :param messages: The list of message dictionaries. |
| :return: The updated list of message dictionaries. |
| """ |
|
|
| if messages and messages[-1].get('role') == 'assistant': |
| messages[-1]['content'] = f'{messages[-1]["content"]}\n{content}' |
| else: |
| |
| messages.append({'role': 'assistant', 'content': content}) |
|
|
| return messages |
|
|
|
|
| def strip_empty_content_blocks(messages: list[dict]) -> list[dict]: |
| """ |
| Remove empty text content blocks from multimodal message content arrays. |
| |
| Providers like Gemini and Claude reject messages where a text block has |
| an empty string. This can happen when a user sends only file/image |
| attachments without typing any text. |
| """ |
| for message in messages: |
| content = message.get('content') |
| if isinstance(content, list): |
| cleaned = [ |
| block |
| for block in content |
| if not (isinstance(block, dict) and block.get('type') == 'text' and not block.get('text', '').strip()) |
| ] |
| if cleaned: |
| message['content'] = cleaned |
| return messages |
|
|
|
|
| def openai_chat_message_template(model: str): |
| return { |
| 'id': f'{model}-{str(uuid.uuid4())}', |
| 'created': int(time.time()), |
| 'model': model, |
| 'choices': [{'index': 0, 'logprobs': None, 'finish_reason': None}], |
| } |
|
|
|
|
| def openai_chat_chunk_message_template( |
| model: str, |
| content: str | None = None, |
| reasoning_content: str | None = None, |
| tool_calls: list[dict | None] = None, |
| usage: dict | None = None, |
| ) -> dict: |
| template = openai_chat_message_template(model) |
| template['object'] = 'chat.completion.chunk' |
|
|
| template['choices'][0]['index'] = 0 |
| template['choices'][0]['delta'] = {} |
|
|
| if content: |
| template['choices'][0]['delta']['content'] = content |
|
|
| if reasoning_content: |
| template['choices'][0]['delta']['reasoning_content'] = reasoning_content |
|
|
| if tool_calls: |
| template['choices'][0]['delta']['tool_calls'] = tool_calls |
|
|
| if not content and not reasoning_content and not tool_calls: |
| template['choices'][0]['finish_reason'] = 'stop' |
|
|
| if usage: |
| template['usage'] = usage |
| return template |
|
|
|
|
| def openai_chat_completion_message_template( |
| model: str, |
| message: str | None = None, |
| reasoning_content: str | None = None, |
| tool_calls: list[dict | None] = None, |
| usage: dict | None = None, |
| ) -> dict: |
| template = openai_chat_message_template(model) |
| template['object'] = 'chat.completion' |
| if message is not None: |
| template['choices'][0]['message'] = { |
| 'role': 'assistant', |
| 'content': message, |
| **({'reasoning_content': reasoning_content} if reasoning_content else {}), |
| **({'tool_calls': tool_calls} if tool_calls else {}), |
| } |
|
|
| template['choices'][0]['finish_reason'] = 'tool_calls' if tool_calls else 'stop' |
|
|
| if usage: |
| template['usage'] = usage |
| return template |
|
|
|
|
| def get_gravatar_url(email): |
| |
| |
| |
| address = str(email).strip().lower() |
|
|
| |
| hash_object = hashlib.sha256(address.encode()) |
| hash_hex = hash_object.hexdigest() |
|
|
| |
| return f'https://www.gravatar.com/avatar/{hash_hex}?d=mp' |
|
|
|
|
| |
| |
| |
| |
| def calculate_sha256(file_path, chunk_size): |
| |
| sha256 = hashlib.sha256() |
| with open(file_path, 'rb') as f: |
| while chunk := f.read(chunk_size): |
| sha256.update(chunk) |
| return sha256.hexdigest() |
|
|
|
|
| def calculate_sha256_string(string): |
| |
| sha256_hash = hashlib.sha256() |
| |
| sha256_hash.update(string.encode('utf-8')) |
| |
| hashed_string = sha256_hash.hexdigest() |
| return hashed_string |
|
|
|
|
| def validate_email_format(email: str) -> bool: |
| if email.endswith('@localhost'): |
| return True |
|
|
| return bool(re.match(r'[^@]+@[^@]+\.[^@]+', email)) |
|
|
|
|
| def sanitize_filename(file_name): |
| |
| lower_case_file_name = file_name.lower() |
|
|
| |
| sanitized_file_name = re.sub(r'[^\w\s]', '', lower_case_file_name) |
|
|
| |
| final_file_name = re.sub(r'\s+', '-', sanitized_file_name) |
|
|
| return final_file_name |
|
|
|
|
| def sanitize_text_for_db(text: str) -> str: |
| """Remove null bytes and invalid UTF-8 surrogates from text for PostgreSQL storage.""" |
| if not isinstance(text, str): |
| return text |
| |
| if '\x00' not in text: |
| return text |
| |
| text = text.replace('\x00', '').replace('\u0000', '') |
| |
| |
| try: |
| text = text.encode('utf-8', errors='surrogatepass').decode('utf-8', errors='ignore') |
| except (UnicodeEncodeError, UnicodeDecodeError): |
| pass |
| return text |
|
|
|
|
| def _strip_null_bytes_deep(obj): |
| """Inner recursive walk — only called when null bytes are known to be present.""" |
| if isinstance(obj, str): |
| return sanitize_text_for_db(obj) |
| elif isinstance(obj, dict): |
| return {k: _strip_null_bytes_deep(v) for k, v in obj.items()} |
| elif isinstance(obj, list): |
| return [_strip_null_bytes_deep(v) for v in obj] |
| return obj |
|
|
|
|
| def sanitize_data_for_db(obj): |
| """Recursively sanitize all strings in a data structure for database storage. |
| |
| Performs a fast pre-check: serializes the structure once and scans for |
| null bytes. If none are found (the overwhelmingly common case), the |
| original object is returned immediately, skipping the expensive |
| recursive walk. |
| """ |
| if isinstance(obj, str): |
| return sanitize_text_for_db(obj) |
| |
| |
| |
| try: |
| if '\\u0000' not in json.dumps(obj, ensure_ascii=False): |
| return obj |
| except (TypeError, ValueError): |
| pass |
| return _strip_null_bytes_deep(obj) |
|
|
|
|
| def sanitize_metadata(metadata: dict) -> dict: |
| """ |
| Return a JSON-safe copy of a metadata dict for database storage. |
| |
| The middleware metadata accumulates non-serializable Python objects |
| (e.g. callable tool functions, MCP client instances) that cause |
| PostgreSQL JSON inserts to fail. This helper strips those out while |
| preserving the primitive data needed for file-to-chat linking. |
| """ |
| if not isinstance(metadata, dict): |
| return metadata |
|
|
| def _sanitize(obj): |
| if isinstance(obj, (str, int, float, bool, type(None))): |
| return obj |
| if isinstance(obj, dict): |
| return {k: _sanitize(v) for k, v in obj.items() if not callable(v) and _is_serializable(v)} |
| if isinstance(obj, list): |
| return [_sanitize(v) for v in obj if not callable(v) and _is_serializable(v)] |
| if callable(obj): |
| return None |
| |
| try: |
| json.dumps(obj) |
| return obj |
| except (TypeError, ValueError): |
| return None |
|
|
| def _is_serializable(obj): |
| """Quick check whether a value can survive JSON serialization.""" |
| if isinstance(obj, (str, int, float, bool, type(None), dict, list)): |
| return True |
| try: |
| json.dumps(obj) |
| return True |
| except (TypeError, ValueError): |
| return False |
|
|
| return _sanitize(metadata) |
|
|
|
|
| def extract_folders_after_data_docs(path): |
| |
| path = Path(path) |
|
|
| |
| parts = path.parts |
|
|
| |
| try: |
| index_data_docs = parts.index('data') + 1 |
| index_docs = parts.index('docs', index_data_docs) + 1 |
| except ValueError: |
| return [] |
|
|
| |
| tags = [] |
|
|
| folders = parts[index_docs:-1] |
| for idx, _ in enumerate(folders): |
| tags.append('/'.join(folders[: idx + 1])) |
|
|
| return tags |
|
|
|
|
| def parse_duration(duration: str) -> timedelta | None: |
| if duration == '-1' or duration == '0': |
| return None |
|
|
| |
| pattern = r'(-?\d+(\.\d+)?)(ms|s|m|h|d|w)' |
| matches = re.findall(pattern, duration) |
|
|
| if not matches: |
| raise ValueError('Invalid duration string') |
|
|
| total_duration = timedelta() |
|
|
| for number, _, unit in matches: |
| number = float(number) |
| if unit == 'ms': |
| total_duration += timedelta(milliseconds=number) |
| elif unit == 's': |
| total_duration += timedelta(seconds=number) |
| elif unit == 'm': |
| total_duration += timedelta(minutes=number) |
| elif unit == 'h': |
| total_duration += timedelta(hours=number) |
| elif unit == 'd': |
| total_duration += timedelta(days=number) |
| elif unit == 'w': |
| total_duration += timedelta(weeks=number) |
|
|
| return total_duration |
|
|
|
|
| def parse_ollama_modelfile(model_text): |
| parameters_meta = { |
| 'mirostat': int, |
| 'mirostat_eta': float, |
| 'mirostat_tau': float, |
| 'num_ctx': int, |
| 'repeat_last_n': int, |
| 'repeat_penalty': float, |
| 'temperature': float, |
| 'seed': int, |
| 'tfs_z': float, |
| 'num_predict': int, |
| 'top_k': int, |
| 'top_p': float, |
| 'num_keep': int, |
| 'presence_penalty': float, |
| 'frequency_penalty': float, |
| 'num_batch': int, |
| 'num_gpu': int, |
| 'use_mmap': bool, |
| 'use_mlock': bool, |
| 'num_thread': int, |
| } |
|
|
| data = {'base_model_id': None, 'params': {}} |
|
|
| |
| base_model_match = re.search(r'^FROM\s+(\w+)', model_text, re.MULTILINE | re.IGNORECASE) |
| if base_model_match: |
| data['base_model_id'] = base_model_match.group(1) |
|
|
| |
| template_match = re.search(r'TEMPLATE\s+"""(.+?)"""', model_text, re.DOTALL | re.IGNORECASE) |
| if template_match: |
| data['params'] = {'template': template_match.group(1).strip()} |
|
|
| |
| stops = re.findall(r'PARAMETER stop "(.*?)"', model_text, re.IGNORECASE) |
| if stops: |
| data['params']['stop'] = stops |
|
|
| |
| for param, param_type in parameters_meta.items(): |
| param_match = re.search(rf'PARAMETER {param} (.+)', model_text, re.IGNORECASE) |
| if param_match: |
| value = param_match.group(1) |
|
|
| try: |
| if param_type is int: |
| value = int(value) |
| elif param_type is float: |
| value = float(value) |
| elif param_type is bool: |
| value = value.lower() == 'true' |
| except Exception as e: |
| log.exception(f'Failed to parse parameter {param}: {e}') |
| continue |
|
|
| data['params'][param] = value |
|
|
| |
| adapter_match = re.search(r'ADAPTER (.+)', model_text, re.IGNORECASE) |
| if adapter_match: |
| data['params']['adapter'] = adapter_match.group(1) |
|
|
| |
| system_desc_match = re.search(r'SYSTEM\s+"""(.+?)"""', model_text, re.DOTALL | re.IGNORECASE) |
| system_desc_match_single = re.search(r'SYSTEM\s+([^\n]+)', model_text, re.IGNORECASE) |
|
|
| if system_desc_match: |
| data['params']['system'] = system_desc_match.group(1).strip() |
| elif system_desc_match_single: |
| data['params']['system'] = system_desc_match_single.group(1).strip() |
|
|
| |
| messages = [] |
| message_matches = re.findall(r'MESSAGE (\w+) (.+)', model_text, re.IGNORECASE) |
| for role, content in message_matches: |
| messages.append({'role': role, 'content': content}) |
|
|
| if messages: |
| data['params']['messages'] = messages |
|
|
| return data |
|
|
|
|
| def convert_logit_bias_input_to_json(logit_bias_input) -> str | None: |
| if not logit_bias_input: |
| return None |
|
|
| if isinstance(logit_bias_input, dict): |
| return json.dumps(logit_bias_input) |
|
|
| logit_bias_pairs = logit_bias_input.split(',') |
| logit_bias_json = {} |
| for pair in logit_bias_pairs: |
| token, bias = pair.split(':') |
| token = str(token.strip()) |
| bias = int(bias.strip()) |
| bias = 100 if bias > 100 else -100 if bias < -100 else bias |
| logit_bias_json[token] = bias |
| return json.dumps(logit_bias_json) |
|
|
|
|
| def freeze(value): |
| """ |
| Freeze a value to make it hashable. |
| """ |
| if isinstance(value, dict): |
| return frozenset((k, freeze(v)) for k, v in value.items()) |
| elif isinstance(value, list): |
| return tuple(freeze(v) for v in value) |
| return value |
|
|
|
|
| def throttle(interval: float = 10.0): |
| """ |
| Decorator to prevent a function from being called more than once within a specified duration. |
| If the function is called again within the duration, it returns None. To avoid returning |
| different types, the return type of the function should be T | None. |
| |
| :param interval: Duration in seconds to wait before allowing the function to be called again. |
| """ |
|
|
| def decorator(func): |
| last_calls = {} |
| lock = threading.Lock() |
|
|
| async def wrapper(*args, **kwargs): |
| if interval is None: |
| return await func(*args, **kwargs) |
|
|
| key = (args, freeze(kwargs)) |
| now = time.time() |
| if now - last_calls.get(key, 0) < interval: |
| return None |
| with lock: |
| if now - last_calls.get(key, 0) < interval: |
| return None |
| last_calls[key] = now |
| return await func(*args, **kwargs) |
|
|
| return wrapper |
|
|
| return decorator |
|
|
|
|
| def strict_match_mime_type(supported: list[str] | str, header: str) -> str | None: |
| """ |
| Strictly match the mime type with the supported mime types. |
| |
| :param supported: The supported mime types. |
| :param header: The header to match. |
| :return: The matched mime type or None if no match is found. |
| """ |
|
|
| try: |
| if isinstance(supported, str): |
| supported = supported.split(',') |
|
|
| supported = [s for s in supported if s.strip() and '/' in s] |
|
|
| if len(supported) == 0: |
| |
| supported = ['audio/*', 'video/webm'] |
|
|
| match = mimeparse.best_match(supported, header) |
| if not match: |
| return None |
|
|
| _, _, match_params = mimeparse.parse_mime_type(match) |
| _, _, header_params = mimeparse.parse_mime_type(header) |
| for k, v in match_params.items(): |
| if header_params.get(k) != v: |
| return None |
|
|
| return match |
| except Exception as e: |
| log.exception(f'Failed to match mime type {header}: {e}') |
| return None |
|
|
|
|
| def extract_urls(text: str) -> list[str]: |
| |
| url_pattern = re.compile(r'(https?://[^\s]+)', re.IGNORECASE) |
| return url_pattern.findall(text) |
|
|
|
|
| |
| |
| |
| async def cleanup_response( |
| response: aiohttp.ClientResponse | None, |
| session: aiohttp.ClientSession | None, |
| ): |
| if response: |
| if not response.closed: |
| |
| |
| result = response.close() |
| if result is not None: |
| await result |
| if session: |
| if not session.closed: |
| result = session.close() |
| if result is not None: |
| await result |
|
|
|
|
| async def stream_wrapper(response, session, content_handler=None): |
| """ |
| Wrap a stream to ensure cleanup happens even if streaming is interrupted. |
| This is more reliable than BackgroundTask which may not run if client disconnects. |
| """ |
| try: |
| stream = content_handler(response.content) if content_handler else response.content |
| async for chunk in stream: |
| yield chunk |
| finally: |
| await cleanup_response(response, session) |
|
|
|
|
| def stream_chunks_handler(stream: aiohttp.StreamReader): |
| """ |
| Handle stream response chunks, supporting large data chunks that exceed the original 16kb limit. |
| When a single line exceeds max_buffer_size, returns an empty JSON string {} and skips subsequent data |
| until encountering normally sized data. |
| |
| :param stream: The stream reader to handle. |
| :return: An async generator that yields the stream data. |
| """ |
|
|
| max_buffer_size = CHAT_STREAM_RESPONSE_CHUNK_MAX_BUFFER_SIZE |
| if max_buffer_size is None or max_buffer_size <= 0: |
| return stream |
|
|
| async def yield_safe_stream_chunks(): |
| buffer = b'' |
| skip_mode = False |
|
|
| async for data, _ in stream.iter_chunks(): |
| if not data: |
| continue |
|
|
| |
| if skip_mode and len(buffer) > max_buffer_size: |
| buffer = b'' |
|
|
| lines = (buffer + data).split(b'\n') |
|
|
| |
| for i in range(len(lines) - 1): |
| line = lines[i] |
|
|
| if skip_mode: |
| |
| if len(line) <= max_buffer_size: |
| skip_mode = False |
| yield line |
| else: |
| yield b'data: {}\n' |
| else: |
| |
| if len(line) > max_buffer_size: |
| skip_mode = True |
| yield b'data: {}\n' |
| log.info(f'Skip mode triggered, line size: {len(line)}') |
| else: |
| yield line + b'\n' |
|
|
| |
| buffer = lines[-1] |
|
|
| |
| if not skip_mode and len(buffer) > max_buffer_size: |
| skip_mode = True |
| log.info(f'Skip mode triggered, buffer size: {len(buffer)}') |
| |
| buffer = b'' |
|
|
| |
| if buffer and not skip_mode: |
| yield buffer + b'\n' |
|
|
| return yield_safe_stream_chunks() |
|
|