Auric-Bot / app.py
Valtry's picture
Update app.py
d89fbc7 verified
from fastapi import FastAPI
from pydantic import BaseModel
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from sse_starlette.sse import EventSourceResponse
import torch
from threading import Thread
import json
import uvicorn
app = FastAPI()
model_name = "Qwen/Qwen2-0.5B-Instruct"
print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(model_name)
print("Loading model...")
model = AutoModelForCausalLM.from_pretrained(
model_name,
dtype=torch.float32,
low_cpu_mem_usage=True
).to("cpu")
print("Model loaded successfully!")
class ChatRequest(BaseModel):
model: str = "auric-ai"
messages: list
stream: bool = False
max_tokens: int = 512
temperature: float = 0.1
@app.post("/v1/chat/completions")
async def chat(req: ChatRequest):
# Use Qwen2's chat template for proper system/user/assistant formatting
prompt = tokenizer.apply_chat_template(
req.messages,
tokenize=False,
add_generation_prompt=True
)
inputs = tokenizer(prompt, return_tensors="pt").to("cpu")
prompt_length = inputs["input_ids"].shape[-1]
temperature = max(req.temperature, 0.01)
max_tokens = min(req.max_tokens, 2048)
# ---------------- STREAM MODE ----------------
if req.stream:
streamer = TextIteratorStreamer(
tokenizer,
skip_prompt=True,
skip_special_tokens=True
)
generation_kwargs = dict(
**inputs,
streamer=streamer,
max_new_tokens=max_tokens,
temperature=temperature,
do_sample=True
)
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
async def event_generator():
for token in streamer:
data = {
"choices": [
{
"delta": {
"content": token
}
}
]
}
yield {
"event": "message",
"data": json.dumps(data)
}
yield {
"event": "message",
"data": "[DONE]"
}
return EventSourceResponse(event_generator())
# ---------------- NORMAL MODE ----------------
else:
output = model.generate(
**inputs,
max_new_tokens=max_tokens,
temperature=temperature,
do_sample=True
)
# Only decode the newly generated tokens, not the prompt
response = tokenizer.decode(output[0][prompt_length:], skip_special_tokens=True)
return {
"choices": [
{
"message": {
"role": "assistant",
"content": response.strip()
}
}
]
}
if __name__ == "__main__":
uvicorn.run("app:app", host="0.0.0.0", port=7860, workers=3)