File size: 2,712 Bytes
edb3801
 
 
 
0d2cdcb
a2e7e63
b4818d8
 
b8c5c95
 
b4818d8
 
 
 
 
 
 
b8c5c95
 
 
edb3801
 
 
 
b8c5c95
b4818d8
edb3801
 
 
 
 
 
 
 
 
 
 
a2e7e63
 
edb3801
 
 
 
0d2cdcb
 
edb3801
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a2e7e63
 
 
 
edb3801
a2e7e63
 
 
edb3801
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import time
import uuid
import json
import re
from fastapi import FastAPI
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from typing import List
from meta_ai_api import MetaAI

class Message(BaseModel):
    role: str
    content: str

class ChatRequest(BaseModel):
    messages: List[Message]

app = FastAPI()
meta = MetaAI()

@app.get("/")
def read_root():
    return {"msg": "Try POST /v1/chat/completions with stream=True"}

@app.post("/v1/chat/completions")
async def chat_completions(req: ChatRequest):
    user_msgs = [m.content for m in req.messages if m.role == "user"]
    if not user_msgs:
        return {"error": "No user message provided"}
    last_user = user_msgs[-1]

    def normalize_text(text):
        """Normalize text by removing spaces before punctuation and collapsing multiple spaces."""
        text = re.sub(r'\s+([.,!?;:])', r'\1', text)  # Remove spaces before punctuation
        text = re.sub(r'\s+', ' ', text)  # Replace multiple spaces with single space
        text = text.strip()  # Trim leading and trailing spaces
        return text

    def event_stream():
        last_normalized = ""
        for chunk in meta.prompt(message=last_user, stream=True):
            full_text = str(chunk.get("message", ""))
            if not full_text:
                continue

            # Normalize the full text from the current chunk
            normalized_full = normalize_text(full_text)

            # Calculate the new piece to send by comparing with the previous normalized text
            if normalized_full.startswith(last_normalized):
                new_piece = normalized_full[len(last_normalized):]
            else:
                new_piece = normalized_full

            if new_piece:
                data = {
                    "id": f"meta-{uuid.uuid4()}",
                    "object": "chat.completion.chunk",
                    "created": int(time.time()),
                    "choices": [{
                        "delta": {"content": new_piece},
                        "index": 0,
                        "finish_reason": None
                    }]
                }
                yield f"data: {json.dumps(data)}\n\n"

            # Update the last normalized text
            last_normalized = normalized_full

        # Send the final event to indicate completion
        done = {
            "id": f"meta-{uuid.uuid4()}",
            "object": "chat.completion.chunk",
            "created": int(time.time()),
            "choices": [{"delta": {}, "index": 0, "finish_reason": "stop"}]
        }
        yield f"data: {json.dumps(done)}\n\n"

    return StreamingResponse(event_stream(), media_type="text/event-stream")