import os from fastapi import FastAPI from pydantic import BaseModel from unsloth import FastModel # Fix cache path permissions for HF, TorchInductor, Triton os.environ["TRITON_CACHE_DIR"] = "/tmp/triton_cache" os.environ["TORCHINDUCTOR_CACHE_DIR"] = "/tmp/torchinductor_cache" HF_CACHE = "/tmp/hf_cache" os.environ["TRANSFORMERS_CACHE"] = HF_CACHE os.environ["HF_HOME"] = HF_CACHE # Create all the cache dirs for path in [HF_CACHE, "/tmp/triton_cache", "/tmp/torchinductor_cache"]: os.makedirs(path, exist_ok=True) app = FastAPI() # Lazy load the model on first request model = None tokenizer = None class ChatInput(BaseModel): message: str @app.post("/chat") async def chat_handler(input: ChatInput): global model, tokenizer if model is None or tokenizer is None: print("🟡 Loading model...") model, tokenizer = FastModel.from_pretrained( model_name = "microsoft/phi-2", adapter_name = "srikar-v05/phi3-Mini-Medical-Chat", load_in_4bit = True, max_seq_length = 2048, ) FastModel.for_inference(model) print("🟢 Model loaded!") # Generate response prompt = ( "You are a kind, attentive oncology provider speaking to a patient.\n" "Ask one follow-up question at a time to triage their symptoms.\n\n" f"Patient: {input.message}\nProvider:" ) inputs = tokenizer(prompt, return_tensors="pt").to(model.device) outputs = model.generate(**inputs, max_new_tokens=300) response = tokenizer.decode(outputs[0], skip_special_tokens=True).strip() return {"response": response}