File size: 2,712 Bytes
edb3801 0d2cdcb a2e7e63 b4818d8 b8c5c95 b4818d8 b8c5c95 edb3801 b8c5c95 b4818d8 edb3801 a2e7e63 edb3801 0d2cdcb edb3801 a2e7e63 edb3801 a2e7e63 edb3801 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 |
import time
import uuid
import json
import re
from fastapi import FastAPI
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from typing import List
from meta_ai_api import MetaAI
class Message(BaseModel):
role: str
content: str
class ChatRequest(BaseModel):
messages: List[Message]
app = FastAPI()
meta = MetaAI()
@app.get("/")
def read_root():
return {"msg": "Try POST /v1/chat/completions with stream=True"}
@app.post("/v1/chat/completions")
async def chat_completions(req: ChatRequest):
user_msgs = [m.content for m in req.messages if m.role == "user"]
if not user_msgs:
return {"error": "No user message provided"}
last_user = user_msgs[-1]
def normalize_text(text):
"""Normalize text by removing spaces before punctuation and collapsing multiple spaces."""
text = re.sub(r'\s+([.,!?;:])', r'\1', text) # Remove spaces before punctuation
text = re.sub(r'\s+', ' ', text) # Replace multiple spaces with single space
text = text.strip() # Trim leading and trailing spaces
return text
def event_stream():
last_normalized = ""
for chunk in meta.prompt(message=last_user, stream=True):
full_text = str(chunk.get("message", ""))
if not full_text:
continue
# Normalize the full text from the current chunk
normalized_full = normalize_text(full_text)
# Calculate the new piece to send by comparing with the previous normalized text
if normalized_full.startswith(last_normalized):
new_piece = normalized_full[len(last_normalized):]
else:
new_piece = normalized_full
if new_piece:
data = {
"id": f"meta-{uuid.uuid4()}",
"object": "chat.completion.chunk",
"created": int(time.time()),
"choices": [{
"delta": {"content": new_piece},
"index": 0,
"finish_reason": None
}]
}
yield f"data: {json.dumps(data)}\n\n"
# Update the last normalized text
last_normalized = normalized_full
# Send the final event to indicate completion
done = {
"id": f"meta-{uuid.uuid4()}",
"object": "chat.completion.chunk",
"created": int(time.time()),
"choices": [{"delta": {}, "index": 0, "finish_reason": "stop"}]
}
yield f"data: {json.dumps(done)}\n\n"
return StreamingResponse(event_stream(), media_type="text/event-stream") |