File size: 3,146 Bytes
f475440
83d9968
 
 
 
 
f475440
 
83d9968
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f475440
83d9968
 
 
 
 
 
 
f475440
83d9968
 
 
f475440
83d9968
 
f475440
83d9968
 
 
 
 
 
 
 
 
 
f475440
 
 
 
83d9968
f475440
 
 
 
 
 
 
 
 
 
 
 
 
 
83d9968
 
 
 
 
 
 
 
 
 
f475440
83d9968
 
 
 
 
 
 
 
 
f475440
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
from fastapi import FastAPI, File, UploadFile, HTTPException, Query
from fastapi.responses import JSONResponse
from tensorflow.keras.models import load_model
from PIL import Image
import numpy as np
import requests
import io
from typing import Optional

# Initialize FastAPI app
app = FastAPI()

# Load the pre-trained model
model_path = 'dog_breed.h5'  # Update with your model path
model = load_model(model_path)

# Dictionary to map index to breed name
breed_names = {
    0: 'Beagle', 1: 'Boxer', 2: 'Bulldog', 3: 'Dachshund', 4: 'German Shepherd',
    5: 'Golden Retriever', 6: 'Labrador Retriever', 7: 'Poodle', 8: 'Rottweiler', 9: 'Yorkshire Terrier'
}

# Function to preprocess the image
def preprocess_image(image: Image.Image):
    image = image.resize((150, 150))
    img_array = np.array(image)
    img_array = img_array / 255.0
    img_array = np.expand_dims(img_array, axis=0)
    return img_array

# Function to classify the breed
def classify_breed(image: Image.Image, model):
    img_array = preprocess_image(image)
    predictions = model.predict(img_array)
    predicted_class_index = np.argmax(predictions)
    return breed_names.get(predicted_class_index, "Unknown")

# Function to fetch breed information from an API
def fetch_breed_info(breed_name: str):
    url = f'https://api.thedogapi.com/v1/breeds/search?q={breed_name}'
    response = requests.get(url)
    if response.status_code == 200:
        breed_info = response.json()
        return breed_info
    else:
        return None

# API route for prediction
@app.post("/predict")
async def predict(
    file: Optional[UploadFile] = File(None),
    url: Optional[str] = Query(None, description="Public URL of the image")
):
    try:
        # Determine input method: file has priority over URL.
        if file is not None:
            # Check file type
            if file.content_type not in ["image/jpeg", "image/png", "image/jpg"]:
                raise HTTPException(status_code=400, detail="Invalid file type. Only JPG and PNG are allowed.")
            image = Image.open(file.file)
        elif url is not None:
            # Download image from URL
            resp = requests.get(url)
            if resp.status_code != 200:
                raise HTTPException(status_code=400, detail="Unable to fetch image from provided URL.")
            image = Image.open(io.BytesIO(resp.content))
        else:
            raise HTTPException(status_code=400, detail="No image provided. Please upload a file or provide a URL.")

        # Classify the breed
        breed_name = classify_breed(image, model)

        # Fetch breed information
        breed_info = fetch_breed_info(breed_name)

        # Prepare response
        response = {
            "predicted_breed": breed_name,
            "breed_info": breed_info[0] if breed_info and len(breed_info) > 0 else "No additional information available."
        }
        return JSONResponse(content=response)

    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

# Root route
@app.get("/")
def read_root():
    return {"message": "Welcome to the Dog Breed Classification API!"}