File size: 3,967 Bytes
41a8f4f
6d886e2
41a8f4f
6d886e2
072a239
58bcf35
 
6d886e2
072a239
2f3e185
aaa9022
58bcf35
072a239
6d886e2
41a8f4f
 
f99d81b
41a8f4f
 
 
 
 
 
 
6d886e2
 
072a239
 
 
6d886e2
 
 
 
 
 
 
 
 
 
 
aaa9022
6d886e2
072a239
6d886e2
 
 
 
 
 
41a8f4f
6d886e2
072a239
 
 
 
6d886e2
 
072a239
6d886e2
 
 
 
 
 
 
072a239
6d886e2
 
 
 
 
072a239
6d886e2
 
072a239
6d886e2
 
072a239
6d886e2
 
072a239
6d886e2
 
072a239
6d886e2
 
 
 
 
 
 
 
072a239
6d886e2
 
 
 
 
 
 
072a239
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
from fastapi import FastAPI, HTTPException, Request, Depends, Security
from fastapi.responses import StreamingResponse, JSONResponse
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from pydantic import BaseModel
from gradio_client import Client
import time
import json

# Configure your Gradio Space ID and default endpoint
SPACE_ID = "openfree/Llama-4-Maverick-17B-Research-korea"
DEFAULT_API = "/query_deepseek_streaming"

client = Client(SPACE_ID)

# Security setup
security = HTTPBearer()
VALID_API_KEY = "sk-1234"  # Replace with your actual API key

async def get_api_key(credentials: HTTPAuthorizationCredentials = Depends(security)):
    if credentials.scheme != "Bearer":
        raise HTTPException(status_code=403, detail="Invalid authentication scheme")
    if credentials.credentials != VALID_API_KEY:
        raise HTTPException(status_code=403, detail="Invalid API key")
    return credentials.credentials

def chat_with_gradio(message: str, api_name: str = DEFAULT_API):
    """
    Send a chat message to the Gradio API and return the response.
    """
    try:
        return client.predict(message=message, api_name=api_name)
    except Exception as e:
        raise RuntimeError(f"Gradio API error: {e}")

class ChatRequest(BaseModel):
    message: str
    api_name: str = DEFAULT_API

app = FastAPI()

@app.post("/query_deepseek_streaming", dependencies=[Depends(get_api_key)])
async def chat_endpoint(req: ChatRequest):
    """Forward chat requests to the Gradio API."""
    try:
        reply = chat_with_gradio(req.message, req.api_name)
        return {"reply": reply}
    except RuntimeError as e:
        raise HTTPException(status_code=502, detail=str(e))

@app.post("/v1/chat/completions", dependencies=[Depends(get_api_key)])
async def openai_chat_completions(request: Request):
    """
    OpenAI-compatible chat completions endpoint that forwards to Gradio.
    Supports both streaming and non-streaming.
    """
    body = await request.json()
    messages = body.get("messages")
    model = body.get("model")
    stream = body.get("stream", False)

    if not messages or not isinstance(messages, list):
        raise HTTPException(status_code=400, detail="`messages` must be a list of dicts.")

    user_msg = messages[-1].get("content", "")

    # Call Gradio
    try:
        reply = chat_with_gradio(user_msg, DEFAULT_API)
    except RuntimeError as e:
        raise HTTPException(status_code=502, detail=str(e))

    # Build usage (simple token count by words)
    prompt_tokens = sum(len(m.get("content", "").split()) for m in messages)
    completion_tokens = len(str(reply).split())
    usage = {"prompt_tokens": prompt_tokens, "completion_tokens": completion_tokens, "total_tokens": prompt_tokens + completion_tokens}

    if stream:
        # Stream word by word as OpenAI SSE
        def event_generator():
            for word in str(reply).split():
                chunk = {"choices": [{"delta": {"content": word+" "}, "index": 0, "finish_reason": None}]}
                yield f"data: {json.dumps(chunk)}\n\n"
                time.sleep(0.05)
            # send done
            done = {"choices": [{"delta": {}, "index": 0, "finish_reason": "stop"}]}
            yield f"data: {json.dumps(done)}\n\n"
        return StreamingResponse(event_generator(), media_type="text/event-stream")
    else:
        response = {
            "id": f"chatcmpl-{int(time.time())}",
            "object": "chat.completion",
            "created": int(time.time()),
            "model": model,
            "choices": [{"index": 0, "message": {"role": "assistant", "content": reply}, "finish_reason": "stop"}],
            "usage": usage
        }
        return JSONResponse(response)

if __name__ == "__main__":
    import uvicorn
    print(f"Starting server on http://0.0.0.0:7860 using {SPACE_ID}{DEFAULT_API} and OpenAI-compatible endpoint /v1/chat/completions")
    uvicorn.run(app, host="0.0.0.0", port=7860)