Update main.py
Browse files
main.py
CHANGED
|
@@ -1,4 +1,7 @@
|
|
| 1 |
-
import
|
|
|
|
|
|
|
|
|
|
| 2 |
from fastapi import FastAPI
|
| 3 |
from fastapi.responses import StreamingResponse
|
| 4 |
from pydantic import BaseModel
|
|
@@ -10,52 +13,68 @@ class Message(BaseModel):
|
|
| 10 |
content: str
|
| 11 |
|
| 12 |
class ChatRequest(BaseModel):
|
| 13 |
-
model: str = None
|
| 14 |
messages: List[Message]
|
| 15 |
|
| 16 |
app = FastAPI()
|
| 17 |
-
DEFAULT_MODEL = os.getenv("MODEL_NAME", "llama-3-70b")
|
| 18 |
meta = MetaAI()
|
| 19 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
@app.post("/v1/chat/completions")
|
| 21 |
async def chat_completions(req: ChatRequest):
|
| 22 |
-
|
| 23 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
|
| 25 |
def event_stream():
|
| 26 |
-
|
| 27 |
-
for chunk in meta.prompt(message=
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
last_text = full
|
| 31 |
-
if not new_piece:
|
| 32 |
continue
|
| 33 |
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
done = {
|
| 49 |
"id": f"meta-{uuid.uuid4()}",
|
| 50 |
"object": "chat.completion.chunk",
|
| 51 |
"created": int(time.time()),
|
| 52 |
-
"
|
| 53 |
-
"choices": [{
|
| 54 |
-
"delta": {},
|
| 55 |
-
"index": 0,
|
| 56 |
-
"finish_reason": "stop"
|
| 57 |
-
}]
|
| 58 |
}
|
| 59 |
yield f"data: {json.dumps(done)}\n\n"
|
| 60 |
|
| 61 |
-
return StreamingResponse(event_stream(), media_type="text/event-stream")
|
|
|
|
| 1 |
+
import time
|
| 2 |
+
import uuid
|
| 3 |
+
import json
|
| 4 |
+
import re
|
| 5 |
from fastapi import FastAPI
|
| 6 |
from fastapi.responses import StreamingResponse
|
| 7 |
from pydantic import BaseModel
|
|
|
|
| 13 |
content: str
|
| 14 |
|
| 15 |
class ChatRequest(BaseModel):
|
|
|
|
| 16 |
messages: List[Message]
|
| 17 |
|
| 18 |
app = FastAPI()
|
|
|
|
| 19 |
meta = MetaAI()
|
| 20 |
|
| 21 |
+
@app.get("/")
|
| 22 |
+
def read_root():
|
| 23 |
+
return {"msg": "Try POST /v1/chat/completions with stream=True"}
|
| 24 |
+
|
| 25 |
@app.post("/v1/chat/completions")
|
| 26 |
async def chat_completions(req: ChatRequest):
|
| 27 |
+
user_msgs = [m.content for m in req.messages if m.role == "user"]
|
| 28 |
+
if not user_msgs:
|
| 29 |
+
return {"error": "No user message provided"}
|
| 30 |
+
last_user = user_msgs[-1]
|
| 31 |
+
|
| 32 |
+
def normalize_text(text):
|
| 33 |
+
"""Normalize text by removing spaces before punctuation and collapsing multiple spaces."""
|
| 34 |
+
text = re.sub(r'\s+([.,!?;:])', r'\1', text) # Remove spaces before punctuation
|
| 35 |
+
text = re.sub(r'\s+', ' ', text) # Replace multiple spaces with single space
|
| 36 |
+
text = text.strip() # Trim leading and trailing spaces
|
| 37 |
+
return text
|
| 38 |
|
| 39 |
def event_stream():
|
| 40 |
+
last_normalized = ""
|
| 41 |
+
for chunk in meta.prompt(message=last_user, stream=True):
|
| 42 |
+
full_text = str(chunk.get("message", ""))
|
| 43 |
+
if not full_text:
|
|
|
|
|
|
|
| 44 |
continue
|
| 45 |
|
| 46 |
+
# Normalize the full text from the current chunk
|
| 47 |
+
normalized_full = normalize_text(full_text)
|
| 48 |
+
|
| 49 |
+
# Calculate the new piece to send by comparing with the previous normalized text
|
| 50 |
+
if normalized_full.startswith(last_normalized):
|
| 51 |
+
new_piece = normalized_full[len(last_normalized):]
|
| 52 |
+
else:
|
| 53 |
+
new_piece = normalized_full
|
| 54 |
+
|
| 55 |
+
if new_piece:
|
| 56 |
+
data = {
|
| 57 |
+
"id": f"meta-{uuid.uuid4()}",
|
| 58 |
+
"object": "chat.completion.chunk",
|
| 59 |
+
"created": int(time.time()),
|
| 60 |
+
"choices": [{
|
| 61 |
+
"delta": {"content": new_piece},
|
| 62 |
+
"index": 0,
|
| 63 |
+
"finish_reason": None
|
| 64 |
+
}]
|
| 65 |
+
}
|
| 66 |
+
yield f"data: {json.dumps(data)}\n\n"
|
| 67 |
+
|
| 68 |
+
# Update the last normalized text
|
| 69 |
+
last_normalized = normalized_full
|
| 70 |
+
|
| 71 |
+
# Send the final event to indicate completion
|
| 72 |
done = {
|
| 73 |
"id": f"meta-{uuid.uuid4()}",
|
| 74 |
"object": "chat.completion.chunk",
|
| 75 |
"created": int(time.time()),
|
| 76 |
+
"choices": [{"delta": {}, "index": 0, "finish_reason": "stop"}]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
}
|
| 78 |
yield f"data: {json.dumps(done)}\n\n"
|
| 79 |
|
| 80 |
+
return StreamingResponse(event_stream(), media_type="text/event-stream")
|