from fastapi import FastAPI, File, UploadFile, HTTPException, Query from fastapi.responses import JSONResponse from tensorflow.keras.models import load_model from PIL import Image import numpy as np import requests import io from typing import Optional # Initialize FastAPI app app = FastAPI() # Load the pre-trained model model_path = 'dog_breed.h5' # Update with your model path model = load_model(model_path) # Dictionary to map index to breed name breed_names = { 0: 'Beagle', 1: 'Boxer', 2: 'Bulldog', 3: 'Dachshund', 4: 'German Shepherd', 5: 'Golden Retriever', 6: 'Labrador Retriever', 7: 'Poodle', 8: 'Rottweiler', 9: 'Yorkshire Terrier' } # Function to preprocess the image def preprocess_image(image: Image.Image): image = image.resize((150, 150)) img_array = np.array(image) img_array = img_array / 255.0 img_array = np.expand_dims(img_array, axis=0) return img_array # Function to classify the breed def classify_breed(image: Image.Image, model): img_array = preprocess_image(image) predictions = model.predict(img_array) predicted_class_index = np.argmax(predictions) return breed_names.get(predicted_class_index, "Unknown") # Function to fetch breed information from an API def fetch_breed_info(breed_name: str): url = f'https://api.thedogapi.com/v1/breeds/search?q={breed_name}' response = requests.get(url) if response.status_code == 200: breed_info = response.json() return breed_info else: return None # API route for prediction @app.post("/predict") async def predict( file: Optional[UploadFile] = File(None), url: Optional[str] = Query(None, description="Public URL of the image") ): try: # Determine input method: file has priority over URL. if file is not None: # Check file type if file.content_type not in ["image/jpeg", "image/png", "image/jpg"]: raise HTTPException(status_code=400, detail="Invalid file type. Only JPG and PNG are allowed.") image = Image.open(file.file) elif url is not None: # Download image from URL resp = requests.get(url) if resp.status_code != 200: raise HTTPException(status_code=400, detail="Unable to fetch image from provided URL.") image = Image.open(io.BytesIO(resp.content)) else: raise HTTPException(status_code=400, detail="No image provided. Please upload a file or provide a URL.") # Classify the breed breed_name = classify_breed(image, model) # Fetch breed information breed_info = fetch_breed_info(breed_name) # Prepare response response = { "predicted_breed": breed_name, "breed_info": breed_info[0] if breed_info and len(breed_info) > 0 else "No additional information available." } return JSONResponse(content=response) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) # Root route @app.get("/") def read_root(): return {"message": "Welcome to the Dog Breed Classification API!"}