from fastapi import FastAPI, UploadFile, File from fastapi.middleware.cors import CORSMiddleware from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM from ultralytics import YOLO from PIL import Image import torch import io app = FastAPI() # ✅ Allow React frontend app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # ------------------------------------------------------------ # ✅ LAZY LOAD MODELS (Important for HuggingFace Spaces) # ------------------------------------------------------------ wound_pipe = None yolo_model = None tokenizer = None chat_model = None def load_wound_model(): global wound_pipe if wound_pipe is None: wound_pipe = pipeline("image-classification", model="Hemg/Wound-Image-classification") return wound_pipe def load_yolo_model(): global yolo_model if yolo_model is None: yolo_model = YOLO("best.pt") # best.pt MUST be in root return yolo_model def load_llm(): global tokenizer, chat_model if tokenizer is None or chat_model is None: LLM_ID = "google/medgemma-2b-it" # ✅ Smaller model, fits HF tokenizer = AutoTokenizer.from_pretrained(LLM_ID) chat_model = AutoModelForCausalLM.from_pretrained( LLM_ID, device_map="cpu", # ✅ Forces CPU mode torch_dtype=torch.float32 ) return tokenizer, chat_model # ------------------------------------------------------------ # ✅ API 1 — Analyze Image # ------------------------------------------------------------ @app.post("/analyze-image/") async def analyze_image(file: UploadFile = File(...)): wound_pipe = load_wound_model() yolo_model = load_yolo_model() img_bytes = await file.read() img = Image.open(io.BytesIO(img_bytes)).convert("RGB") # ✅ Wound model wound_pred = wound_pipe(img)[0] wound_label = wound_pred["label"] wound_conf = float(wound_pred["score"]) # ✅ YOLO detection yolo_out = yolo_model(img) if len(yolo_out[0].boxes) > 0: cls_id = int(yolo_out[0].boxes.cls[0]) skin_label = yolo_model.names[cls_id] else: skin_label = "None" return { "wound_label": wound_label, "wound_conf": wound_conf, "skin_label": skin_label } # ------------------------------------------------------------ # ✅ API 2 — Ask Medical Question # ------------------------------------------------------------ @app.post("/ask-ai/") async def ask_ai(data: dict): tokenizer, chat_model = load_llm() prompt = f""" Detected wound: {data['wound']} Detected skin disease: {data['skin']} User question: {data['question']} Respond like a medical expert. """ inputs = tokenizer(prompt, return_tensors="pt") outputs = chat_model.generate(**inputs, max_new_tokens=200) reply = tokenizer.decode(outputs[0], skip_special_tokens=True) return {"reply": reply} # ✅ Start Uvicorn server import uvicorn if __name__ == "__main__": uvicorn.run("app:app", host="0.0.0.0", port=7860)