from __future__ import annotations import os import json import logging import time import uuid import threading import requests from flask import Flask, request, Response, jsonify, stream_with_context from flask_cors import CORS # ── App setup ───────────────────────────────────────────────────────────────── app = Flask(__name__) CORS(app) app.config["MAX_CONTENT_LENGTH"] = 50 * 1024 * 1024 # 50MB max body # ── Logging ─────────────────────────────────────────────────────────────────── logging.basicConfig( level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s", datefmt="%Y-%m-%d %H:%M:%S", ) logger = logging.getLogger("bridge") # ── Config ──────────────────────────────────────────────────────────────────── import os PROVIDER_BASE_URL = os.environ["PROVIDER_BASE_URL"] # will raise KeyError if missing PROVIDER_API_KEY = os.environ["PROVIDER_API_KEY"] PROVIDER_HEADERS = { "Authorization": f"Bearer {PROVIDER_API_KEY}", "Content-Type": "application/json", } TIMEOUTS = { "chat": 120, "image": 180, "audio": 180, "embeddings": 60, "models": 10, "moderations": 30, "files": 60, "batches": 30, } RETRY_MAX = 3 RETRY_STATUSES = {429, 500, 502, 503, 504} RETRY_BASE_DELAY = 1.0 KNOWN_CHAT_PARAMS = { "model", "messages", "stream", "temperature", "top_p", "n", "max_tokens", "stop", "presence_penalty", "frequency_penalty", "logit_bias", "logprobs", "top_logprobs", "user", "tools", "tool_choice", "response_format", "seed", "functions", "function_call", } KNOWN_COMPLETIONS_PARAMS = { "model", "prompt", "stream", "temperature", "top_p", "n", "max_tokens", "stop", "presence_penalty", "frequency_penalty", "logit_bias", "logprobs", "echo", "best_of", "suffix", "user", "seed", } VALID_ROLES = {"system", "user", "assistant", "tool", "function"} # ── In-memory usage stats ───────────────────────────────────────────────────── _stats_lock = threading.Lock() _stats = { "started_at": int(time.time()), "total_requests": 0, "total_errors": 0, "total_tokens": 0, "total_latency_s": 0.0, "by_endpoint": {}, } def _stats_record(endpoint: str, tokens: int = 0, latency: float = 0.0, error: bool = False): with _stats_lock: _stats["total_requests"] += 1 _stats["total_latency_s"] += latency _stats["total_tokens"] += tokens if error: _stats["total_errors"] += 1 ep = _stats["by_endpoint"].setdefault(endpoint, { "requests": 0, "errors": 0, "tokens": 0, "latency_s": 0.0 }) ep["requests"] += 1 ep["latency_s"] += latency ep["tokens"] += tokens if error: ep["errors"] += 1 # ── In-memory file + batch stores (stub backing store) ─────────────────────── # Real file bytes are stored here keyed by file_id. # In production you'd swap this for S3/GCS/disk. _files_lock = threading.Lock() _files_store: dict = {} # file_id -> {meta, bytes} _batches_lock = threading.Lock() _batches_store: dict = {} # batch_id -> batch object # ── Logging helpers ─────────────────────────────────────────────────────────── def log_req(endpoint: str, model: str = "-", stream: bool = False): logger.info(f"→ {endpoint} | model={model} | stream={stream}") def log_resp(endpoint: str, model: str, usage: dict, latency: float): tokens = (usage.get("total_tokens") or usage.get("input_tokens", 0) + usage.get("output_tokens", 0)) logger.info(f"← {endpoint} | model={model} | tokens={tokens} | latency={latency:.2f}s") _stats_record(endpoint, tokens=tokens, latency=latency) def log_err(endpoint: str, status: int, message: str): logger.warning(f"✗ {endpoint} | status={status} | {message}") _stats_record(endpoint, error=True) # ── Error helpers ───────────────────────────────────────────────────────────── def openai_error(message, error_type="invalid_request_error", code=None, status=400, param=None): log_err(request.path, status, message) return jsonify({"error": { "message": message, "type": error_type, "param": param, "code": code, }}), status def provider_error_to_openai(http_error): status = http_error.response.status_code try: d = http_error.response.json() msg = d.get("error", {}).get("message") or d.get("message") or str(http_error) etype = d.get("error", {}).get("type", "provider_error") code = d.get("error", {}).get("code") except Exception: msg, etype, code = str(http_error), "provider_error", None return openai_error(msg, etype, code, status) # ── Retry logic ─────────────────────────────────────────────────────────────── def _do_request(method: str, url: str, timeout: int, **kwargs) -> requests.Response: is_stream = kwargs.get("stream", False) delay = RETRY_BASE_DELAY for attempt in range(1, RETRY_MAX + 1): try: resp = requests.request(method, url, timeout=timeout, **kwargs) if resp.status_code in RETRY_STATUSES and not is_stream and attempt < RETRY_MAX: retry_after = float(resp.headers.get("Retry-After", delay)) logger.warning(f"Provider {resp.status_code} — retrying in {retry_after:.1f}s ({attempt}/{RETRY_MAX})") time.sleep(retry_after) delay *= 2 continue resp.raise_for_status() return resp except (requests.exceptions.Timeout, requests.exceptions.ConnectionError) as e: if is_stream or attempt == RETRY_MAX: raise logger.warning(f"{e.__class__.__name__} — retrying in {delay:.1f}s ({attempt}/{RETRY_MAX})") time.sleep(delay) delay *= 2 raise requests.exceptions.RetryError("Max retries exceeded") # ── Validation ──────────────────────────────────────────────────────────────── def validate_messages(messages): if not isinstance(messages, list) or len(messages) == 0: return "Field 'messages' must be a non-empty array." for i, msg in enumerate(messages): if not isinstance(msg, dict): return f"messages[{i}] must be an object." if "role" not in msg: return f"messages[{i}] is missing required field 'role'." if msg["role"] not in VALID_ROLES: return (f"messages[{i}].role '{msg['role']}' is not valid. " f"Must be one of: {', '.join(sorted(VALID_ROLES))}.") if "content" not in msg and "tool_calls" not in msg: return f"messages[{i}] must have 'content' or 'tool_calls'." return None def check_unknown_params(body: dict, known: set, endpoint: str): unknown = set(body.keys()) - known if unknown: logger.warning(f"{endpoint} — unknown params forwarded: {sorted(unknown)}") def validate_numeric(body: dict, param: str, lo: float, hi: float): if param not in body: return None v = body[param] if not isinstance(v, (int, float)): return f"'{param}' must be a number." if not (lo <= v <= hi): return f"'{param}' must be between {lo} and {hi}." return None # ── JSON mode / response_format handling ───────────────────────────────────── def _validate_response_format(response_format) -> str | None: """ Validate response_format and return an error string if invalid, else None. Accepted forms: {"type": "text"} {"type": "json_object"} {"type": "json_schema", "json_schema": {...}} """ if response_format is None: return None if not isinstance(response_format, dict): return "'response_format' must be an object." rf_type = response_format.get("type") if rf_type not in ("text", "json_object", "json_schema"): return (f"'response_format.type' must be one of: 'text', 'json_object', 'json_schema'. " f"Got: {rf_type!r}.") if rf_type == "json_schema": if "json_schema" not in response_format: return "'response_format.json_schema' is required when type is 'json_schema'." schema = response_format["json_schema"] if not isinstance(schema, dict): return "'response_format.json_schema' must be an object." if "name" not in schema: return "'response_format.json_schema.name' is required." return None def _handle_json_mode(body: dict, provider_resp_data: dict) -> dict: """ If response_format is json_object or json_schema, try to ensure the assistant reply is valid JSON. If it isn't, wrap it gracefully. """ rf = body.get("response_format", {}) if not rf or rf.get("type") not in ("json_object", "json_schema"): return provider_resp_data for choice in provider_resp_data.get("choices", []): content = choice.get("message", {}).get("content", "") if not content: continue # Strip markdown code fences if present stripped = content.strip() if stripped.startswith("```"): lines = stripped.split("\n") # Remove opening fence (```json or ```) and closing fence inner = lines[1:] if lines[0].startswith("```") else lines if inner and inner[-1].strip() == "```": inner = inner[:-1] stripped = "\n".join(inner).strip() try: json.loads(stripped) # Valid JSON — update content with cleaned version choice["message"]["content"] = stripped except json.JSONDecodeError: logger.warning( "Provider response is not valid JSON despite json_object mode. " "Wrapping in error envelope." ) # Return a valid JSON object that signals the parse failure choice["message"]["content"] = json.dumps({ "error": "Provider did not return valid JSON.", "raw": content, }) return provider_resp_data def _inject_json_mode_system_message(messages: list, response_format: dict) -> list: """ If provider doesn't natively support response_format, inject a system message instructing the model to reply with pure JSON. We do this transparently so clients get reliable JSON back. """ rf_type = (response_format or {}).get("type") if rf_type not in ("json_object", "json_schema"): return messages # Check if there's already a system message with JSON instruction for msg in messages: if msg.get("role") == "system" and "json" in msg.get("content", "").lower(): return messages # Already has JSON instruction instruction = "You must respond with valid JSON only. Do not include any explanation, markdown, or code fences." if rf_type == "json_schema": schema_obj = response_format.get("json_schema", {}) schema_def = schema_obj.get("schema") if schema_def: instruction += f" Your response must conform to this JSON schema: {json.dumps(schema_def)}" # Prepend or append to existing system message new_messages = [] has_system = any(m.get("role") == "system" for m in messages) if has_system: for msg in messages: if msg.get("role") == "system": new_messages.append({**msg, "content": msg["content"] + "\n\n" + instruction}) else: new_messages.append(msg) else: new_messages = [{"role": "system", "content": instruction}] + messages return new_messages # ── Provider calls ──────────────────────────────────────────────────────────── def call_provider_chat(messages, model="gpt-4", stream=False, **kwargs): payload = {"model": model, "messages": messages, "stream": stream, **kwargs} return _do_request( "POST", f"{PROVIDER_BASE_URL}/chat/completions", timeout=TIMEOUTS["chat"], headers=PROVIDER_HEADERS, json=payload, stream=stream, ) # ── Streaming heartbeat ─────────────────────────────────────────────────────── def heartbeat_generator(inner_gen, interval=15): import queue q = queue.Queue() sentinel = object() def producer(): try: for chunk in inner_gen: q.put(chunk) finally: q.put(sentinel) t = threading.Thread(target=producer, daemon=True) t.start() while True: try: item = q.get(timeout=interval) if item is sentinel: break yield item except Exception: yield ": keep-alive\n\n" # ── /v1/models ──────────────────────────────────────────────────────────────── @app.route("/v1/models", methods=["GET"]) def list_models(): logger.info("→ GET /v1/models") t0 = time.time() try: resp = _do_request("GET", f"{PROVIDER_BASE_URL}/models", timeout=TIMEOUTS["models"], headers=PROVIDER_HEADERS) logger.info(f"← GET /v1/models | latency={time.time()-t0:.2f}s") return jsonify(resp.json()) except requests.HTTPError as e: return provider_error_to_openai(e) except requests.exceptions.Timeout: return openai_error("Provider timed out fetching models.", "server_error", "timeout", 504) except requests.exceptions.ConnectionError: return openai_error("Could not connect to provider.", "server_error", "connection_error", 502) @app.route("/v1/models/", methods=["GET"]) def retrieve_model(model_id): logger.info(f"→ GET /v1/models/{model_id}") t0 = time.time() try: resp = _do_request("GET", f"{PROVIDER_BASE_URL}/models/{model_id}", timeout=TIMEOUTS["models"], headers=PROVIDER_HEADERS) logger.info(f"← GET /v1/models/{model_id} | latency={time.time()-t0:.2f}s") return jsonify(resp.json()) except requests.HTTPError as e: return provider_error_to_openai(e) except requests.exceptions.Timeout: return openai_error("Provider timed out fetching model.", "server_error", "timeout", 504) except requests.exceptions.ConnectionError: return openai_error("Could not connect to provider.", "server_error", "connection_error", 502) # ── /v1/moderations ────────────────────────────────────────────────────────── @app.route("/v1/moderations", methods=["POST"]) def moderations(): body = request.get_json(force=True, silent=True) if not body: return openai_error("Could not parse request body as JSON.") if "input" not in body: return openai_error("Missing required field: input.") try: resp = _do_request("POST", f"{PROVIDER_BASE_URL}/moderations", timeout=TIMEOUTS["moderations"], headers=PROVIDER_HEADERS, json=body) return jsonify(resp.json()) except Exception: pass inputs = body["input"] if isinstance(body["input"], list) else [body["input"]] cats = ["sexual", "hate", "harassment", "self-harm", "sexual/minors", "hate/threatening", "violence/graphic", "self-harm/intent", "self-harm/instructions", "harassment/threatening", "violence"] results = [{"flagged": False, "categories": {k: False for k in cats}, "category_scores": {k: 0.0 for k in cats}} for _ in inputs] logger.info("→ /v1/moderations | stub (provider unsupported)") return jsonify({"id": "modr-" + str(uuid.uuid4()).replace("-", ""), "model": body.get("model", "text-moderation-latest"), "results": results}) # ── /v1/chat/completions ────────────────────────────────────────────────────── @app.route("/v1/chat/completions", methods=["POST"]) def chat_completions(): body = request.get_json(force=True, silent=True) if not body: return openai_error("Could not parse request body as JSON.") err = validate_messages(body.get("messages", [])) if err: return openai_error(err, param="messages") for param, lo, hi in [("temperature", 0.0, 2.0), ("top_p", 0.0, 1.0), ("presence_penalty", -2.0, 2.0), ("frequency_penalty", -2.0, 2.0), ("n", 1, 128)]: err = validate_numeric(body, param, lo, hi) if err: return openai_error(err, param=param) # ── response_format / JSON mode ────────────────────────────────────────── rf = body.get("response_format") err = _validate_response_format(rf) if err: return openai_error(err, param="response_format") check_unknown_params(body, KNOWN_CHAT_PARAMS, "/v1/chat/completions") model = body.get("model", "gpt-4") stream = body.get("stream", False) messages = body["messages"] # Inject JSON instruction into messages so providers that don't support # response_format natively still return valid JSON messages = _inject_json_mode_system_message(messages, rf) log_req("/v1/chat/completions", model, stream) t0 = time.time() extra = {k: v for k, v in body.items() if k not in ("messages", "model", "stream")} try: provider_resp = call_provider_chat(messages=messages, model=model, stream=stream, **extra) except requests.HTTPError as e: return provider_error_to_openai(e) except requests.exceptions.Timeout: return openai_error("Request to provider timed out.", "server_error", "timeout", 504) except requests.exceptions.ConnectionError: return openai_error("Could not connect to provider.", "server_error", "connection_error", 502) if stream: logger.info(f"← /v1/chat/completions | stream=True | latency={time.time()-t0:.2f}s") _stats_record("/v1/chat/completions", latency=time.time() - t0) def generate(): for line in provider_resp.iter_lines(): if line: yield line.decode("utf-8") + "\n\n" return Response(stream_with_context(heartbeat_generator(generate())), content_type="text/event-stream") data = provider_resp.json() # Post-process JSON mode — ensure content is valid JSON data = _handle_json_mode(body, data) log_resp("/v1/chat/completions", model, data.get("usage", {}), time.time() - t0) return jsonify(data) # ── /v1/completions (legacy) ────────────────────────────────────────────────── def _completions_to_chat_messages(prompt) -> list: if isinstance(prompt, str): return [{"role": "user", "content": prompt}] if isinstance(prompt, list): return [{"role": "user", "content": "\n".join(str(p) for p in prompt)}] return [{"role": "user", "content": str(prompt)}] def _chat_to_completions_response(chat_resp: dict, model: str) -> dict: choices = [] for i, choice in enumerate(chat_resp.get("choices", [])): text = choice.get("message", {}).get("content", "") or "" choices.append({"text": text, "index": i, "logprobs": None, "finish_reason": choice.get("finish_reason", "stop")}) return { "id": chat_resp.get("id", "cmpl-" + str(uuid.uuid4()).replace("-", "")), "object": "text_completion", "created": chat_resp.get("created", int(time.time())), "model": chat_resp.get("model", model), "choices": choices, "usage": chat_resp.get("usage", {}), } def _stream_completions_format(provider_resp, model: str): cmpl_id = "cmpl-" + str(uuid.uuid4()).replace("-", "") created = int(time.time()) for line in provider_resp.iter_lines(): if not line: continue decoded = line.decode("utf-8") if decoded.startswith("data: "): decoded = decoded[6:] if decoded.strip() == "[DONE]": yield "data: [DONE]\n\n" break try: chunk = json.loads(decoded) except json.JSONDecodeError: continue delta = chunk.get("choices", [{}])[0].get("delta", {}) finish_reason = chunk.get("choices", [{}])[0].get("finish_reason") text = delta.get("content", "") or "" out = {"id": cmpl_id, "object": "text_completion", "created": created, "model": model, "choices": [{"text": text, "index": 0, "logprobs": None, "finish_reason": finish_reason}]} yield f"data: {json.dumps(out)}\n\n" @app.route("/v1/completions", methods=["POST"]) def legacy_completions(): body = request.get_json(force=True, silent=True) if not body: return openai_error("Could not parse request body as JSON.") if "prompt" not in body: return openai_error("Missing required field: prompt.", param="prompt") for param, lo, hi in [("temperature", 0.0, 2.0), ("top_p", 0.0, 1.0), ("presence_penalty", -2.0, 2.0), ("frequency_penalty", -2.0, 2.0)]: err = validate_numeric(body, param, lo, hi) if err: return openai_error(err, param=param) check_unknown_params(body, KNOWN_COMPLETIONS_PARAMS, "/v1/completions") model = body.get("model", "gpt-3.5-turbo") stream = body.get("stream", False) messages = _completions_to_chat_messages(body["prompt"]) extra = {k: v for k, v in body.items() if k in ("temperature", "top_p", "n", "max_tokens", "stop", "presence_penalty", "frequency_penalty", "logit_bias", "user", "seed")} log_req("/v1/completions", model, stream) t0 = time.time() try: provider_resp = call_provider_chat(messages=messages, model=model, stream=stream, **extra) except requests.HTTPError as e: return provider_error_to_openai(e) except requests.exceptions.Timeout: return openai_error("Request to provider timed out.", "server_error", "timeout", 504) except requests.exceptions.ConnectionError: return openai_error("Could not connect to provider.", "server_error", "connection_error", 502) if stream: logger.info(f"← /v1/completions | stream=True | latency={time.time()-t0:.2f}s") _stats_record("/v1/completions", latency=time.time() - t0) return Response(stream_with_context(heartbeat_generator(_stream_completions_format(provider_resp, model))), content_type="text/event-stream") data = provider_resp.json() log_resp("/v1/completions", model, data.get("usage", {}), time.time() - t0) return jsonify(_chat_to_completions_response(data, model)) # ── /v1/responses ───────────────────────────────────────────────────────────── def messages_from_responses_input(input_data): if isinstance(input_data, str): return [{"role": "user", "content": input_data}] if isinstance(input_data, list): messages = [] for item in input_data: if isinstance(item, str): messages.append({"role": "user", "content": item}) elif isinstance(item, dict): role = item.get("role", "user") content = item.get("content", "") if isinstance(content, list): parts = [] for p in content: if not isinstance(p, dict): continue ptype = p.get("type", "") if ptype in ("input_text", "text"): parts.append({"type": "text", "text": p.get("text", "")}) elif ptype == "input_image": url = p.get("image_url") or p.get("url", "") if url: parts.append({"type": "image_url", "image_url": {"url": url}}) content = parts if parts else "" messages.append({"role": role, "content": content}) return messages return [] def tools_from_responses(tools): if not tools: return [] converted = [] for t in tools: if t.get("type") == "function": converted.append(t) elif t.get("type") == "web_search": converted.append({"type": "function", "function": { "name": "web_search", "description": "Search the web for current information.", "parameters": { "type": "object", "properties": {"query": {"type": "string", "description": "Search query"}}, "required": ["query"], }, }}) return converted def chat_choice_to_responses_output(choice): message = choice.get("message", {}) content_text = message.get("content", "") or "" tool_calls = message.get("tool_calls") or [] item_id = "msg_" + str(uuid.uuid4()).replace("-", "") output_content = [] if content_text: output_content.append({"type": "output_text", "text": content_text}) for tc in tool_calls: fn = tc.get("function", {}) try: args = json.loads(fn.get("arguments", "{}")) except Exception: args = {} output_content.append({ "type": "tool_use", "id": tc.get("id", "call_" + str(uuid.uuid4()).replace("-", "")), "name": fn.get("name", ""), "input": args, }) return {"id": item_id, "type": "message", "role": "assistant", "content": output_content, "status": "completed"} def chat_response_to_responses_format(chat_resp: dict, model: str) -> dict: choice = chat_resp.get("choices", [{}])[0] usage = chat_resp.get("usage", {}) output_item = chat_choice_to_responses_output(choice) return { "id": "resp_" + chat_resp.get("id", str(uuid.uuid4())).replace("chatcmpl-", ""), "object": "response", "created_at": chat_resp.get("created", int(time.time())), "model": chat_resp.get("model", model), "status": "completed", "output": [output_item], "usage": { "input_tokens": usage.get("prompt_tokens", 0), "output_tokens": usage.get("completion_tokens", 0), "total_tokens": usage.get("total_tokens", 0), }, "finish_reason": choice.get("finish_reason", "stop"), } def stream_responses_format(provider_resp, model): response_id = "resp_" + str(uuid.uuid4()).replace("-", "") created_at = int(time.time()) output_index = 0 item_id = "msg_" + str(uuid.uuid4()).replace("-", "") def event(data): return f"data: {json.dumps(data)}\n\n" yield event({"type": "response.created", "response": { "id": response_id, "object": "response", "created_at": created_at, "model": model, "status": "in_progress", "output": [], }}) yield event({"type": "response.output_item.added", "output_index": output_index, "item": { "id": item_id, "type": "message", "role": "assistant", "content": [], "status": "in_progress", }}) yield event({"type": "response.content_part.added", "item_id": item_id, "output_index": output_index, "content_index": 0, "part": {"type": "output_text", "text": ""}}) full_text = "" tool_calls_map = {} for line in provider_resp.iter_lines(): if not line: continue decoded = line.decode("utf-8") if decoded.startswith("data: "): decoded = decoded[6:] if decoded.strip() == "[DONE]": break try: chunk = json.loads(decoded) except json.JSONDecodeError: continue delta = chunk.get("choices", [{}])[0].get("delta", {}) text = delta.get("content", "") if text: full_text += text yield event({"type": "response.output_text.delta", "item_id": item_id, "output_index": output_index, "content_index": 0, "delta": text}) for tc_delta in (delta.get("tool_calls") or []): idx = tc_delta.get("index", 0) if idx not in tool_calls_map: tool_calls_map[idx] = { "id": tc_delta.get("id", "call_" + str(uuid.uuid4()).replace("-", "")), "name": "", "arguments": "", } fn = tc_delta.get("function", {}) if fn.get("name"): tool_calls_map[idx]["name"] += fn["name"] if fn.get("arguments"): tool_calls_map[idx]["arguments"] += fn["arguments"] yield event({"type": "response.content_part.done", "item_id": item_id, "output_index": output_index, "content_index": 0, "part": {"type": "output_text", "text": full_text}}) final_content = [{"type": "output_text", "text": full_text}] content_index = 1 for tc in tool_calls_map.values(): try: args = json.loads(tc["arguments"] or "{}") except Exception: args = {} tool_part = {"type": "tool_use", "id": tc["id"], "name": tc["name"], "input": args} final_content.append(tool_part) yield event({"type": "response.content_part.added", "item_id": item_id, "output_index": output_index, "content_index": content_index, "part": tool_part}) yield event({"type": "response.content_part.done", "item_id": item_id, "output_index": output_index, "content_index": content_index, "part": tool_part}) content_index += 1 yield event({"type": "response.output_item.done", "output_index": output_index, "item": { "id": item_id, "type": "message", "role": "assistant", "content": final_content, "status": "completed", }}) yield event({"type": "response.done", "response": { "id": response_id, "object": "response", "created_at": created_at, "model": model, "status": "completed", "output": [{"id": item_id, "type": "message", "role": "assistant", "content": final_content, "status": "completed"}], }}) yield "data: [DONE]\n\n" @app.route("/v1/responses", methods=["POST"]) def responses(): body = request.get_json(force=True, silent=True) if not body: return openai_error("Could not parse request body as JSON.") if "input" not in body: return openai_error("Missing required field: input.") model = body.get("model", "gpt-4") stream = body.get("stream", False) log_req("/v1/responses", model, stream) t0 = time.time() messages = [] if body.get("instructions"): messages.append({"role": "system", "content": body["instructions"]}) messages.extend(messages_from_responses_input(body["input"])) for prev in body.get("conversation", []): messages.append({"role": prev.get("role", "user"), "content": prev.get("content", "")}) extra = {} for src, dst in [("max_output_tokens", "max_tokens"), ("temperature", "temperature"), ("top_p", "top_p"), ("stop", "stop")]: if src in body: extra[dst] = body[src] tools = tools_from_responses(body.get("tools")) if tools: extra["tools"] = tools if body.get("tool_choice"): extra["tool_choice"] = body["tool_choice"] try: provider_resp = call_provider_chat(messages=messages, model=model, stream=stream, **extra) except requests.HTTPError as e: return provider_error_to_openai(e) except requests.exceptions.Timeout: return openai_error("Request to provider timed out.", "server_error", "timeout", 504) except requests.exceptions.ConnectionError: return openai_error("Could not connect to provider.", "server_error", "connection_error", 502) if stream: logger.info(f"← /v1/responses | stream=True | latency={time.time()-t0:.2f}s") _stats_record("/v1/responses", latency=time.time() - t0) return Response(stream_with_context(heartbeat_generator(stream_responses_format(provider_resp, model))), content_type="text/event-stream") data = provider_resp.json() log_resp("/v1/responses", model, data.get("usage", {}), time.time() - t0) return jsonify(chat_response_to_responses_format(data, model)) # ── /v1/images/generations ──────────────────────────────────────────────────── @app.route("/v1/images/generations", methods=["POST"]) @app.route("/v1/image/generations", methods=["POST"]) def image_generations(): body = request.get_json(force=True, silent=True) if not body: return openai_error("Could not parse request body as JSON.") if "prompt" not in body: return openai_error("Missing required field: prompt.", param="prompt") logger.info(f"→ /v1/images/generations | prompt={body['prompt'][:60]!r}") t0 = time.time() try: resp = _do_request("POST", f"{PROVIDER_BASE_URL}/images/generations", timeout=TIMEOUTS["image"], headers=PROVIDER_HEADERS, json=body) except requests.HTTPError as e: return provider_error_to_openai(e) except requests.exceptions.Timeout: return openai_error("Request to provider timed out.", "server_error", "timeout", 504) except requests.exceptions.ConnectionError: return openai_error("Could not connect to provider.", "server_error", "connection_error", 502) log_resp("/v1/images/generations", body.get("model", "dall-e-3"), {}, time.time() - t0) return jsonify(resp.json()) # ── /v1/embeddings ──────────────────────────────────────────────────────────── @app.route("/v1/embeddings", methods=["POST"]) def embeddings(): body = request.get_json(force=True, silent=True) if not body: return openai_error("Could not parse request body as JSON.") if "input" not in body: return openai_error("Missing required field: input.", param="input") t0 = time.time() try: resp = _do_request("POST", f"{PROVIDER_BASE_URL}/embeddings", timeout=TIMEOUTS["embeddings"], headers=PROVIDER_HEADERS, json=body) log_resp("/v1/embeddings", body.get("model", "text-embedding-ada-002"), resp.json().get("usage", {}), time.time() - t0) return jsonify(resp.json()) except requests.HTTPError as e: return provider_error_to_openai(e) except Exception: return openai_error("Embeddings are not supported by this provider.", "invalid_request_error", "unsupported_endpoint", 400) # ── /v1/audio/* ─────────────────────────────────────────────────────────────── @app.route("/v1/audio/transcriptions", methods=["POST"]) def audio_transcriptions(): if "file" not in request.files: return openai_error("Missing required file field: file.", param="file") audio_file = request.files["file"] form = request.form.to_dict() logger.info(f"→ /v1/audio/transcriptions | model={form.get('model', 'whisper-1')}") t0 = time.time() try: resp = _do_request( "POST", f"{PROVIDER_BASE_URL}/audio/transcriptions", timeout=TIMEOUTS["audio"], headers={"Authorization": f"Bearer {PROVIDER_API_KEY}"}, files={"file": (audio_file.filename, audio_file.stream, audio_file.content_type)}, data=form, ) logger.info(f"← /v1/audio/transcriptions | latency={time.time()-t0:.2f}s") return jsonify(resp.json()) except requests.HTTPError as e: return provider_error_to_openai(e) except Exception: return openai_error("Audio transcription is not supported by this provider.", "invalid_request_error", "unsupported_endpoint", 501) @app.route("/v1/audio/speech", methods=["POST"]) def audio_speech(): body = request.get_json(force=True, silent=True) if not body: return openai_error("Could not parse request body as JSON.") if "input" not in body: return openai_error("Missing required field: input.", param="input") logger.info(f"→ /v1/audio/speech | voice={body.get('voice', 'alloy')}") t0 = time.time() try: resp = _do_request("POST", f"{PROVIDER_BASE_URL}/audio/speech", timeout=TIMEOUTS["audio"], headers=PROVIDER_HEADERS, json=body, stream=True) content_type = resp.headers.get("Content-Type", "audio/mpeg") logger.info(f"← /v1/audio/speech | latency={time.time()-t0:.2f}s") return Response(resp.iter_content(chunk_size=4096), content_type=content_type) except requests.HTTPError as e: return provider_error_to_openai(e) except Exception: return openai_error("Text-to-speech is not supported by this provider.", "invalid_request_error", "unsupported_endpoint", 501) # ── /v1/files ───────────────────────────────────────────────────────────────── # Full CRUD implementation backed by in-memory store. # Tries provider first; falls back to local store if provider doesn't support files. def _file_meta(file_id: str, filename: str, purpose: str, size: int, created: int) -> dict: return { "id": file_id, "object": "file", "bytes": size, "created_at": created, "filename": filename, "purpose": purpose, "status": "processed", "status_details": None, } @app.route("/v1/files", methods=["POST"]) def upload_file(): """Upload a file. Tries provider first, falls back to local store.""" if "file" not in request.files: return openai_error("Missing required field: file.", param="file") purpose = request.form.get("purpose", "") if not purpose: return openai_error("Missing required field: purpose.", param="purpose") valid_purposes = {"fine-tune", "assistants", "batch", "vision", "user_data"} if purpose not in valid_purposes: return openai_error(f"'purpose' must be one of: {', '.join(sorted(valid_purposes))}.", param="purpose") uploaded = request.files["file"] file_bytes = uploaded.read() log_req("/v1/files", stream=False) # Try provider try: resp = _do_request( "POST", f"{PROVIDER_BASE_URL}/files", timeout=TIMEOUTS["files"], headers={"Authorization": f"Bearer {PROVIDER_API_KEY}"}, files={"file": (uploaded.filename, file_bytes, uploaded.content_type)}, data={"purpose": purpose}, ) data = resp.json() logger.info(f"← /v1/files (upload) | id={data.get('id')} | provider=true") return jsonify(data), 200 except Exception: pass # Local fallback file_id = "file-" + str(uuid.uuid4()).replace("-", "") created = int(time.time()) meta = _file_meta(file_id, uploaded.filename or "upload", purpose, len(file_bytes), created) with _files_lock: _files_store[file_id] = {"meta": meta, "data": file_bytes} logger.info(f"← /v1/files (upload) | id={file_id} | provider=false (local)") return jsonify(meta), 200 @app.route("/v1/files", methods=["GET"]) def list_files(): """List files. Tries provider first, merges with local store.""" purpose = request.args.get("purpose") log_req("/v1/files", stream=False) provider_files = [] try: params = {} if purpose: params["purpose"] = purpose resp = _do_request("GET", f"{PROVIDER_BASE_URL}/files", timeout=TIMEOUTS["files"], headers=PROVIDER_HEADERS, params=params) provider_files = resp.json().get("data", []) except Exception: pass with _files_lock: local_files = [v["meta"] for v in _files_store.values() if not purpose or v["meta"]["purpose"] == purpose] # Merge — provider files take priority if same id seen = {f["id"] for f in provider_files} merged = provider_files + [f for f in local_files if f["id"] not in seen] merged.sort(key=lambda f: f.get("created_at", 0), reverse=True) return jsonify({"object": "list", "data": merged}) @app.route("/v1/files/", methods=["GET"]) def retrieve_file(file_id): """Retrieve file metadata.""" log_req(f"/v1/files/{file_id}", stream=False) try: resp = _do_request("GET", f"{PROVIDER_BASE_URL}/files/{file_id}", timeout=TIMEOUTS["files"], headers=PROVIDER_HEADERS) return jsonify(resp.json()) except Exception: pass with _files_lock: entry = _files_store.get(file_id) if not entry: return openai_error(f"No such file: {file_id!r}", "invalid_request_error", "file_not_found", 404) return jsonify(entry["meta"]) @app.route("/v1/files/", methods=["DELETE"]) def delete_file(file_id): """Delete a file.""" log_req(f"/v1/files/{file_id} DELETE", stream=False) try: resp = _do_request("DELETE", f"{PROVIDER_BASE_URL}/files/{file_id}", timeout=TIMEOUTS["files"], headers=PROVIDER_HEADERS) return jsonify(resp.json()) except Exception: pass with _files_lock: deleted = _files_store.pop(file_id, None) if not deleted: return openai_error(f"No such file: {file_id!r}", "invalid_request_error", "file_not_found", 404) return jsonify({"id": file_id, "object": "file", "deleted": True}) @app.route("/v1/files//content", methods=["GET"]) def retrieve_file_content(file_id): """Download raw file content.""" log_req(f"/v1/files/{file_id}/content", stream=False) try: resp = _do_request("GET", f"{PROVIDER_BASE_URL}/files/{file_id}/content", timeout=TIMEOUTS["files"], headers=PROVIDER_HEADERS, stream=True) content_type = resp.headers.get("Content-Type", "application/octet-stream") return Response(resp.iter_content(chunk_size=8192), content_type=content_type) except Exception: pass with _files_lock: entry = _files_store.get(file_id) if not entry: return openai_error(f"No such file: {file_id!r}", "invalid_request_error", "file_not_found", 404) return Response(entry["data"], content_type="application/octet-stream", headers={"Content-Disposition": f"attachment; filename={entry['meta']['filename']}"}) # ── /v1/batches ─────────────────────────────────────────────────────────────── # Full lifecycle: create, retrieve, list, cancel. # Tries provider first; falls back to local async execution backed by threads. BATCH_VALID_ENDPOINTS = {"/v1/chat/completions", "/v1/embeddings", "/v1/completions"} def _make_batch(batch_id: str, input_file_id: str, endpoint: str, completion_window: str, metadata: dict, created: int) -> dict: return { "id": batch_id, "object": "batch", "endpoint": endpoint, "errors": None, "input_file_id": input_file_id, "completion_window": completion_window, "status": "validating", "output_file_id": None, "error_file_id": None, "created_at": created, "in_progress_at": None, "expires_at": created + 86400, "finalizing_at": None, "completed_at": None, "failed_at": None, "expired_at": None, "cancelling_at": None, "cancelled_at": None, "request_counts": {"total": 0, "completed": 0, "failed": 0}, "metadata": metadata or {}, } def _run_batch(batch_id: str): """Background worker: read JSONL from file store, run requests, write output file.""" with _batches_lock: batch = _batches_store.get(batch_id) if not batch: return now = int(time.time()) batch["status"] = "in_progress" batch["in_progress_at"] = now # Read input file with _files_lock: entry = _files_store.get(batch["input_file_id"]) if not entry: batch["status"] = "failed" batch["failed_at"] = int(time.time()) batch["errors"] = {"object": "list", "data": [ {"code": "file_not_found", "message": "Input file not found.", "line": None, "param": None} ]} return try: lines = entry["data"].decode("utf-8").strip().split("\n") except Exception: batch["status"] = "failed" batch["failed_at"] = int(time.time()) return output_lines = [] total = len(lines) success = 0 failed = 0 endpoint = batch["endpoint"] for i, line in enumerate(lines): line = line.strip() if not line: continue try: req = json.loads(line) except json.JSONDecodeError: failed += 1 output_lines.append(json.dumps({ "id": f"batch_req_{i}", "custom_id": None, "response": None, "error": {"code": "json_parse_error", "message": "Could not parse JSONL line."}, })) continue custom_id = req.get("custom_id", f"req_{i}") body = req.get("body", {}) try: if endpoint == "/v1/chat/completions": messages = body.get("messages", []) model = body.get("model", "gpt-4") extra = {k: v for k, v in body.items() if k not in ("messages", "model", "stream")} pr = call_provider_chat(messages=messages, model=model, stream=False, **extra) resp_body = pr.json() elif endpoint == "/v1/embeddings": pr = _do_request("POST", f"{PROVIDER_BASE_URL}/embeddings", timeout=TIMEOUTS["embeddings"], headers=PROVIDER_HEADERS, json=body) resp_body = pr.json() elif endpoint == "/v1/completions": msgs = _completions_to_chat_messages(body.get("prompt", "")) extra = {k: v for k, v in body.items() if k in ("temperature", "top_p", "max_tokens", "stop", "user")} pr = call_provider_chat(messages=msgs, model=body.get("model", "gpt-3.5-turbo"), stream=False, **extra) resp_body = _chat_to_completions_response(pr.json(), body.get("model", "gpt-3.5-turbo")) else: raise ValueError(f"Unsupported batch endpoint: {endpoint}") success += 1 output_lines.append(json.dumps({ "id": f"batch_req_{uuid.uuid4().hex[:8]}", "custom_id": custom_id, "response": {"status_code": 200, "request_id": str(uuid.uuid4()), "body": resp_body}, "error": None, })) except Exception as e: failed += 1 output_lines.append(json.dumps({ "id": f"batch_req_{uuid.uuid4().hex[:8]}", "custom_id": custom_id, "response": None, "error": {"code": "request_failed", "message": str(e)}, })) # Write output file output_bytes = "\n".join(output_lines).encode("utf-8") output_file_id = "file-" + str(uuid.uuid4()).replace("-", "") out_meta = _file_meta(output_file_id, f"batch_{batch_id}_output.jsonl", "batch_output", len(output_bytes), int(time.time())) with _files_lock: _files_store[output_file_id] = {"meta": out_meta, "data": output_bytes} now = int(time.time()) with _batches_lock: b = _batches_store[batch_id] b["status"] = "completed" b["output_file_id"] = output_file_id b["finalizing_at"] = now b["completed_at"] = now b["request_counts"] = {"total": total, "completed": success, "failed": failed} logger.info(f"Batch {batch_id} completed | total={total} ok={success} failed={failed}") @app.route("/v1/batches", methods=["POST"]) def create_batch(): body = request.get_json(force=True, silent=True) if not body: return openai_error("Could not parse request body as JSON.") for field in ("input_file_id", "endpoint", "completion_window"): if field not in body: return openai_error(f"Missing required field: {field}.", param=field) if body["endpoint"] not in BATCH_VALID_ENDPOINTS: return openai_error( f"'endpoint' must be one of: {', '.join(sorted(BATCH_VALID_ENDPOINTS))}.", param="endpoint" ) log_req("/v1/batches", stream=False) # Try provider first try: resp = _do_request("POST", f"{PROVIDER_BASE_URL}/batches", timeout=TIMEOUTS["batches"], headers=PROVIDER_HEADERS, json=body) return jsonify(resp.json()), 200 except Exception: pass # Local fallback: create batch and run in background thread batch_id = "batch_" + str(uuid.uuid4()).replace("-", "") created = int(time.time()) batch = _make_batch(batch_id, body["input_file_id"], body["endpoint"], body["completion_window"], body.get("metadata"), created) with _batches_lock: _batches_store[batch_id] = batch t = threading.Thread(target=_run_batch, args=(batch_id,), daemon=True) t.start() logger.info(f"← /v1/batches | id={batch_id} | local=true") return jsonify(batch), 200 @app.route("/v1/batches/", methods=["GET"]) def retrieve_batch(batch_id): log_req(f"/v1/batches/{batch_id}", stream=False) try: resp = _do_request("GET", f"{PROVIDER_BASE_URL}/batches/{batch_id}", timeout=TIMEOUTS["batches"], headers=PROVIDER_HEADERS) return jsonify(resp.json()) except Exception: pass with _batches_lock: batch = _batches_store.get(batch_id) if not batch: return openai_error(f"No such batch: {batch_id!r}", "invalid_request_error", "batch_not_found", 404) return jsonify(batch) @app.route("/v1/batches", methods=["GET"]) def list_batches(): log_req("/v1/batches list", stream=False) provider_batches = [] try: resp = _do_request("GET", f"{PROVIDER_BASE_URL}/batches", timeout=TIMEOUTS["batches"], headers=PROVIDER_HEADERS, params=request.args) provider_batches = resp.json().get("data", []) except Exception: pass with _batches_lock: local_batches = list(_batches_store.values()) seen = {b["id"] for b in provider_batches} merged = provider_batches + [b for b in local_batches if b["id"] not in seen] merged.sort(key=lambda b: b.get("created_at", 0), reverse=True) return jsonify({"object": "list", "data": merged, "has_more": False, "first_id": None, "last_id": None}) @app.route("/v1/batches//cancel", methods=["POST"]) def cancel_batch(batch_id): log_req(f"/v1/batches/{batch_id}/cancel", stream=False) try: resp = _do_request("POST", f"{PROVIDER_BASE_URL}/batches/{batch_id}/cancel", timeout=TIMEOUTS["batches"], headers=PROVIDER_HEADERS) return jsonify(resp.json()) except Exception: pass with _batches_lock: batch = _batches_store.get(batch_id) if not batch: return openai_error(f"No such batch: {batch_id!r}", "invalid_request_error", "batch_not_found", 404) if batch["status"] in ("completed", "failed", "expired", "cancelled"): return openai_error(f"Batch {batch_id} cannot be cancelled (status: {batch['status']}).", "invalid_request_error", "batch_not_cancellable") now = int(time.time()) batch["status"] = "cancelled" batch["cancelled_at"] = now batch["cancelling_at"] = now return jsonify(batch) # ── /stats ──────────────────────────────────────────────────────────────────── @app.route("/stats", methods=["GET"]) def stats(): with _stats_lock: s = dict(_stats) by_ep = {k: dict(v) for k, v in _stats["by_endpoint"].items()} total_req = s["total_requests"] avg_latency = round(s["total_latency_s"] / total_req, 3) if total_req else 0.0 error_rate = round(s["total_errors"] / total_req, 4) if total_req else 0.0 for ep, d in by_ep.items(): d["avg_latency_s"] = round(d["latency_s"] / d["requests"], 3) if d["requests"] else 0.0 d["error_rate"] = round(d["errors"] / d["requests"], 4) if d["requests"] else 0.0 with _files_lock: file_count = len(_files_store) with _batches_lock: batch_count = len(_batches_store) return jsonify({ "uptime_s": int(time.time()) - s["started_at"], "total_requests": total_req, "total_errors": s["total_errors"], "error_rate": error_rate, "total_tokens": s["total_tokens"], "avg_latency_s": avg_latency, "local_files": file_count, "local_batches": batch_count, "by_endpoint": by_ep, }) # ── /v1/ discovery ──────────────────────────────────────────────────────────── @app.route("/v1/", methods=["GET"]) def openapi_discovery(): return jsonify({ "name": "OpenAI-compatible bridge", "version": "1.0.0", "provider": PROVIDER_BASE_URL, "notes": { "json_mode": ( "response_format={type:json_object} and json_schema are fully supported. " "If the provider does not natively support them, a system message is injected " "to enforce JSON output and the response is post-processed to ensure validity." ), "files": ( "Full CRUD file support. Tries provider first; falls back to in-memory store. " "In-memory files are lost on restart." ), "batches": ( "Full batch lifecycle (create/retrieve/list/cancel). Tries provider first; " "falls back to local background thread execution. " "Supports endpoints: /v1/chat/completions, /v1/embeddings, /v1/completions." ), }, "endpoints": [ {"method": "GET", "path": "/v1/models", "description": "List models (proxied from provider)."}, {"method": "GET", "path": "/v1/models/{model}", "description": "Retrieve a model."}, {"method": "POST", "path": "/v1/chat/completions", "description": "Chat completion. Supports response_format (json_object, json_schema, text), tools, streaming."}, {"method": "POST", "path": "/v1/completions", "description": "Legacy text completions (converted to chat internally)."}, {"method": "POST", "path": "/v1/responses", "description": "OpenAI Responses API (converted to chat internally)."}, {"method": "POST", "path": "/v1/images/generations", "description": "Image generation."}, {"method": "POST", "path": "/v1/embeddings", "description": "Text embeddings."}, {"method": "POST", "path": "/v1/moderations", "description": "Content moderation (stub fallback)."}, {"method": "POST", "path": "/v1/audio/transcriptions", "description": "Audio transcription."}, {"method": "POST", "path": "/v1/audio/speech", "description": "Text-to-speech."}, {"method": "POST", "path": "/v1/files", "description": "Upload a file."}, {"method": "GET", "path": "/v1/files", "description": "List files."}, {"method": "GET", "path": "/v1/files/{file_id}", "description": "Retrieve file metadata."}, {"method": "DELETE", "path": "/v1/files/{file_id}", "description": "Delete a file."}, {"method": "GET", "path": "/v1/files/{file_id}/content", "description": "Download file content."}, {"method": "POST", "path": "/v1/batches", "description": "Create a batch job."}, {"method": "GET", "path": "/v1/batches", "description": "List batch jobs."}, {"method": "GET", "path": "/v1/batches/{batch_id}", "description": "Retrieve a batch job."}, {"method": "POST", "path": "/v1/batches/{batch_id}/cancel", "description": "Cancel a batch job."}, {"method": "GET", "path": "/stats", "description": "Bridge usage stats."}, {"method": "GET", "path": "/health", "description": "Provider health check."}, {"method": "GET", "path": "/ping", "description": "Liveness probe."}, ], }) # ── Health & ping ───────────────────────────────────────────────────────────── @app.route("/health", methods=["GET"]) def health(): provider_ok = False provider_latency = None try: t0 = time.time() r = requests.get(f"{PROVIDER_BASE_URL}/models", headers=PROVIDER_HEADERS, timeout=5) provider_latency = round(time.time() - t0, 3) provider_ok = r.ok except Exception: pass with _stats_lock: total_req = _stats["total_requests"] uptime = int(time.time()) - _stats["started_at"] return jsonify({ "status": "healthy" if provider_ok else "degraded", "uptime_s": uptime, "provider": {"reachable": provider_ok, "latency_s": provider_latency, "url": PROVIDER_BASE_URL}, "requests": total_req, "timestamp": int(time.time()), }), 200 if provider_ok else 207 @app.route("/ping", methods=["GET"]) def ping(): return jsonify({"pong": True, "timestamp": int(time.time())}) if __name__ == "__main__": app.run(host="0.0.0.0", port=7860, debug=False)