ethiotech4848 commited on
Commit
f524bae
·
verified ·
1 Parent(s): f7e486c

Create main.py

Browse files
Files changed (1) hide show
  1. main.py +53 -0
main.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException
2
+ from fastapi.responses import StreamingResponse
3
+ from pydantic import BaseModel, Field
4
+ from typing import List, Optional, Any, Dict
5
+ from deepinfra_handler import DeepInfraHandler
6
+ import json
7
+
8
+ app = FastAPI()
9
+ api_handler = DeepInfraHandler()
10
+
11
+ class Message(BaseModel):
12
+ role: str
13
+ content: str
14
+
15
+ class ChatCompletionRequest(BaseModel):
16
+ model: str
17
+ messages: List[Message]
18
+ temperature: Optional[float] = Field(default=0.7, ge=0.0, le=2.0)
19
+ max_tokens: Optional[int] = Field(default=4096, ge=1)
20
+ top_p: Optional[float] = Field(default=1.0, ge=0.0, le=1.0)
21
+ frequency_penalty: Optional[float] = Field(default=0.0, ge=-2.0, le=2.0)
22
+ presence_penalty: Optional[float] = Field(default=0.0, ge=-2.0, le=2.0)
23
+ stop: Optional[List[str]] = Field(default=[])
24
+ stream: Optional[bool] = Field(default=False)
25
+
26
+ @app.post("/chat/completions")
27
+ async def chat_completions(request: ChatCompletionRequest):
28
+ try:
29
+ # Convert request to dictionary
30
+ params = request.dict()
31
+
32
+ if request.stream:
33
+ # Handle streaming response
34
+ def generate():
35
+ for chunk in api_handler.generate_completion(**params):
36
+ yield f"data: {json.dumps({'choices': [{'delta': {'content': chunk}}]})}\n\n"
37
+ yield "data: [DONE]\n\n"
38
+
39
+ return StreamingResponse(
40
+ generate(),
41
+ media_type="text/event-stream"
42
+ )
43
+
44
+ # Handle regular response
45
+ response = api_handler.generate_completion(**params)
46
+ return response
47
+
48
+ except Exception as e:
49
+ raise HTTPException(status_code=500, detail=str(e))
50
+
51
+ if __name__ == "__main__":
52
+ import uvicorn
53
+ uvicorn.run(app, host="0.0.0.0", port=8000)