File size: 1,633 Bytes
488fea0
fb91760
 
97bceda
3f41316
c98a751
b7b15c8
1fdd678
8c048e4
 
 
1fdd678
c98a751
 
 
fb91760
 
 
c98a751
 
 
97bceda
fb91760
 
 
 
 
c98a751
 
 
 
 
 
 
 
 
 
 
 
 
 
97bceda
 
 
 
 
 
8c048e4
97bceda
c98a751
fb91760
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import os
from fastapi import FastAPI
from pydantic import BaseModel
from unsloth import FastModel

# Fix cache path permissions for HF, TorchInductor, Triton
os.environ["TRITON_CACHE_DIR"] = "/tmp/triton_cache"
os.environ["TORCHINDUCTOR_CACHE_DIR"] = "/tmp/torchinductor_cache"
HF_CACHE = "/tmp/hf_cache"
os.environ["TRANSFORMERS_CACHE"] = HF_CACHE
os.environ["HF_HOME"] = HF_CACHE

# Create all the cache dirs
for path in [HF_CACHE, "/tmp/triton_cache", "/tmp/torchinductor_cache"]:
    os.makedirs(path, exist_ok=True)

app = FastAPI()

# Lazy load the model on first request
model = None
tokenizer = None

class ChatInput(BaseModel):
    message: str

@app.post("/chat")
async def chat_handler(input: ChatInput):
    global model, tokenizer

    if model is None or tokenizer is None:
        print("🟡 Loading model...")
        model, tokenizer = FastModel.from_pretrained(
            model_name = "microsoft/phi-2",
            adapter_name = "srikar-v05/phi3-Mini-Medical-Chat",
            load_in_4bit = True,
            max_seq_length = 2048,
        )
        FastModel.for_inference(model)
        print("🟢 Model loaded!")

    # Generate response
    prompt = (
        "You are a kind, attentive oncology provider speaking to a patient.\n"
        "Ask one follow-up question at a time to triage their symptoms.\n\n"
        f"Patient: {input.message}\nProvider:"
    )
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    outputs = model.generate(**inputs, max_new_tokens=300)
    response = tokenizer.decode(outputs[0], skip_special_tokens=True).strip()

    return {"response": response}