|
|
from fastapi import FastAPI, File, UploadFile, HTTPException, Query |
|
|
from fastapi.responses import JSONResponse |
|
|
from tensorflow.keras.models import load_model |
|
|
from PIL import Image |
|
|
import numpy as np |
|
|
import requests |
|
|
import io |
|
|
from typing import Optional |
|
|
|
|
|
|
|
|
app = FastAPI() |
|
|
|
|
|
|
|
|
model_path = 'dog_breed.h5' |
|
|
model = load_model(model_path) |
|
|
|
|
|
|
|
|
breed_names = { |
|
|
0: 'Beagle', 1: 'Boxer', 2: 'Bulldog', 3: 'Dachshund', 4: 'German Shepherd', |
|
|
5: 'Golden Retriever', 6: 'Labrador Retriever', 7: 'Poodle', 8: 'Rottweiler', 9: 'Yorkshire Terrier' |
|
|
} |
|
|
|
|
|
|
|
|
def preprocess_image(image: Image.Image): |
|
|
image = image.resize((150, 150)) |
|
|
img_array = np.array(image) |
|
|
img_array = img_array / 255.0 |
|
|
img_array = np.expand_dims(img_array, axis=0) |
|
|
return img_array |
|
|
|
|
|
|
|
|
def classify_breed(image: Image.Image, model): |
|
|
img_array = preprocess_image(image) |
|
|
predictions = model.predict(img_array) |
|
|
predicted_class_index = np.argmax(predictions) |
|
|
return breed_names.get(predicted_class_index, "Unknown") |
|
|
|
|
|
|
|
|
def fetch_breed_info(breed_name: str): |
|
|
url = f'https://api.thedogapi.com/v1/breeds/search?q={breed_name}' |
|
|
response = requests.get(url) |
|
|
if response.status_code == 200: |
|
|
breed_info = response.json() |
|
|
return breed_info |
|
|
else: |
|
|
return None |
|
|
|
|
|
|
|
|
@app.post("/predict") |
|
|
async def predict( |
|
|
file: Optional[UploadFile] = File(None), |
|
|
url: Optional[str] = Query(None, description="Public URL of the image") |
|
|
): |
|
|
try: |
|
|
|
|
|
if file is not None: |
|
|
|
|
|
if file.content_type not in ["image/jpeg", "image/png", "image/jpg"]: |
|
|
raise HTTPException(status_code=400, detail="Invalid file type. Only JPG and PNG are allowed.") |
|
|
image = Image.open(file.file) |
|
|
elif url is not None: |
|
|
|
|
|
resp = requests.get(url) |
|
|
if resp.status_code != 200: |
|
|
raise HTTPException(status_code=400, detail="Unable to fetch image from provided URL.") |
|
|
image = Image.open(io.BytesIO(resp.content)) |
|
|
else: |
|
|
raise HTTPException(status_code=400, detail="No image provided. Please upload a file or provide a URL.") |
|
|
|
|
|
|
|
|
breed_name = classify_breed(image, model) |
|
|
|
|
|
|
|
|
breed_info = fetch_breed_info(breed_name) |
|
|
|
|
|
|
|
|
response = { |
|
|
"predicted_breed": breed_name, |
|
|
"breed_info": breed_info[0] if breed_info and len(breed_info) > 0 else "No additional information available." |
|
|
} |
|
|
return JSONResponse(content=response) |
|
|
|
|
|
except Exception as e: |
|
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
|
|
|
|
@app.get("/") |
|
|
def read_root(): |
|
|
return {"message": "Welcome to the Dog Breed Classification API!"} |
|
|
|