Medical-Chatbot / app.py
Srikar00007's picture
Update app.py
0f4af1c verified
from fastapi import FastAPI, UploadFile, File
from fastapi.middleware.cors import CORSMiddleware
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
from ultralytics import YOLO
from PIL import Image
import torch
import io
app = FastAPI()
# βœ… Allow React frontend
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# ------------------------------------------------------------
# βœ… LAZY LOAD MODELS (Important for HuggingFace Spaces)
# ------------------------------------------------------------
wound_pipe = None
yolo_model = None
tokenizer = None
chat_model = None
def load_wound_model():
global wound_pipe
if wound_pipe is None:
wound_pipe = pipeline("image-classification", model="Hemg/Wound-Image-classification")
return wound_pipe
def load_yolo_model():
global yolo_model
if yolo_model is None:
yolo_model = YOLO("best.pt") # best.pt MUST be in root
return yolo_model
def load_llm():
global tokenizer, chat_model
if tokenizer is None or chat_model is None:
LLM_ID = "google/medgemma-2b-it" # βœ… Smaller model, fits HF
tokenizer = AutoTokenizer.from_pretrained(LLM_ID)
chat_model = AutoModelForCausalLM.from_pretrained(
LLM_ID,
device_map="cpu", # βœ… Forces CPU mode
torch_dtype=torch.float32
)
return tokenizer, chat_model
# ------------------------------------------------------------
# βœ… API 1 β€” Analyze Image
# ------------------------------------------------------------
@app.post("/analyze-image/")
async def analyze_image(file: UploadFile = File(...)):
wound_pipe = load_wound_model()
yolo_model = load_yolo_model()
img_bytes = await file.read()
img = Image.open(io.BytesIO(img_bytes)).convert("RGB")
# βœ… Wound model
wound_pred = wound_pipe(img)[0]
wound_label = wound_pred["label"]
wound_conf = float(wound_pred["score"])
# βœ… YOLO detection
yolo_out = yolo_model(img)
if len(yolo_out[0].boxes) > 0:
cls_id = int(yolo_out[0].boxes.cls[0])
skin_label = yolo_model.names[cls_id]
else:
skin_label = "None"
return {
"wound_label": wound_label,
"wound_conf": wound_conf,
"skin_label": skin_label
}
# ------------------------------------------------------------
# βœ… API 2 β€” Ask Medical Question
# ------------------------------------------------------------
@app.post("/ask-ai/")
async def ask_ai(data: dict):
tokenizer, chat_model = load_llm()
prompt = f"""
Detected wound: {data['wound']}
Detected skin disease: {data['skin']}
User question: {data['question']}
Respond like a medical expert.
"""
inputs = tokenizer(prompt, return_tensors="pt")
outputs = chat_model.generate(**inputs, max_new_tokens=200)
reply = tokenizer.decode(outputs[0], skip_special_tokens=True)
return {"reply": reply}
# βœ… Start Uvicorn server
import uvicorn
if __name__ == "__main__":
uvicorn.run("app:app", host="0.0.0.0", port=7860)