# import base64 # Not needed from fastapi import FastAPI, File, UploadFile, HTTPException from fastapi.responses import StreamingResponse from PIL import Image, ImageDraw, ImageFont import numpy as np import tensorflow as tf import io from typing import List app = FastAPI(title="Fruit Classifier API") # Load the model # Assuming the model is in the same directory MODEL_PATH = "fruit_classifier_model.h5" try: model = tf.keras.models.load_model(MODEL_PATH) print(f"Model loaded successfully from {MODEL_PATH}") except Exception as e: print(f"Error loading model: {e}") # We allow the app to start even if model fails, but predict will fail model = None # Class names extracted from the training notebook CLASS_NAMES = [ 'Apple', 'Apricots', 'Avocado', 'Banana', 'Blackberries', 'Blueberry', 'Cantaloupe', 'Cherry', 'Coconut', 'Dates', 'Dragon fruit', 'Fig', 'Grapes', 'Guava', 'Jackfruit', 'Kiwi', 'Lemons', 'Lychee', 'Mango', 'Olive', 'Orange', 'Papaya', 'Pear', 'Persimmon', 'Pineapple', 'Plum', 'Pomegranate', 'Rambutan', 'Raspberry', 'Salak', 'Sapodilla', 'Soursop', 'Starfruit', 'Strawberry', 'Watermelon' ] def preprocess_image(image: Image.Image) -> np.ndarray: """ Preprocess the image to match the model's expected input. EfficientNet usually expects (224, 224, 3) and values in [0, 255] if using the internal preprocessing layer, or pre-scaled if not. The notebook showed: tf.keras.utils.image_dataset_from_directory(..., image_size=(224, 224), ...) and the model used Rescaling/Normalization layers inside it (efficientnetb0 usually has it or we saw Rescaling layer in summary). The provided summary showed: rescaling (Rescaling) ... normalization (Normalization) ... So we just need to resize to (224, 224) and provide inputs as they are (0-255 usually for uint8, but converting to float32 is safer). """ if image.mode != "RGB": image = image.convert("RGB") image_resized = image.resize((224, 224)) image_array = np.array(image_resized) image_array = np.expand_dims(image_array, axis=0) # Add batch dimension return image_array @app.get("/") def read_root(): return {"message": "Welcome to the Fruit Classifier API"} @app.post("/predict") async def predict(file: UploadFile = File(...)): if model is None: raise HTTPException(status_code=500, detail="Model not loaded") try: contents = await file.read() image = Image.open(io.BytesIO(contents)) processed_image = preprocess_image(image) predictions = model.predict(processed_image) predicted_class_index = np.argmax(predictions[0]) confidence = float(predictions[0][predicted_class_index]) predicted_class = CLASS_NAMES[predicted_class_index] return { "prediction": predicted_class, "confidence": confidence, "filename": file.filename } except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.post("/predict_image") async def predict_image(file: UploadFile = File(...)): if model is None: raise HTTPException(status_code=500, detail="Model not loaded") try: contents = await file.read() original_image = Image.open(io.BytesIO(contents)) # Ensure RGB if original_image.mode != "RGB": original_image = original_image.convert("RGB") processed_image = preprocess_image(original_image) predictions = model.predict(processed_image) predicted_class_index = np.argmax(predictions[0]) confidence = float(predictions[0][predicted_class_index]) predicted_class = CLASS_NAMES[predicted_class_index] # Draw on the original image draw = ImageDraw.Draw(original_image) # Try to load a nice font, otherwise default try: # Try loading a system font (Windows usually has arial) font = ImageFont.truetype("arial.ttf", size=int(original_image.height / 20)) except IOError: font = ImageFont.load_default() text = f"{predicted_class} ({confidence:.2f})" # Calculate text position (top-left or centered-top) text_position = (10, 10) # Draw text with outline for better visibility x, y = text_position outline_color = "black" text_color = "red" draw.text((x-1, y-1), text, font=font, fill=outline_color) draw.text((x+1, y-1), text, font=font, fill=outline_color) draw.text((x-1, y+1), text, font=font, fill=outline_color) draw.text((x+1, y+1), text, font=font, fill=outline_color) draw.text(text_position, text, font=font, fill=text_color) # Save to bytes img_byte_arr = io.BytesIO() original_image.save(img_byte_arr, format='JPEG') img_byte_arr.seek(0) return StreamingResponse(img_byte_arr, media_type="image/jpeg") except Exception as e: raise HTTPException(status_code=500, detail=str(e)) if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=7860)