from fastapi import FastAPI from pydantic import BaseModel from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer from sse_starlette.sse import EventSourceResponse import torch from threading import Thread import json import uvicorn app = FastAPI() model_name = "Qwen/Qwen2-0.5B-Instruct" print("Loading tokenizer...") tokenizer = AutoTokenizer.from_pretrained(model_name) print("Loading model...") model = AutoModelForCausalLM.from_pretrained( model_name, dtype=torch.float32, low_cpu_mem_usage=True ).to("cpu") print("Model loaded successfully!") class ChatRequest(BaseModel): model: str = "auric-ai" messages: list stream: bool = False max_tokens: int = 512 temperature: float = 0.1 @app.post("/v1/chat/completions") async def chat(req: ChatRequest): # Use Qwen2's chat template for proper system/user/assistant formatting prompt = tokenizer.apply_chat_template( req.messages, tokenize=False, add_generation_prompt=True ) inputs = tokenizer(prompt, return_tensors="pt").to("cpu") prompt_length = inputs["input_ids"].shape[-1] temperature = max(req.temperature, 0.01) max_tokens = min(req.max_tokens, 2048) # ---------------- STREAM MODE ---------------- if req.stream: streamer = TextIteratorStreamer( tokenizer, skip_prompt=True, skip_special_tokens=True ) generation_kwargs = dict( **inputs, streamer=streamer, max_new_tokens=max_tokens, temperature=temperature, do_sample=True ) thread = Thread(target=model.generate, kwargs=generation_kwargs) thread.start() async def event_generator(): for token in streamer: data = { "choices": [ { "delta": { "content": token } } ] } yield { "event": "message", "data": json.dumps(data) } yield { "event": "message", "data": "[DONE]" } return EventSourceResponse(event_generator()) # ---------------- NORMAL MODE ---------------- else: output = model.generate( **inputs, max_new_tokens=max_tokens, temperature=temperature, do_sample=True ) # Only decode the newly generated tokens, not the prompt response = tokenizer.decode(output[0][prompt_length:], skip_special_tokens=True) return { "choices": [ { "message": { "role": "assistant", "content": response.strip() } } ] } if __name__ == "__main__": uvicorn.run("app:app", host="0.0.0.0", port=7860, workers=3)