import asyncio import json import time from typing import Optional, List from pydantic import BaseModel, Field from starlette.responses import StreamingResponse from fastapi import FastAPI, HTTPException, Request app = FastAPI(title="OpenAI-compatible API") # Load model directly from transformers import AutoTokenizer, AutoModelForCausalLM tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2") model = AutoModelForCausalLM.from_pretrained("microsoft/phi-2") # data models class Message(BaseModel): role: str content: str class ChatCompletionRequest(BaseModel): model: Optional[str] = "mock-gpt-model" messages: List[Message] max_tokens: Optional[int] = 512 temperature: Optional[float] = 0.1 stream: Optional[bool] = False async def _resp_async_generator(text_resp: str, model: str): tokens = text_resp.split(" ") for i, token in enumerate(tokens): chunk = { "id": i, "object": "chat.completion.chunk", "created": time.time(), "model": model, "choices": [{"delta": {"content": token + " "}}], } yield f"data: {json.dumps(chunk)}\n\n" await asyncio.sleep(0.05) yield "data: [DONE]\n\n" @app.post("/chat/completions") async def chat_completions(request: ChatCompletionRequest): if not request.messages: raise HTTPException(status_code=400, detail="No messages provided.") # Build the prompt from messages prompt = "" for msg in request.messages: if msg.role == "user": prompt += f"User: {msg.content}\n" elif msg.role == "assistant": prompt += f"Assistant: {msg.content}\n" prompt += "Assistant:" # Tokenize and generate inputs = tokenizer(prompt, return_tensors="pt") outputs = model.generate( **inputs, max_new_tokens=request.max_tokens, temperature=request.temperature, do_sample=True, pad_token_id=tokenizer.eos_token_id ) generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) # Extract only the assistant's new reply assistant_reply = generated_text[len(prompt):].strip() if request.stream: return StreamingResponse( _resp_async_generator(assistant_reply, request.model), media_type="text/event-stream" ) return { "id": "1337", "object": "chat.completion", "created": time.time(), "model": request.model, "choices": [{"message": Message(role="assistant", content=assistant_reply)}], } # if __name__ == "__main__": # import uvicorn # uvicorn.run(app, host="0.0.0.0", port=8000)