Penthes's picture
Update app.py
ca400c7 verified
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
import tensorflow as tf
import numpy as np
from PIL import Image
import io
import logging
import uvicorn
import os
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Initialize FastAPI app
app = FastAPI(
title="Waste Classification API",
description="API for classifying waste into categories: Glass, Metal, Organic, Paper, Plastic",
version="1.0.0",
docs_url="/", # Swagger UI at root for easy access
)
# Add CORS middleware for web access
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Global variables
model = None
# IMPORTANT: Class order from training (alphabetical from image_dataset_from_directory)
class_labels = ["glass", "metal", "organic", "paper", "plastic"]
def load_model():
"""Load the trained TensorFlow/Keras model with version compatibility handling"""
try:
# Debug info
logger.info(f"TensorFlow version: {tf.__version__}")
logger.info("=== DEBUGGING MODEL LOADING ===")
model_files = [
'model/waste_model.keras',
'model/waste_model.h5',
'./model/waste_model.keras',
'./model/waste_model.h5',
]
model = None
loaded_from = None
for model_file in model_files:
if os.path.exists(model_file):
try:
logger.info(f"Attempting to load: {model_file}")
# Try different loading methods for compatibility
try:
# Method 1: Standard loading
model = tf.keras.models.load_model(model_file, compile=False)
logger.info(f"✅ Loaded with standard method: {model_file}")
loaded_from = model_file
break
except Exception as e1:
logger.warning(f"Standard loading failed: {e1}")
# Method 2: Load with custom objects (for compatibility)
try:
custom_objects = {
'InputLayer': tf.keras.layers.InputLayer,
'Rescaling': tf.keras.layers.Rescaling,
}
model = tf.keras.models.load_model(
model_file,
custom_objects=custom_objects,
compile=False
)
logger.info(f"✅ Loaded with custom objects: {model_file}")
loaded_from = model_file
break
except Exception as e2:
logger.warning(f"Custom objects loading failed: {e2}")
# Method 3: Try loading weights only
try:
# Create model architecture first, then load weights
model = create_model_architecture()
if model_file.endswith('.h5'):
model.load_weights(model_file)
logger.info(f"✅ Loaded weights only: {model_file}")
loaded_from = f"{model_file} (weights only)"
break
except Exception as e3:
logger.warning(f"Weights loading failed: {e3}")
continue
except Exception as e:
logger.warning(f"Failed to load {model_file}: {e}")
continue
if model is None:
logger.warning("All loading methods failed. Creating model from architecture...")
model = create_model_architecture()
logger.warning("⚠️ Using untrained model - predictions will be random!")
else:
logger.info(f"✅ Model loaded successfully from: {loaded_from}")
# Recompile the model
model.compile(
optimizer='adam',
loss='categorical_crossentropy',
metrics=['accuracy']
)
return model
except Exception as e:
logger.error(f"Critical error loading model: {e}")
raise Exception(f"Model loading failed: {e}")
def create_model_architecture():
"""Create the model architecture matching your training setup"""
try:
# Create the same architecture as in your training notebook
base_model = tf.keras.applications.MobileNetV2(
weights='imagenet',
include_top=False,
input_shape=(224, 224, 3)
)
# Freeze base model
base_model.trainable = False
# Create complete model
model = tf.keras.Sequential([
tf.keras.layers.Rescaling(1./255, input_shape=(224, 224, 3)),
base_model,
tf.keras.layers.GlobalAveragePooling2D(),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dropout(0.5),
tf.keras.layers.Dense(5, activation='softmax') # 5 classes
])
logger.info("Created model architecture successfully")
return model
except Exception as e:
logger.error(f"Failed to create model architecture: {e}")
raise
def preprocess_image(image_data):
"""
Preprocess image to match training pipeline:
- Crop ROI from ESP32 frame (400x296)
- Resize to 224x224
- Convert to numpy array, add batch dim
"""
try:
# Convert bytes to PIL Image
image = Image.open(io.BytesIO(image_data)).convert("RGB")
# Resize to model input size
image = image.resize((224, 224), Image.Resampling.LANCZOS)
# Normalize and expand dims
image = np.array(image).astype("float32")
# Add batch dimension
image = np.expand_dims(image, axis=0)
# Model has Rescaling(1./255) layer, so no manual normalization
return image
except Exception as e:
logger.error(f"Image preprocessing error: {e}")
raise HTTPException(status_code=400, detail=f"Image preprocessing failed: {e}")
@app.get("/health")
async def health_check():
"""Health check endpoint"""
try:
# Quick model test
dummy_input = np.random.random((1, 224, 224, 3)).astype(np.float32)
prediction = model.predict(dummy_input, verbose=0)
# Check if we're using the dummy model (random predictions)
is_dummy_model = np.allclose(prediction.sum(), 1.0) # Should sum to ~1
model_status = "real_model" if not is_dummy_model else "dummy_model"
return {
"status": "healthy",
"model_status": model_status,
"model_loaded": model is not None,
"classes": class_labels,
"input_shape": "(224, 224, 3)",
"model_type": "TensorFlow/Keras MobileNetV2",
"prediction_sample": prediction[0].tolist() # Show first prediction
}
except Exception as e:
return {
"status": "unhealthy",
"error": str(e),
"model_loaded": model is not None,
"classes": class_labels
}
@app.on_event("startup")
async def startup_event():
"""Load model on startup"""
global model
try:
# Log available files for debugging
logger.info("Available files in root:")
for file in os.listdir('.'):
logger.info(f" {file}")
if os.path.exists('model'):
logger.info("Available files in model/ directory:")
for file in os.listdir('model'):
logger.info(f" model/{file}")
model = load_model()
logger.info("API startup complete")
# Test model with dummy input
dummy_input = np.random.random((1, 224, 224, 3)).astype(np.float32)
_ = model.predict(dummy_input, verbose=0)
logger.info("Model test prediction successful")
except Exception as e:
logger.error(f"Startup failed: {e}")
raise
@app.post("/classify")
async def classify_image(file: UploadFile = File(...)):
"""
Main classification endpoint for ESP32
Expected usage:
curl -X POST -F "file=@image.jpg" https://your-space-url.hf.space/classify
Returns:
JSON: {"label": "plastic"} or {"error": "message"}
"""
try:
# Validate file type
if not file.content_type or not file.content_type.startswith('image/'):
logger.warning(f"Invalid file type: {file.content_type}")
raise HTTPException(status_code=400, detail="File must be an image")
# Read image data
image_data = await file.read()
if len(image_data) == 0:
raise HTTPException(status_code=400, detail="Empty image file")
logger.info(f"Processing image: {file.filename}, size: {len(image_data)} bytes")
# Preprocess image
processed_image = preprocess_image(image_data)
# Make prediction
predictions = model.predict(processed_image, verbose=0)
predicted_class_index = np.argmax(predictions[0])
predicted_class = class_labels[predicted_class_index]
confidence = float(predictions[0][predicted_class_index])
logger.info(f"Prediction: {predicted_class} (confidence: {confidence:.3f})")
# Return simple response for ESP32 - match your ESP32 expectation exactly
return {"label": predicted_class.capitalize()} # Capitalize to match your ESP32 labels
except HTTPException:
raise
except Exception as e:
logger.error(f"Classification error: {str(e)}")
return JSONResponse(
status_code=500,
content={"error": f"Classification failed: {str(e)}"}
)
@app.post("/classify/detailed")
async def classify_detailed(file: UploadFile = File(...)):
"""
Detailed classification endpoint with confidence scores
"""
try:
# Validate file type
if not file.content_type or not file.content_type.startswith('image/'):
raise HTTPException(status_code=400, detail="File must be an image")
# Read and process image
image_data = await file.read()
processed_image = preprocess_image(image_data)
# Make prediction with full details
predictions = model.predict(processed_image, verbose=0)
predicted_class_index = np.argmax(predictions[0])
predicted_class = class_labels[predicted_class_index]
confidence = float(predictions[0][predicted_class_index])
# Get all class probabilities
all_probs = {
class_labels[i].capitalize(): round(float(predictions[0][i]) * 100, 2)
for i in range(len(class_labels))
}
return {
"label": predicted_class.capitalize(),
"confidence": round(confidence * 100, 2),
"all_probabilities": all_probs,
"model_info": {
"architecture": "MobileNetV2",
"input_size": "224x224",
"classes": len(class_labels)
},
"status": "success"
}
except Exception as e:
logger.error(f"Detailed classification error: {str(e)}")
raise HTTPException(status_code=500, detail=f"Classification failed: {str(e)}")
@app.get("/info")
async def get_info():
"""API information endpoint"""
return {
"api_name": "Waste Classification API",
"version": "1.0.0",
"model": {
"architecture": "MobileNetV2 + Custom Head",
"framework": "TensorFlow/Keras",
"input_size": "224x224x3",
"preprocessing": "RGB, Resize, Rescaling (internal)"
},
"classes": [label.capitalize() for label in class_labels],
"endpoints": {
"/classify": "POST - Main classification endpoint (returns simple label)",
"/classify/detailed": "POST - Detailed classification with confidence",
"/health": "GET - Health check",
"/info": "GET - API information"
},
"usage": {
"esp32": "POST image to /classify endpoint",
"curl_example": "curl -X POST -F 'file=@image.jpg' https://your-space-url.hf.space/classify"
}
}
@app.post("/test")
async def test_with_dummy():
"""Test endpoint with dummy data for debugging"""
try:
# Create dummy image (random noise)
dummy_image = np.random.randint(0, 255, (224, 224, 3), dtype=np.uint8)
dummy_input = np.expand_dims(dummy_image.astype(np.float32), axis=0)
# Make prediction
predictions = model.predict(dummy_input, verbose=0)
predicted_class_index = np.argmax(predictions[0])
predicted_class = class_labels[predicted_class_index]
return {
"test_status": "success",
"predicted_class": predicted_class.capitalize(),
"confidence": float(predictions[0][predicted_class_index]),
"all_predictions": [float(p) for p in predictions[0]]
}
except Exception as e:
return {"test_status": "failed", "error": str(e)}
if __name__ == "__main__":
# Run the FastAPI app
port = int(os.environ.get("PORT", 7860))
uvicorn.run(
app,
host="0.0.0.0",
port=port,
log_level="info"
)