from fastapi import FastAPI, File, UploadFile, HTTPException from fastapi.responses import JSONResponse from fastapi.middleware.cors import CORSMiddleware import tensorflow as tf import numpy as np from PIL import Image import io import logging import uvicorn import os # Set up logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Initialize FastAPI app app = FastAPI( title="Waste Classification API", description="API for classifying waste into categories: Glass, Metal, Organic, Paper, Plastic", version="1.0.0", docs_url="/", # Swagger UI at root for easy access ) # Add CORS middleware for web access app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Global variables model = None # IMPORTANT: Class order from training (alphabetical from image_dataset_from_directory) class_labels = ["glass", "metal", "organic", "paper", "plastic"] def load_model(): """Load the trained TensorFlow/Keras model with version compatibility handling""" try: # Debug info logger.info(f"TensorFlow version: {tf.__version__}") logger.info("=== DEBUGGING MODEL LOADING ===") model_files = [ 'model/waste_model.keras', 'model/waste_model.h5', './model/waste_model.keras', './model/waste_model.h5', ] model = None loaded_from = None for model_file in model_files: if os.path.exists(model_file): try: logger.info(f"Attempting to load: {model_file}") # Try different loading methods for compatibility try: # Method 1: Standard loading model = tf.keras.models.load_model(model_file, compile=False) logger.info(f"✅ Loaded with standard method: {model_file}") loaded_from = model_file break except Exception as e1: logger.warning(f"Standard loading failed: {e1}") # Method 2: Load with custom objects (for compatibility) try: custom_objects = { 'InputLayer': tf.keras.layers.InputLayer, 'Rescaling': tf.keras.layers.Rescaling, } model = tf.keras.models.load_model( model_file, custom_objects=custom_objects, compile=False ) logger.info(f"✅ Loaded with custom objects: {model_file}") loaded_from = model_file break except Exception as e2: logger.warning(f"Custom objects loading failed: {e2}") # Method 3: Try loading weights only try: # Create model architecture first, then load weights model = create_model_architecture() if model_file.endswith('.h5'): model.load_weights(model_file) logger.info(f"✅ Loaded weights only: {model_file}") loaded_from = f"{model_file} (weights only)" break except Exception as e3: logger.warning(f"Weights loading failed: {e3}") continue except Exception as e: logger.warning(f"Failed to load {model_file}: {e}") continue if model is None: logger.warning("All loading methods failed. Creating model from architecture...") model = create_model_architecture() logger.warning("⚠️ Using untrained model - predictions will be random!") else: logger.info(f"✅ Model loaded successfully from: {loaded_from}") # Recompile the model model.compile( optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'] ) return model except Exception as e: logger.error(f"Critical error loading model: {e}") raise Exception(f"Model loading failed: {e}") def create_model_architecture(): """Create the model architecture matching your training setup""" try: # Create the same architecture as in your training notebook base_model = tf.keras.applications.MobileNetV2( weights='imagenet', include_top=False, input_shape=(224, 224, 3) ) # Freeze base model base_model.trainable = False # Create complete model model = tf.keras.Sequential([ tf.keras.layers.Rescaling(1./255, input_shape=(224, 224, 3)), base_model, tf.keras.layers.GlobalAveragePooling2D(), tf.keras.layers.Dropout(0.2), tf.keras.layers.Dense(128, activation='relu'), tf.keras.layers.Dropout(0.5), tf.keras.layers.Dense(5, activation='softmax') # 5 classes ]) logger.info("Created model architecture successfully") return model except Exception as e: logger.error(f"Failed to create model architecture: {e}") raise def preprocess_image(image_data): """ Preprocess image to match training pipeline: - Crop ROI from ESP32 frame (400x296) - Resize to 224x224 - Convert to numpy array, add batch dim """ try: # Convert bytes to PIL Image image = Image.open(io.BytesIO(image_data)).convert("RGB") # Resize to model input size image = image.resize((224, 224), Image.Resampling.LANCZOS) # Normalize and expand dims image = np.array(image).astype("float32") # Add batch dimension image = np.expand_dims(image, axis=0) # Model has Rescaling(1./255) layer, so no manual normalization return image except Exception as e: logger.error(f"Image preprocessing error: {e}") raise HTTPException(status_code=400, detail=f"Image preprocessing failed: {e}") @app.get("/health") async def health_check(): """Health check endpoint""" try: # Quick model test dummy_input = np.random.random((1, 224, 224, 3)).astype(np.float32) prediction = model.predict(dummy_input, verbose=0) # Check if we're using the dummy model (random predictions) is_dummy_model = np.allclose(prediction.sum(), 1.0) # Should sum to ~1 model_status = "real_model" if not is_dummy_model else "dummy_model" return { "status": "healthy", "model_status": model_status, "model_loaded": model is not None, "classes": class_labels, "input_shape": "(224, 224, 3)", "model_type": "TensorFlow/Keras MobileNetV2", "prediction_sample": prediction[0].tolist() # Show first prediction } except Exception as e: return { "status": "unhealthy", "error": str(e), "model_loaded": model is not None, "classes": class_labels } @app.on_event("startup") async def startup_event(): """Load model on startup""" global model try: # Log available files for debugging logger.info("Available files in root:") for file in os.listdir('.'): logger.info(f" {file}") if os.path.exists('model'): logger.info("Available files in model/ directory:") for file in os.listdir('model'): logger.info(f" model/{file}") model = load_model() logger.info("API startup complete") # Test model with dummy input dummy_input = np.random.random((1, 224, 224, 3)).astype(np.float32) _ = model.predict(dummy_input, verbose=0) logger.info("Model test prediction successful") except Exception as e: logger.error(f"Startup failed: {e}") raise @app.post("/classify") async def classify_image(file: UploadFile = File(...)): """ Main classification endpoint for ESP32 Expected usage: curl -X POST -F "file=@image.jpg" https://your-space-url.hf.space/classify Returns: JSON: {"label": "plastic"} or {"error": "message"} """ try: # Validate file type if not file.content_type or not file.content_type.startswith('image/'): logger.warning(f"Invalid file type: {file.content_type}") raise HTTPException(status_code=400, detail="File must be an image") # Read image data image_data = await file.read() if len(image_data) == 0: raise HTTPException(status_code=400, detail="Empty image file") logger.info(f"Processing image: {file.filename}, size: {len(image_data)} bytes") # Preprocess image processed_image = preprocess_image(image_data) # Make prediction predictions = model.predict(processed_image, verbose=0) predicted_class_index = np.argmax(predictions[0]) predicted_class = class_labels[predicted_class_index] confidence = float(predictions[0][predicted_class_index]) logger.info(f"Prediction: {predicted_class} (confidence: {confidence:.3f})") # Return simple response for ESP32 - match your ESP32 expectation exactly return {"label": predicted_class.capitalize()} # Capitalize to match your ESP32 labels except HTTPException: raise except Exception as e: logger.error(f"Classification error: {str(e)}") return JSONResponse( status_code=500, content={"error": f"Classification failed: {str(e)}"} ) @app.post("/classify/detailed") async def classify_detailed(file: UploadFile = File(...)): """ Detailed classification endpoint with confidence scores """ try: # Validate file type if not file.content_type or not file.content_type.startswith('image/'): raise HTTPException(status_code=400, detail="File must be an image") # Read and process image image_data = await file.read() processed_image = preprocess_image(image_data) # Make prediction with full details predictions = model.predict(processed_image, verbose=0) predicted_class_index = np.argmax(predictions[0]) predicted_class = class_labels[predicted_class_index] confidence = float(predictions[0][predicted_class_index]) # Get all class probabilities all_probs = { class_labels[i].capitalize(): round(float(predictions[0][i]) * 100, 2) for i in range(len(class_labels)) } return { "label": predicted_class.capitalize(), "confidence": round(confidence * 100, 2), "all_probabilities": all_probs, "model_info": { "architecture": "MobileNetV2", "input_size": "224x224", "classes": len(class_labels) }, "status": "success" } except Exception as e: logger.error(f"Detailed classification error: {str(e)}") raise HTTPException(status_code=500, detail=f"Classification failed: {str(e)}") @app.get("/info") async def get_info(): """API information endpoint""" return { "api_name": "Waste Classification API", "version": "1.0.0", "model": { "architecture": "MobileNetV2 + Custom Head", "framework": "TensorFlow/Keras", "input_size": "224x224x3", "preprocessing": "RGB, Resize, Rescaling (internal)" }, "classes": [label.capitalize() for label in class_labels], "endpoints": { "/classify": "POST - Main classification endpoint (returns simple label)", "/classify/detailed": "POST - Detailed classification with confidence", "/health": "GET - Health check", "/info": "GET - API information" }, "usage": { "esp32": "POST image to /classify endpoint", "curl_example": "curl -X POST -F 'file=@image.jpg' https://your-space-url.hf.space/classify" } } @app.post("/test") async def test_with_dummy(): """Test endpoint with dummy data for debugging""" try: # Create dummy image (random noise) dummy_image = np.random.randint(0, 255, (224, 224, 3), dtype=np.uint8) dummy_input = np.expand_dims(dummy_image.astype(np.float32), axis=0) # Make prediction predictions = model.predict(dummy_input, verbose=0) predicted_class_index = np.argmax(predictions[0]) predicted_class = class_labels[predicted_class_index] return { "test_status": "success", "predicted_class": predicted_class.capitalize(), "confidence": float(predictions[0][predicted_class_index]), "all_predictions": [float(p) for p in predictions[0]] } except Exception as e: return {"test_status": "failed", "error": str(e)} if __name__ == "__main__": # Run the FastAPI app port = int(os.environ.get("PORT", 7860)) uvicorn.run( app, host="0.0.0.0", port=port, log_level="info" )