Deepseek_Test / app /main.py
Hivra's picture
Update app/main.py
2f3e185 verified
from fastapi import FastAPI, HTTPException, Request, Depends, Security
from fastapi.responses import StreamingResponse, JSONResponse
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from pydantic import BaseModel
from gradio_client import Client
import time
import json
# Configure your Gradio Space ID and default endpoint
SPACE_ID = "openfree/Llama-4-Maverick-17B-Research-korea"
DEFAULT_API = "/query_deepseek_streaming"
client = Client(SPACE_ID)
# Security setup
security = HTTPBearer()
VALID_API_KEY = "sk-1234" # Replace with your actual API key
async def get_api_key(credentials: HTTPAuthorizationCredentials = Depends(security)):
if credentials.scheme != "Bearer":
raise HTTPException(status_code=403, detail="Invalid authentication scheme")
if credentials.credentials != VALID_API_KEY:
raise HTTPException(status_code=403, detail="Invalid API key")
return credentials.credentials
def chat_with_gradio(message: str, api_name: str = DEFAULT_API):
"""
Send a chat message to the Gradio API and return the response.
"""
try:
return client.predict(message=message, api_name=api_name)
except Exception as e:
raise RuntimeError(f"Gradio API error: {e}")
class ChatRequest(BaseModel):
message: str
api_name: str = DEFAULT_API
app = FastAPI()
@app.post("/query_deepseek_streaming", dependencies=[Depends(get_api_key)])
async def chat_endpoint(req: ChatRequest):
"""Forward chat requests to the Gradio API."""
try:
reply = chat_with_gradio(req.message, req.api_name)
return {"reply": reply}
except RuntimeError as e:
raise HTTPException(status_code=502, detail=str(e))
@app.post("/v1/chat/completions", dependencies=[Depends(get_api_key)])
async def openai_chat_completions(request: Request):
"""
OpenAI-compatible chat completions endpoint that forwards to Gradio.
Supports both streaming and non-streaming.
"""
body = await request.json()
messages = body.get("messages")
model = body.get("model")
stream = body.get("stream", False)
if not messages or not isinstance(messages, list):
raise HTTPException(status_code=400, detail="`messages` must be a list of dicts.")
user_msg = messages[-1].get("content", "")
# Call Gradio
try:
reply = chat_with_gradio(user_msg, DEFAULT_API)
except RuntimeError as e:
raise HTTPException(status_code=502, detail=str(e))
# Build usage (simple token count by words)
prompt_tokens = sum(len(m.get("content", "").split()) for m in messages)
completion_tokens = len(str(reply).split())
usage = {"prompt_tokens": prompt_tokens, "completion_tokens": completion_tokens, "total_tokens": prompt_tokens + completion_tokens}
if stream:
# Stream word by word as OpenAI SSE
def event_generator():
for word in str(reply).split():
chunk = {"choices": [{"delta": {"content": word+" "}, "index": 0, "finish_reason": None}]}
yield f"data: {json.dumps(chunk)}\n\n"
time.sleep(0.05)
# send done
done = {"choices": [{"delta": {}, "index": 0, "finish_reason": "stop"}]}
yield f"data: {json.dumps(done)}\n\n"
return StreamingResponse(event_generator(), media_type="text/event-stream")
else:
response = {
"id": f"chatcmpl-{int(time.time())}",
"object": "chat.completion",
"created": int(time.time()),
"model": model,
"choices": [{"index": 0, "message": {"role": "assistant", "content": reply}, "finish_reason": "stop"}],
"usage": usage
}
return JSONResponse(response)
if __name__ == "__main__":
import uvicorn
print(f"Starting server on http://0.0.0.0:7860 using {SPACE_ID}{DEFAULT_API} and OpenAI-compatible endpoint /v1/chat/completions")
uvicorn.run(app, host="0.0.0.0", port=7860)