Spaces:
Runtime error
Runtime error
| import os | |
| from fastapi import FastAPI | |
| from pydantic import BaseModel | |
| from unsloth import FastModel | |
| # Fix cache path permissions for HF, TorchInductor, Triton | |
| os.environ["TRITON_CACHE_DIR"] = "/tmp/triton_cache" | |
| os.environ["TORCHINDUCTOR_CACHE_DIR"] = "/tmp/torchinductor_cache" | |
| HF_CACHE = "/tmp/hf_cache" | |
| os.environ["TRANSFORMERS_CACHE"] = HF_CACHE | |
| os.environ["HF_HOME"] = HF_CACHE | |
| # Create all the cache dirs | |
| for path in [HF_CACHE, "/tmp/triton_cache", "/tmp/torchinductor_cache"]: | |
| os.makedirs(path, exist_ok=True) | |
| app = FastAPI() | |
| # Lazy load the model on first request | |
| model = None | |
| tokenizer = None | |
| class ChatInput(BaseModel): | |
| message: str | |
| async def chat_handler(input: ChatInput): | |
| global model, tokenizer | |
| if model is None or tokenizer is None: | |
| print("🟡 Loading model...") | |
| model, tokenizer = FastModel.from_pretrained( | |
| model_name = "microsoft/phi-2", | |
| adapter_name = "srikar-v05/phi3-Mini-Medical-Chat", | |
| load_in_4bit = True, | |
| max_seq_length = 2048, | |
| ) | |
| FastModel.for_inference(model) | |
| print("🟢 Model loaded!") | |
| # Generate response | |
| prompt = ( | |
| "You are a kind, attentive oncology provider speaking to a patient.\n" | |
| "Ask one follow-up question at a time to triage their symptoms.\n\n" | |
| f"Patient: {input.message}\nProvider:" | |
| ) | |
| inputs = tokenizer(prompt, return_tensors="pt").to(model.device) | |
| outputs = model.generate(**inputs, max_new_tokens=300) | |
| response = tokenizer.decode(outputs[0], skip_special_tokens=True).strip() | |
| return {"response": response} | |