File size: 2,701 Bytes
51f9bd2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 |
import asyncio
import json
import time
from typing import Optional, List
from pydantic import BaseModel, Field
from starlette.responses import StreamingResponse
from fastapi import FastAPI, HTTPException, Request
app = FastAPI(title="OpenAI-compatible API")
# Load model directly
from transformers import AutoTokenizer, AutoModelForCausalLM
tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2")
model = AutoModelForCausalLM.from_pretrained("microsoft/phi-2")
# data models
class Message(BaseModel):
role: str
content: str
class ChatCompletionRequest(BaseModel):
model: Optional[str] = "mock-gpt-model"
messages: List[Message]
max_tokens: Optional[int] = 512
temperature: Optional[float] = 0.1
stream: Optional[bool] = False
async def _resp_async_generator(text_resp: str, model: str):
tokens = text_resp.split(" ")
for i, token in enumerate(tokens):
chunk = {
"id": i,
"object": "chat.completion.chunk",
"created": time.time(),
"model": model,
"choices": [{"delta": {"content": token + " "}}],
}
yield f"data: {json.dumps(chunk)}\n\n"
await asyncio.sleep(0.05)
yield "data: [DONE]\n\n"
@app.post("/chat/completions")
async def chat_completions(request: ChatCompletionRequest):
if not request.messages:
raise HTTPException(status_code=400, detail="No messages provided.")
# Build the prompt from messages
prompt = ""
for msg in request.messages:
if msg.role == "user":
prompt += f"User: {msg.content}\n"
elif msg.role == "assistant":
prompt += f"Assistant: {msg.content}\n"
prompt += "Assistant:"
# Tokenize and generate
inputs = tokenizer(prompt, return_tensors="pt")
outputs = model.generate(
**inputs,
max_new_tokens=request.max_tokens,
temperature=request.temperature,
do_sample=True,
pad_token_id=tokenizer.eos_token_id
)
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Extract only the assistant's new reply
assistant_reply = generated_text[len(prompt):].strip()
if request.stream:
return StreamingResponse(
_resp_async_generator(assistant_reply, request.model),
media_type="text/event-stream"
)
return {
"id": "1337",
"object": "chat.completion",
"created": time.time(),
"model": request.model,
"choices": [{"message": Message(role="assistant", content=assistant_reply)}],
}
# if __name__ == "__main__":
# import uvicorn
# uvicorn.run(app, host="0.0.0.0", port=8000)
|