Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, File, UploadFile, HTTPException | |
| from fastapi.responses import JSONResponse | |
| from fastapi.middleware.cors import CORSMiddleware | |
| import tensorflow as tf | |
| import numpy as np | |
| from PIL import Image | |
| import io | |
| import logging | |
| import uvicorn | |
| import os | |
| # Set up logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Initialize FastAPI app | |
| app = FastAPI( | |
| title="Waste Classification API", | |
| description="API for classifying waste into categories: Glass, Metal, Organic, Paper, Plastic", | |
| version="1.0.0", | |
| docs_url="/", # Swagger UI at root for easy access | |
| ) | |
| # Add CORS middleware for web access | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Global variables | |
| model = None | |
| # IMPORTANT: Class order from training (alphabetical from image_dataset_from_directory) | |
| class_labels = ["glass", "metal", "organic", "paper", "plastic"] | |
| def load_model(): | |
| """Load the trained TensorFlow/Keras model with version compatibility handling""" | |
| try: | |
| # Debug info | |
| logger.info(f"TensorFlow version: {tf.__version__}") | |
| logger.info("=== DEBUGGING MODEL LOADING ===") | |
| model_files = [ | |
| 'model/waste_model.keras', | |
| 'model/waste_model.h5', | |
| './model/waste_model.keras', | |
| './model/waste_model.h5', | |
| ] | |
| model = None | |
| loaded_from = None | |
| for model_file in model_files: | |
| if os.path.exists(model_file): | |
| try: | |
| logger.info(f"Attempting to load: {model_file}") | |
| # Try different loading methods for compatibility | |
| try: | |
| # Method 1: Standard loading | |
| model = tf.keras.models.load_model(model_file, compile=False) | |
| logger.info(f"✅ Loaded with standard method: {model_file}") | |
| loaded_from = model_file | |
| break | |
| except Exception as e1: | |
| logger.warning(f"Standard loading failed: {e1}") | |
| # Method 2: Load with custom objects (for compatibility) | |
| try: | |
| custom_objects = { | |
| 'InputLayer': tf.keras.layers.InputLayer, | |
| 'Rescaling': tf.keras.layers.Rescaling, | |
| } | |
| model = tf.keras.models.load_model( | |
| model_file, | |
| custom_objects=custom_objects, | |
| compile=False | |
| ) | |
| logger.info(f"✅ Loaded with custom objects: {model_file}") | |
| loaded_from = model_file | |
| break | |
| except Exception as e2: | |
| logger.warning(f"Custom objects loading failed: {e2}") | |
| # Method 3: Try loading weights only | |
| try: | |
| # Create model architecture first, then load weights | |
| model = create_model_architecture() | |
| if model_file.endswith('.h5'): | |
| model.load_weights(model_file) | |
| logger.info(f"✅ Loaded weights only: {model_file}") | |
| loaded_from = f"{model_file} (weights only)" | |
| break | |
| except Exception as e3: | |
| logger.warning(f"Weights loading failed: {e3}") | |
| continue | |
| except Exception as e: | |
| logger.warning(f"Failed to load {model_file}: {e}") | |
| continue | |
| if model is None: | |
| logger.warning("All loading methods failed. Creating model from architecture...") | |
| model = create_model_architecture() | |
| logger.warning("⚠️ Using untrained model - predictions will be random!") | |
| else: | |
| logger.info(f"✅ Model loaded successfully from: {loaded_from}") | |
| # Recompile the model | |
| model.compile( | |
| optimizer='adam', | |
| loss='categorical_crossentropy', | |
| metrics=['accuracy'] | |
| ) | |
| return model | |
| except Exception as e: | |
| logger.error(f"Critical error loading model: {e}") | |
| raise Exception(f"Model loading failed: {e}") | |
| def create_model_architecture(): | |
| """Create the model architecture matching your training setup""" | |
| try: | |
| # Create the same architecture as in your training notebook | |
| base_model = tf.keras.applications.MobileNetV2( | |
| weights='imagenet', | |
| include_top=False, | |
| input_shape=(224, 224, 3) | |
| ) | |
| # Freeze base model | |
| base_model.trainable = False | |
| # Create complete model | |
| model = tf.keras.Sequential([ | |
| tf.keras.layers.Rescaling(1./255, input_shape=(224, 224, 3)), | |
| base_model, | |
| tf.keras.layers.GlobalAveragePooling2D(), | |
| tf.keras.layers.Dropout(0.2), | |
| tf.keras.layers.Dense(128, activation='relu'), | |
| tf.keras.layers.Dropout(0.5), | |
| tf.keras.layers.Dense(5, activation='softmax') # 5 classes | |
| ]) | |
| logger.info("Created model architecture successfully") | |
| return model | |
| except Exception as e: | |
| logger.error(f"Failed to create model architecture: {e}") | |
| raise | |
| def preprocess_image(image_data): | |
| """ | |
| Preprocess image to match training pipeline: | |
| - Crop ROI from ESP32 frame (400x296) | |
| - Resize to 224x224 | |
| - Convert to numpy array, add batch dim | |
| """ | |
| try: | |
| # Convert bytes to PIL Image | |
| image = Image.open(io.BytesIO(image_data)).convert("RGB") | |
| # Resize to model input size | |
| image = image.resize((224, 224), Image.Resampling.LANCZOS) | |
| # Normalize and expand dims | |
| image = np.array(image).astype("float32") | |
| # Add batch dimension | |
| image = np.expand_dims(image, axis=0) | |
| # Model has Rescaling(1./255) layer, so no manual normalization | |
| return image | |
| except Exception as e: | |
| logger.error(f"Image preprocessing error: {e}") | |
| raise HTTPException(status_code=400, detail=f"Image preprocessing failed: {e}") | |
| async def health_check(): | |
| """Health check endpoint""" | |
| try: | |
| # Quick model test | |
| dummy_input = np.random.random((1, 224, 224, 3)).astype(np.float32) | |
| prediction = model.predict(dummy_input, verbose=0) | |
| # Check if we're using the dummy model (random predictions) | |
| is_dummy_model = np.allclose(prediction.sum(), 1.0) # Should sum to ~1 | |
| model_status = "real_model" if not is_dummy_model else "dummy_model" | |
| return { | |
| "status": "healthy", | |
| "model_status": model_status, | |
| "model_loaded": model is not None, | |
| "classes": class_labels, | |
| "input_shape": "(224, 224, 3)", | |
| "model_type": "TensorFlow/Keras MobileNetV2", | |
| "prediction_sample": prediction[0].tolist() # Show first prediction | |
| } | |
| except Exception as e: | |
| return { | |
| "status": "unhealthy", | |
| "error": str(e), | |
| "model_loaded": model is not None, | |
| "classes": class_labels | |
| } | |
| async def startup_event(): | |
| """Load model on startup""" | |
| global model | |
| try: | |
| # Log available files for debugging | |
| logger.info("Available files in root:") | |
| for file in os.listdir('.'): | |
| logger.info(f" {file}") | |
| if os.path.exists('model'): | |
| logger.info("Available files in model/ directory:") | |
| for file in os.listdir('model'): | |
| logger.info(f" model/{file}") | |
| model = load_model() | |
| logger.info("API startup complete") | |
| # Test model with dummy input | |
| dummy_input = np.random.random((1, 224, 224, 3)).astype(np.float32) | |
| _ = model.predict(dummy_input, verbose=0) | |
| logger.info("Model test prediction successful") | |
| except Exception as e: | |
| logger.error(f"Startup failed: {e}") | |
| raise | |
| async def classify_image(file: UploadFile = File(...)): | |
| """ | |
| Main classification endpoint for ESP32 | |
| Expected usage: | |
| curl -X POST -F "file=@image.jpg" https://your-space-url.hf.space/classify | |
| Returns: | |
| JSON: {"label": "plastic"} or {"error": "message"} | |
| """ | |
| try: | |
| # Validate file type | |
| if not file.content_type or not file.content_type.startswith('image/'): | |
| logger.warning(f"Invalid file type: {file.content_type}") | |
| raise HTTPException(status_code=400, detail="File must be an image") | |
| # Read image data | |
| image_data = await file.read() | |
| if len(image_data) == 0: | |
| raise HTTPException(status_code=400, detail="Empty image file") | |
| logger.info(f"Processing image: {file.filename}, size: {len(image_data)} bytes") | |
| # Preprocess image | |
| processed_image = preprocess_image(image_data) | |
| # Make prediction | |
| predictions = model.predict(processed_image, verbose=0) | |
| predicted_class_index = np.argmax(predictions[0]) | |
| predicted_class = class_labels[predicted_class_index] | |
| confidence = float(predictions[0][predicted_class_index]) | |
| logger.info(f"Prediction: {predicted_class} (confidence: {confidence:.3f})") | |
| # Return simple response for ESP32 - match your ESP32 expectation exactly | |
| return {"label": predicted_class.capitalize()} # Capitalize to match your ESP32 labels | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| logger.error(f"Classification error: {str(e)}") | |
| return JSONResponse( | |
| status_code=500, | |
| content={"error": f"Classification failed: {str(e)}"} | |
| ) | |
| async def classify_detailed(file: UploadFile = File(...)): | |
| """ | |
| Detailed classification endpoint with confidence scores | |
| """ | |
| try: | |
| # Validate file type | |
| if not file.content_type or not file.content_type.startswith('image/'): | |
| raise HTTPException(status_code=400, detail="File must be an image") | |
| # Read and process image | |
| image_data = await file.read() | |
| processed_image = preprocess_image(image_data) | |
| # Make prediction with full details | |
| predictions = model.predict(processed_image, verbose=0) | |
| predicted_class_index = np.argmax(predictions[0]) | |
| predicted_class = class_labels[predicted_class_index] | |
| confidence = float(predictions[0][predicted_class_index]) | |
| # Get all class probabilities | |
| all_probs = { | |
| class_labels[i].capitalize(): round(float(predictions[0][i]) * 100, 2) | |
| for i in range(len(class_labels)) | |
| } | |
| return { | |
| "label": predicted_class.capitalize(), | |
| "confidence": round(confidence * 100, 2), | |
| "all_probabilities": all_probs, | |
| "model_info": { | |
| "architecture": "MobileNetV2", | |
| "input_size": "224x224", | |
| "classes": len(class_labels) | |
| }, | |
| "status": "success" | |
| } | |
| except Exception as e: | |
| logger.error(f"Detailed classification error: {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"Classification failed: {str(e)}") | |
| async def get_info(): | |
| """API information endpoint""" | |
| return { | |
| "api_name": "Waste Classification API", | |
| "version": "1.0.0", | |
| "model": { | |
| "architecture": "MobileNetV2 + Custom Head", | |
| "framework": "TensorFlow/Keras", | |
| "input_size": "224x224x3", | |
| "preprocessing": "RGB, Resize, Rescaling (internal)" | |
| }, | |
| "classes": [label.capitalize() for label in class_labels], | |
| "endpoints": { | |
| "/classify": "POST - Main classification endpoint (returns simple label)", | |
| "/classify/detailed": "POST - Detailed classification with confidence", | |
| "/health": "GET - Health check", | |
| "/info": "GET - API information" | |
| }, | |
| "usage": { | |
| "esp32": "POST image to /classify endpoint", | |
| "curl_example": "curl -X POST -F 'file=@image.jpg' https://your-space-url.hf.space/classify" | |
| } | |
| } | |
| async def test_with_dummy(): | |
| """Test endpoint with dummy data for debugging""" | |
| try: | |
| # Create dummy image (random noise) | |
| dummy_image = np.random.randint(0, 255, (224, 224, 3), dtype=np.uint8) | |
| dummy_input = np.expand_dims(dummy_image.astype(np.float32), axis=0) | |
| # Make prediction | |
| predictions = model.predict(dummy_input, verbose=0) | |
| predicted_class_index = np.argmax(predictions[0]) | |
| predicted_class = class_labels[predicted_class_index] | |
| return { | |
| "test_status": "success", | |
| "predicted_class": predicted_class.capitalize(), | |
| "confidence": float(predictions[0][predicted_class_index]), | |
| "all_predictions": [float(p) for p in predictions[0]] | |
| } | |
| except Exception as e: | |
| return {"test_status": "failed", "error": str(e)} | |
| if __name__ == "__main__": | |
| # Run the FastAPI app | |
| port = int(os.environ.get("PORT", 7860)) | |
| uvicorn.run( | |
| app, | |
| host="0.0.0.0", | |
| port=port, | |
| log_level="info" | |
| ) |