Spaces:
Runtime error
Runtime error
| # ============================================================================== | |
| # Phase 2: AI-Enabled Healthcare Diagnostic Tool - Backend API (Corrected) | |
| # ============================================================================== | |
| import torch | |
| import torch.nn.functional as F | |
| from fastapi import FastAPI, UploadFile, File, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from transformers import ViTForImageClassification, ViTImageProcessor | |
| from PIL import Image | |
| import io | |
| import logging | |
| import os | |
| from datetime import datetime | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Set cache directory to a writable location | |
| os.environ['TRANSFORMERS_CACHE'] = '/tmp/cache' | |
| os.environ['HF_HOME'] = '/tmp/cache' | |
| os.environ['HF_HUB_DISABLE_SYMLINKS'] = '1' | |
| # --- 1. Application Setup --- | |
| app = FastAPI( | |
| title="Pneumonia Detection API", | |
| description="An API to detect pneumonia from chest X-ray images using a Vision Transformer model.", | |
| version="1.0.0" | |
| ) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # --- 2. Configuration --- | |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| MODEL_SAVE_PATH = 'pneumonia_detection_model.pth' | |
| # Using a smaller base model to save storage | |
| BASE_MODEL = "WinKawaks/vit-tiny-patch16-224" # Much smaller than the original | |
| CLASS_NAMES = ['NORMAL', 'PNEUMONIA'] | |
| model = None | |
| processor = None | |
| # --- 3. Model Loading --- | |
| async def load_model(): | |
| global model, processor | |
| try: | |
| logger.info(f"===== Application Startup at {datetime.now().strftime('%Y-%m-%d %H:%M:%S')} =====") | |
| logger.info(f"Device: {DEVICE}") | |
| logger.info(f"Cache directory: /tmp/cache") | |
| # Create cache directory if it doesn't exist | |
| os.makedirs('/tmp/cache', exist_ok=True) | |
| logger.info(f"Loading processor and model from Hugging Face Hub: {BASE_MODEL}") | |
| # Load processor from Hugging Face | |
| processor = ViTImageProcessor.from_pretrained( | |
| BASE_MODEL, | |
| cache_dir='/tmp/cache' | |
| ) | |
| logger.info("Processor loaded successfully.") | |
| # Load model architecture from Hugging Face | |
| model = ViTForImageClassification.from_pretrained( | |
| BASE_MODEL, | |
| num_labels=len(CLASS_NAMES), | |
| ignore_mismatched_sizes=True, | |
| cache_dir='/tmp/cache' | |
| ) | |
| logger.info("Base model loaded successfully.") | |
| # Load your trained weights | |
| logger.info(f"Loading trained weights from {MODEL_SAVE_PATH}...") | |
| # Load state dict with weights_only=True for security | |
| state_dict = torch.load(MODEL_SAVE_PATH, map_location=DEVICE, weights_only=True) | |
| # Load weights with strict=False in case of size mismatches | |
| missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) | |
| if missing_keys: | |
| logger.warning(f"Missing keys: {missing_keys}") | |
| if unexpected_keys: | |
| logger.warning(f"Unexpected keys: {unexpected_keys}") | |
| logger.info("Trained weights loaded successfully.") | |
| model.to(DEVICE) | |
| model.eval() | |
| logger.info("✓ Model and processor loaded and ready!") | |
| except Exception as e: | |
| logger.error(f"Error loading model or processor: {e}") | |
| import traceback | |
| logger.error(traceback.format_exc()) | |
| model = None | |
| processor = None | |
| # --- 4. API Endpoints --- | |
| def read_root(): | |
| return { | |
| "message": "Welcome to the Pneumonia Detection API", | |
| "status": "running", | |
| "endpoints": { | |
| "docs": "/docs", | |
| "health": "/health", | |
| "predict": "/predict/" | |
| } | |
| } | |
| def health_check(): | |
| if model is None or processor is None: | |
| return { | |
| "status": "unhealthy", | |
| "reason": "Model not loaded", | |
| "device": str(DEVICE) | |
| } | |
| return { | |
| "status": "healthy", | |
| "model_loaded": True, | |
| "device": str(DEVICE), | |
| "class_names": CLASS_NAMES | |
| } | |
| async def predict(file: UploadFile = File(...)): | |
| if not model or not processor: | |
| raise HTTPException( | |
| status_code=500, | |
| detail="Model is not loaded. Check server logs or visit /health endpoint." | |
| ) | |
| # Read file contents | |
| contents = await file.read() | |
| # Validate and load image | |
| try: | |
| image = Image.open(io.BytesIO(contents)).convert("RGB") | |
| except Exception as e: | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"Invalid image file: {str(e)}" | |
| ) | |
| # Process image and make prediction | |
| try: | |
| inputs = processor(images=image, return_tensors="pt").to(DEVICE) | |
| with torch.no_grad(): | |
| outputs = model(**inputs).logits | |
| probabilities = F.softmax(outputs, dim=1)[0] | |
| predicted_class_idx = torch.argmax(probabilities).item() | |
| predicted_class = CLASS_NAMES[predicted_class_idx] | |
| confidence = probabilities[predicted_class_idx].item() | |
| return { | |
| "filename": file.filename, | |
| "prediction": predicted_class, | |
| "confidence": f"{confidence:.4f}", | |
| "probabilities": { | |
| CLASS_NAMES[0]: f"{probabilities[0].item():.4f}", | |
| CLASS_NAMES[1]: f"{probabilities[1].item():.4f}" | |
| } | |
| } | |
| except Exception as e: | |
| logger.error(f"Prediction error: {e}") | |
| raise HTTPException( | |
| status_code=500, | |
| detail=f"Error during prediction: {str(e)}" | |
| ) | |
| # --- 5. (Optional) For running directly --- | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=8000) |