Spaces:
Runtime error
Runtime error
File size: 1,633 Bytes
488fea0 fb91760 97bceda 3f41316 c98a751 b7b15c8 1fdd678 8c048e4 1fdd678 c98a751 fb91760 c98a751 97bceda fb91760 c98a751 97bceda 8c048e4 97bceda c98a751 fb91760 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 |
import os
from fastapi import FastAPI
from pydantic import BaseModel
from unsloth import FastModel
# Fix cache path permissions for HF, TorchInductor, Triton
os.environ["TRITON_CACHE_DIR"] = "/tmp/triton_cache"
os.environ["TORCHINDUCTOR_CACHE_DIR"] = "/tmp/torchinductor_cache"
HF_CACHE = "/tmp/hf_cache"
os.environ["TRANSFORMERS_CACHE"] = HF_CACHE
os.environ["HF_HOME"] = HF_CACHE
# Create all the cache dirs
for path in [HF_CACHE, "/tmp/triton_cache", "/tmp/torchinductor_cache"]:
os.makedirs(path, exist_ok=True)
app = FastAPI()
# Lazy load the model on first request
model = None
tokenizer = None
class ChatInput(BaseModel):
message: str
@app.post("/chat")
async def chat_handler(input: ChatInput):
global model, tokenizer
if model is None or tokenizer is None:
print("🟡 Loading model...")
model, tokenizer = FastModel.from_pretrained(
model_name = "microsoft/phi-2",
adapter_name = "srikar-v05/phi3-Mini-Medical-Chat",
load_in_4bit = True,
max_seq_length = 2048,
)
FastModel.for_inference(model)
print("🟢 Model loaded!")
# Generate response
prompt = (
"You are a kind, attentive oncology provider speaking to a patient.\n"
"Ask one follow-up question at a time to triage their symptoms.\n\n"
f"Patient: {input.message}\nProvider:"
)
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
outputs = model.generate(**inputs, max_new_tokens=300)
response = tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
return {"response": response}
|