Billly / app.py
Zynara's picture
Main.py
f3d004f verified
raw
history blame
1.82 kB
from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
# === CONFIG ===
MODEL_NAME = "microsoft/phi-2" # Replace with phi-4 if available
# === INIT ===
app = FastAPI()
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32)
model.eval()
# === CORS (for browser clients) ===
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Replace with your frontend domain
allow_methods=["*"],
allow_headers=["*"],
)
# === Request Schema ===
class ChatRequest(BaseModel):
model: str = "phi-4"
messages: list
# === Helper Function ===
def format_prompt(messages):
prompt = "You are Billy, a helpful and friendly assistant.\n\n"
for msg in messages:
role = msg["role"]
content = msg["content"]
if role == "user":
prompt += f"User: {content}\n"
elif role == "assistant":
prompt += f"Billy: {content}\n"
prompt += "Billy:"
return prompt
# === Chat Endpoint ===
@app.post("/chat")
async def chat(req: ChatRequest):
prompt = format_prompt(req.messages)
inputs = tokenizer(prompt, return_tensors="pt", truncation=True).to(model.device)
with torch.no_grad():
output = model.generate(
**inputs,
max_new_tokens=100,
do_sample=True,
temperature=0.7,
top_p=0.9,
pad_token_id=tokenizer.eos_token_id
)
decoded = tokenizer.decode(output[0], skip_special_tokens=True)
response = decoded.split("Billy:")[-1].strip()
return {"message": {"content": response}}