aurora / main.py
yasyn14's picture
edited docker and main.py
e7b36c2
from config import CLASS_NAMES
import os, logging, numpy as np, asyncio
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from contextlib import asynccontextmanager
from huggingface_hub import hf_hub_download
import gradio as gr
from PIL import Image
from api.v1 import router as v1_router
from models.model_loader import load_skin_condition_model
from utils.predictor import predict_skin_condition
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
# Disable OneDNN optimizations if needed
os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0"
@asynccontextmanager
async def lifespan(app: FastAPI):
try:
# Get cache directory from environment variable or use a default
cache_dir = os.environ.get("HF_HOME", "/tmp/huggingface")
logger.info(f"Downloading model from Hugging Face Hub using cache_dir: {cache_dir}...")
model_path = hf_hub_download(
repo_id="yasyn14/skin-analyzer",
filename="model-v1.keras",
cache_dir=cache_dir
)
logger.info(f"Loading model from path: {model_path}")
model = await asyncio.to_thread(load_skin_condition_model, model_path)
# warm‑up
dummy = np.zeros((1, 224, 224, 3), dtype=np.uint8)
await asyncio.to_thread(model.predict, dummy)
app.state.model = model
logger.info("Model ready ✅")
yield
except Exception as e:
logger.exception("Failed during startup:")
raise RuntimeError("Failed to load skin-condition model") from e
finally:
logger.info("Shutting down: releasing resources")
if hasattr(app.state, "model"):
del app.state.model
# === FastAPI Setup ===
app = FastAPI(
lifespan=lifespan,
title="Skin Condition Classifier API",
description="Upload skin images to detect skin conditions using AI",
version="1.0.0",
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # tighten in prod
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.get("/healthz", tags=["Health"])
async def health_check():
return {"status": "ok"}
# include your versioned REST API
app.include_router(v1_router)
# === Gradio UI Setup ===
def predict_skin_condition_grad(image: Image.Image):
if image is None:
return "No image provided"
if not hasattr(app.state, "model"):
return "Model is not loaded yet. Please try again in a moment."
model = app.state.model
# Preprocess image
img = image.resize((224, 224)).convert("RGB")
img_array = np.array(img)
# Predict
prediction = predict_skin_condition(img_array, model)
confidence = prediction.get("confidence")
label = prediction.get("condition")
return f"{label} ({confidence:.2%} confidence)"
gradio_interface = gr.Interface(
fn=predict_skin_condition_grad,
inputs=gr.Image(
type="pil",
label="Upload or capture a skin image",
sources=["upload", "webcam"], # Explicitly enable both upload and webcam
webcam_options={"facingMode": "environment"} # Use back camera by default (better for skin photos)
),
outputs=gr.Text(label="Prediction"),
title="Skin Analyzer",
description="Upload a photo or use your camera to detect skin conditions like acne, eczema, dryness, etc.",
examples=[
# Optional: Add example images if you have them
# ["examples/acne.jpg"],
# ["examples/eczema.jpg"]
],
allow_flagging="never" # Disable flagging option
)
# Mount Gradio on root
app = gr.mount_gradio_app(app, gradio_interface, path="")