chatdoctor-api / app.py
arjunkmoorthy's picture
Update app.py
c98a751 verified
import os
from fastapi import FastAPI
from pydantic import BaseModel
from unsloth import FastModel
# Fix cache path permissions for HF, TorchInductor, Triton
os.environ["TRITON_CACHE_DIR"] = "/tmp/triton_cache"
os.environ["TORCHINDUCTOR_CACHE_DIR"] = "/tmp/torchinductor_cache"
HF_CACHE = "/tmp/hf_cache"
os.environ["TRANSFORMERS_CACHE"] = HF_CACHE
os.environ["HF_HOME"] = HF_CACHE
# Create all the cache dirs
for path in [HF_CACHE, "/tmp/triton_cache", "/tmp/torchinductor_cache"]:
os.makedirs(path, exist_ok=True)
app = FastAPI()
# Lazy load the model on first request
model = None
tokenizer = None
class ChatInput(BaseModel):
message: str
@app.post("/chat")
async def chat_handler(input: ChatInput):
global model, tokenizer
if model is None or tokenizer is None:
print("🟡 Loading model...")
model, tokenizer = FastModel.from_pretrained(
model_name = "microsoft/phi-2",
adapter_name = "srikar-v05/phi3-Mini-Medical-Chat",
load_in_4bit = True,
max_seq_length = 2048,
)
FastModel.for_inference(model)
print("🟢 Model loaded!")
# Generate response
prompt = (
"You are a kind, attentive oncology provider speaking to a patient.\n"
"Ask one follow-up question at a time to triage their symptoms.\n\n"
f"Patient: {input.message}\nProvider:"
)
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
outputs = model.generate(**inputs, max_new_tokens=300)
response = tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
return {"response": response}