LLama / main.py
Hivra's picture
Update main.py
edb3801 verified
import time
import uuid
import json
import re
from fastapi import FastAPI
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from typing import List
from meta_ai_api import MetaAI
class Message(BaseModel):
role: str
content: str
class ChatRequest(BaseModel):
messages: List[Message]
app = FastAPI()
meta = MetaAI()
@app.get("/")
def read_root():
return {"msg": "Try POST /v1/chat/completions with stream=True"}
@app.post("/v1/chat/completions")
async def chat_completions(req: ChatRequest):
user_msgs = [m.content for m in req.messages if m.role == "user"]
if not user_msgs:
return {"error": "No user message provided"}
last_user = user_msgs[-1]
def normalize_text(text):
"""Normalize text by removing spaces before punctuation and collapsing multiple spaces."""
text = re.sub(r'\s+([.,!?;:])', r'\1', text) # Remove spaces before punctuation
text = re.sub(r'\s+', ' ', text) # Replace multiple spaces with single space
text = text.strip() # Trim leading and trailing spaces
return text
def event_stream():
last_normalized = ""
for chunk in meta.prompt(message=last_user, stream=True):
full_text = str(chunk.get("message", ""))
if not full_text:
continue
# Normalize the full text from the current chunk
normalized_full = normalize_text(full_text)
# Calculate the new piece to send by comparing with the previous normalized text
if normalized_full.startswith(last_normalized):
new_piece = normalized_full[len(last_normalized):]
else:
new_piece = normalized_full
if new_piece:
data = {
"id": f"meta-{uuid.uuid4()}",
"object": "chat.completion.chunk",
"created": int(time.time()),
"choices": [{
"delta": {"content": new_piece},
"index": 0,
"finish_reason": None
}]
}
yield f"data: {json.dumps(data)}\n\n"
# Update the last normalized text
last_normalized = normalized_full
# Send the final event to indicate completion
done = {
"id": f"meta-{uuid.uuid4()}",
"object": "chat.completion.chunk",
"created": int(time.time()),
"choices": [{"delta": {}, "index": 0, "finish_reason": "stop"}]
}
yield f"data: {json.dumps(done)}\n\n"
return StreamingResponse(event_stream(), media_type="text/event-stream")