apichange / main.py
bibibi12345's picture
Update main.py
a8e4f2c verified
from fastapi import FastAPI, Request
from fastapi.responses import Response, StreamingResponse
from fastapi.middleware.cors import CORSMiddleware
import requests
import os
import json
app = FastAPI()
# Enable CORS to allow requests from SillyTavern
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Allows all origins; adjust if you know SillyTavern's origin
allow_credentials=True,
allow_methods=["*"], # Allows all HTTP methods
allow_headers=["*"], # Allows all headers
)
# Get the third-party API URL from an environment variable
THIRD_PARTY_API_URL = os.getenv("THIRD_PARTY_API_URL", "https://default-api.com")
model_name = ["DeepSeek-R1","DeepSeek-V3"]
@app.get("/hf/v1/models")
async def list_models():
"""θΏ”ε›žζ”―ζŒηš„ζ¨‘εž‹εˆ—θ‘¨"""
models = [
{
"id": model_id,
"object": "model",
"created": 1677610602,
"owned_by": "system",
}
for model_id in model_name
]
return {
"object": "list",
"data": models
}
# @app.post("/hf/v1/chat/completions")
@app.api_route("/hf/{path:path}", methods=["POST"])
async def proxy(request: Request, path: str):
print(f"Received request: {request.method} {path}")
# Extract the request body
body = await request.body()
headers = dict(request.headers)
# Remove the 'host' header to prevent forwarding issues
headers.pop("host", None)
# Process JSON requests to remove frequency_penalty
if request.headers.get("Content-Type") == "application/json":
try:
data = json.loads(body)
if "frequency_penalty" in data:
del data["frequency_penalty"]
print("Removed frequency_penalty from request body")
if "top_p" in data:
del data["top_p"]
print("Removed top_p from request body")
body = json.dumps(data).encode("utf-8")
except json.JSONDecodeError:
pass # If the body isn't valid JSON, forward it unchanged
# Construct the target URL and forward the request
url = f"{THIRD_PARTY_API_URL}/{path}"
print(f"Forwarding to: {url}")
response = requests.request(
method=request.method,
url=url,
headers=headers,
data=body,
params=request.query_params,
stream=True, # Enable streaming for compatibility with streaming responses
)
# Handle streaming responses (e.g., text/event-stream for OpenAI streaming)
if response.headers.get("Content-Type") == "text/event-stream":
return StreamingResponse(response.iter_content(chunk_size=1024), media_type="text/event-stream")
# Handle non-streaming responses
else:
return Response(
content=response.content,
status_code=response.status_code,
headers=dict(response.headers),
)