ynsbyrm's picture
Update app/main.py
a58eac3 verified
from fastapi import FastAPI, Request, File, UploadFile, Form
from fastapi.responses import StreamingResponse, JSONResponse
import os
import uuid
import time
import httpx
import json
import numpy as np
import soundfile as sf
import io
import wave
app = FastAPI(
title="OpenAI Compatible API (Mock)",
version="1.0.0"
)
HF_TOKEN = os.getenv("HF_TOKEN")
headers_with_auth = {
"Content-Type": "application/json",
"Authorization": "Bearer " + HF_TOKEN
}
list_models = None
timeout = httpx.Timeout(
connect=5.0,
read=120.0,
write=5.0,
pool=5.0,
)
def log_request(request: Request, body: dict):
print("=" * 80)
print(f"[{time.strftime('%Y-%m-%d %H:%M:%S')}] {request.method} {request.url.path}")
print("Headers:")
for k, v in request.headers.items():
print(f" {k}: {v}")
print("Body:")
print(json.dumps(body, indent=2, ensure_ascii=False))
print("=" * 80)
# ------------------------------------------------------------------
# HEALTH CHECK
# ------------------------------------------------------------------
@app.get("/")
async def health_check():
return "Service up and running!"
# ------------------------------------------------------------------
# MODELS
# ------------------------------------------------------------------
@app.get("/v1/models")
async def list_models():
async with httpx.AsyncClient() as client:
upstream_url = "https://ynsbyrm-api-chat-service-models.hf.space/v1/models"
resp = await client.get(upstream_url, headers=headers_with_auth)
response = resp.json()
# model_id -> url map
global MODEL_REGISTRY
MODEL_REGISTRY = {
m["id"]: m["url"]
for m in response
if "id" in m and "url" in m
}
return response
# ------------------------------------------------------------------
# RESPONSES
# ------------------------------------------------------------------
@app.post("/v1/responses")
async def create_response(request: Request):
body = await request.json()
log_request(request, body)
return {
"id": f"resp_{uuid.uuid4().hex}",
"object": "response",
"created": int(time.time()),
"model": body.get("model", "unknown"),
"output": [
{
"id": "msg_1",
"type": "message",
"role": "assistant",
"content": [
{
"type": "output_text",
"text": "This is a mock response."
}
]
}
]
}
# ------------------------------------------------------------------
# EMBEDDINGS
# ------------------------------------------------------------------
@app.post("/v1/embeddings")
async def create_embeddings(request: Request):
body = await request.json()
log_request(request, body)
inputs = body.get("input", [])
if not isinstance(inputs, list):
inputs = [inputs]
return {
"object": "list",
"data": [
{
"object": "embedding",
"embedding": [0.0] * 1536,
"index": i
}
for i in range(len(inputs))
],
"model": body.get("model", "text-embedding-3")
}
# ------------------------------------------------------------------
# IMAGES
# ------------------------------------------------------------------
@app.post("/v1/images/generations")
async def generate_image(request: Request):
body = await request.json()
log_request(request, body)
return {
"created": int(time.time()),
"data": [
{
"url": "https://dummyimage.com/1024x1024/000/fff.png&text=mock"
}
]
}
# ------------------------------------------------------------------
# CHAT COMPLETIONS (LEGACY-OLD)
# ------------------------------------------------------------------
@app.post("/v1/chat/completions_old")
async def chat_completions(request: Request):
body = await request.json()
log_request(request, body)
return {
"id": f"chatcmpl_{uuid.uuid4().hex}",
"object": "chat.completion",
"created": int(time.time()),
"model": body.get("model", "gpt-4"),
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": "This is a mock chat completion."
},
"finish_reason": "stop"
}
]
}
# ------------------------------------------------------------------
# CHAT COMPLETIONS (New with Streaming)
# ------------------------------------------------------------------
@app.post("/v1/chat/completions")
async def chat_completions(request: Request):
body = await request.json()
log_request(request, body)
messages = body.get("messages", [])
stream = body.get("stream", False)
model_name = body.get("model", "unknown")
if model_name not in MODEL_REGISTRY:
return JSONResponse(
status_code=400,
content={"error": f"Model not found: {model_name}"}
)
upstream_url = MODEL_REGISTRY[model_name] + "/v1/chat/completions"
if body.get("stream", False):
# STREAMING
async def proxy_stream():
async with httpx.AsyncClient(timeout=timeout) as client:
async with client.stream(
"POST",
upstream_url,
json=body,
headers=headers_with_auth,
) as response:
async for chunk in response.aiter_raw():
yield chunk
return StreamingResponse(
proxy_stream(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
},
)
else:
# NON-STREAM (opsiyonel)
async with httpx.AsyncClient(timeout=timeout) as client:
resp = await client.post(upstream_url, json=body, headers=headers_with_auth)
return resp.json()
# ------------------------------------------------------------------
# COMPLETIONS (LEGACY)
# ------------------------------------------------------------------
@app.post("/v1/completions")
async def completions(request: Request):
body = await request.json()
log_request(request, body)
return {
"id": f"cmpl_{uuid.uuid4().hex}",
"object": "text_completion",
"created": int(time.time()),
"model": body.get("model", "text-davinci-003"),
"choices": [
{
"text": "This is a mock completion.",
"index": 0,
"finish_reason": "stop"
}
]
}
# ------------------------------------------------------------------
# MODERATIONS
# ------------------------------------------------------------------
@app.post("/v1/moderations")
async def moderations(request: Request):
body = await request.json()
log_request(request, body)
return {
"id": f"modr_{uuid.uuid4().hex}",
"model": body.get("model", "omni-moderation-latest"),
"results": [
{
"flagged": False,
"categories": {},
"category_scores": {}
}
]
}
# ------------------------------------------------------------------
# AUDIO TRANSCRIPTIONS (STT)
# ------------------------------------------------------------------
@app.post("/v1/audio/transcriptions")
async def audio_transcriptions(
file: UploadFile = File(...),
model: str = Form("whisper-1"),
language: str | None = Form(None),
prompt: str | None = Form(None),
response_format: str = Form("json"),
):
print("Request hit /v1/audio/transcriptions")
if model not in MODEL_REGISTRY:
return JSONResponse(
status_code=400,
content={"error": f"Model not found: {model_name}"}
)
upstream_url = MODEL_REGISTRY[model] + "/v1/audio/transcriptions"
file.file.seek(0)
files = {
"file": (
file.filename,
file.file,
file.content_type or "application/octet-stream"
)
}
data = {
"model": model,
"response_format": response_format,
}
if language:
data["language"] = language
if prompt:
data["prompt"] = prompt
safe_headers = {
k: v
for k, v in headers_with_auth.items()
if k.lower() != "content-type"
}
async with httpx.AsyncClient(timeout=timeout) as client:
resp = await client.post(
upstream_url,
files=files,
data=data,
headers=safe_headers
)
return JSONResponse(status_code=resp.status_code, content=resp.json())
# ------------------------------------------------------------------
# AUDIO SPEECH (TTS)
# ------------------------------------------------------------------
@app.post("/v1/audio/speech")
async def text_to_speech(request: Request):
body = await request.json()
log_request(request, body)
model_name = body.get("model", "unknown")
upstream_url = MODEL_REGISTRY[model_name] + "/v1/audio/speech"
async with httpx.AsyncClient(timeout=timeout) as client:
resp = await client.post(
upstream_url,
json=body,
headers=headers_with_auth
)
return StreamingResponse(
resp.aiter_bytes(),
media_type=resp.headers.get("content-type", "audio/wav"),
status_code=resp.status_code
)