Fruits_API / app.py
ihtesham0345's picture
Initial commit with LFS
31082a9
# import base64 # Not needed
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.responses import StreamingResponse
from PIL import Image, ImageDraw, ImageFont
import numpy as np
import tensorflow as tf
import io
from typing import List
app = FastAPI(title="Fruit Classifier API")
# Load the model
# Assuming the model is in the same directory
MODEL_PATH = "fruit_classifier_model.h5"
try:
model = tf.keras.models.load_model(MODEL_PATH)
print(f"Model loaded successfully from {MODEL_PATH}")
except Exception as e:
print(f"Error loading model: {e}")
# We allow the app to start even if model fails, but predict will fail
model = None
# Class names extracted from the training notebook
CLASS_NAMES = [
'Apple', 'Apricots', 'Avocado', 'Banana', 'Blackberries', 'Blueberry',
'Cantaloupe', 'Cherry', 'Coconut', 'Dates', 'Dragon fruit', 'Fig',
'Grapes', 'Guava', 'Jackfruit', 'Kiwi', 'Lemons', 'Lychee', 'Mango',
'Olive', 'Orange', 'Papaya', 'Pear', 'Persimmon', 'Pineapple', 'Plum',
'Pomegranate', 'Rambutan', 'Raspberry', 'Salak', 'Sapodilla', 'Soursop',
'Starfruit', 'Strawberry', 'Watermelon'
]
def preprocess_image(image: Image.Image) -> np.ndarray:
"""
Preprocess the image to match the model's expected input.
EfficientNet usually expects (224, 224, 3) and values in [0, 255]
if using the internal preprocessing layer, or pre-scaled if not.
The notebook showed:
tf.keras.utils.image_dataset_from_directory(..., image_size=(224, 224), ...)
and the model used Rescaling/Normalization layers inside it (efficientnetb0 usually has it or we saw Rescaling layer in summary).
The provided summary showed:
rescaling (Rescaling) ...
normalization (Normalization) ...
So we just need to resize to (224, 224) and provide inputs as they are (0-255 usually for uint8, but converting to float32 is safer).
"""
if image.mode != "RGB":
image = image.convert("RGB")
image_resized = image.resize((224, 224))
image_array = np.array(image_resized)
image_array = np.expand_dims(image_array, axis=0) # Add batch dimension
return image_array
@app.get("/")
def read_root():
return {"message": "Welcome to the Fruit Classifier API"}
@app.post("/predict")
async def predict(file: UploadFile = File(...)):
if model is None:
raise HTTPException(status_code=500, detail="Model not loaded")
try:
contents = await file.read()
image = Image.open(io.BytesIO(contents))
processed_image = preprocess_image(image)
predictions = model.predict(processed_image)
predicted_class_index = np.argmax(predictions[0])
confidence = float(predictions[0][predicted_class_index])
predicted_class = CLASS_NAMES[predicted_class_index]
return {
"prediction": predicted_class,
"confidence": confidence,
"filename": file.filename
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/predict_image")
async def predict_image(file: UploadFile = File(...)):
if model is None:
raise HTTPException(status_code=500, detail="Model not loaded")
try:
contents = await file.read()
original_image = Image.open(io.BytesIO(contents))
# Ensure RGB
if original_image.mode != "RGB":
original_image = original_image.convert("RGB")
processed_image = preprocess_image(original_image)
predictions = model.predict(processed_image)
predicted_class_index = np.argmax(predictions[0])
confidence = float(predictions[0][predicted_class_index])
predicted_class = CLASS_NAMES[predicted_class_index]
# Draw on the original image
draw = ImageDraw.Draw(original_image)
# Try to load a nice font, otherwise default
try:
# Try loading a system font (Windows usually has arial)
font = ImageFont.truetype("arial.ttf", size=int(original_image.height / 20))
except IOError:
font = ImageFont.load_default()
text = f"{predicted_class} ({confidence:.2f})"
# Calculate text position (top-left or centered-top)
text_position = (10, 10)
# Draw text with outline for better visibility
x, y = text_position
outline_color = "black"
text_color = "red"
draw.text((x-1, y-1), text, font=font, fill=outline_color)
draw.text((x+1, y-1), text, font=font, fill=outline_color)
draw.text((x-1, y+1), text, font=font, fill=outline_color)
draw.text((x+1, y+1), text, font=font, fill=outline_color)
draw.text(text_position, text, font=font, fill=text_color)
# Save to bytes
img_byte_arr = io.BytesIO()
original_image.save(img_byte_arr, format='JPEG')
img_byte_arr.seek(0)
return StreamingResponse(img_byte_arr, media_type="image/jpeg")
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)