rajkhanke's picture
Update app.py
f475440 verified
from fastapi import FastAPI, File, UploadFile, HTTPException, Query
from fastapi.responses import JSONResponse
from tensorflow.keras.models import load_model
from PIL import Image
import numpy as np
import requests
import io
from typing import Optional
# Initialize FastAPI app
app = FastAPI()
# Load the pre-trained model
model_path = 'dog_breed.h5' # Update with your model path
model = load_model(model_path)
# Dictionary to map index to breed name
breed_names = {
0: 'Beagle', 1: 'Boxer', 2: 'Bulldog', 3: 'Dachshund', 4: 'German Shepherd',
5: 'Golden Retriever', 6: 'Labrador Retriever', 7: 'Poodle', 8: 'Rottweiler', 9: 'Yorkshire Terrier'
}
# Function to preprocess the image
def preprocess_image(image: Image.Image):
image = image.resize((150, 150))
img_array = np.array(image)
img_array = img_array / 255.0
img_array = np.expand_dims(img_array, axis=0)
return img_array
# Function to classify the breed
def classify_breed(image: Image.Image, model):
img_array = preprocess_image(image)
predictions = model.predict(img_array)
predicted_class_index = np.argmax(predictions)
return breed_names.get(predicted_class_index, "Unknown")
# Function to fetch breed information from an API
def fetch_breed_info(breed_name: str):
url = f'https://api.thedogapi.com/v1/breeds/search?q={breed_name}'
response = requests.get(url)
if response.status_code == 200:
breed_info = response.json()
return breed_info
else:
return None
# API route for prediction
@app.post("/predict")
async def predict(
file: Optional[UploadFile] = File(None),
url: Optional[str] = Query(None, description="Public URL of the image")
):
try:
# Determine input method: file has priority over URL.
if file is not None:
# Check file type
if file.content_type not in ["image/jpeg", "image/png", "image/jpg"]:
raise HTTPException(status_code=400, detail="Invalid file type. Only JPG and PNG are allowed.")
image = Image.open(file.file)
elif url is not None:
# Download image from URL
resp = requests.get(url)
if resp.status_code != 200:
raise HTTPException(status_code=400, detail="Unable to fetch image from provided URL.")
image = Image.open(io.BytesIO(resp.content))
else:
raise HTTPException(status_code=400, detail="No image provided. Please upload a file or provide a URL.")
# Classify the breed
breed_name = classify_breed(image, model)
# Fetch breed information
breed_info = fetch_breed_info(breed_name)
# Prepare response
response = {
"predicted_breed": breed_name,
"breed_info": breed_info[0] if breed_info and len(breed_info) > 0 else "No additional information available."
}
return JSONResponse(content=response)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# Root route
@app.get("/")
def read_root():
return {"message": "Welcome to the Dog Breed Classification API!"}