Spaces:
Running
Running
| from fastapi import FastAPI, File, HTTPException, UploadFile | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from app.config import settings | |
| from app.schemas import PredictionResponse | |
| app = FastAPI( | |
| title="Derm Foundation Classifier API", | |
| description="Derm Foundation embedding backbone + PyTorch MLP head.", | |
| version="1.0.0", | |
| ) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=[origin.strip() for origin in settings.cors_origins.split(",")], | |
| allow_credentials=False, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| app.state.predictor = None | |
| def get_predictor(): | |
| if app.state.predictor is None: | |
| print("Loading TwoStageDermPredictor...", flush=True) | |
| from app.services.predictor import TwoStageDermPredictor | |
| app.state.predictor = TwoStageDermPredictor( | |
| derm_model_id=settings.derm_model_id, | |
| head_checkpoint_path=str(settings.head_checkpoint_path), | |
| hf_token=settings.hf_token, | |
| local_files_only=settings.local_files_only, | |
| image_size=settings.image_size, | |
| device_name=settings.device, | |
| ) | |
| print("TwoStageDermPredictor loaded.", flush=True) | |
| return app.state.predictor | |
| def root(): | |
| return {"message": "Derm Foundation API is running"} | |
| def health(): | |
| return {"status": "ok"} | |
| async def predict(file: UploadFile = File(...)): | |
| if file.content_type is not None and not file.content_type.startswith("image/"): | |
| raise HTTPException(status_code=400, detail="Uploaded file must be an image.") | |
| image_bytes = await file.read() | |
| if not image_bytes: | |
| raise HTTPException(status_code=400, detail="Uploaded image is empty.") | |
| try: | |
| predictor = get_predictor() | |
| return predictor.predict(image_bytes) | |
| except Exception as exc: | |
| raise HTTPException(status_code=500, detail=str(exc)) from exc |