rajkhanke commited on
Commit
83d9968
·
verified ·
1 Parent(s): 03e9b2c

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +76 -0
app.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, File, UploadFile, HTTPException
2
+ from fastapi.responses import JSONResponse
3
+ from tensorflow.keras.models import load_model
4
+ from PIL import Image
5
+ import numpy as np
6
+ import requests
7
+
8
+ # Initialize FastAPI app
9
+ app = FastAPI()
10
+
11
+ # Load the pre-trained model
12
+ model_path = 'dog_breed.h5' # Update with your model path
13
+ model = load_model(model_path)
14
+
15
+ # Dictionary to map index to breed name
16
+ breed_names = {
17
+ 0: 'Beagle', 1: 'Boxer', 2: 'Bulldog', 3: 'Dachshund', 4: 'German Shepherd',
18
+ 5: 'Golden Retriever', 6: 'Labrador Retriever', 7: 'Poodle', 8: 'Rottweiler', 9: 'Yorkshire Terrier'
19
+ }
20
+
21
+ # Function to preprocess the image
22
+ def preprocess_image(image):
23
+ image = image.resize((150, 150))
24
+ img_array = np.array(image)
25
+ img_array = img_array / 255.0
26
+ img_array = np.expand_dims(img_array, axis=0)
27
+ return img_array
28
+
29
+ # Function to classify the breed
30
+ def classify_breed(image, model):
31
+ img_array = preprocess_image(image)
32
+ predictions = model.predict(img_array)
33
+ predicted_class_index = np.argmax(predictions)
34
+ return breed_names[predicted_class_index]
35
+
36
+ # Function to fetch breed information from an API
37
+ def fetch_breed_info(breed_name):
38
+ url = f'https://api.thedogapi.com/v1/breeds/search?q={breed_name}'
39
+ response = requests.get(url)
40
+ if response.status_code == 200:
41
+ breed_info = response.json()
42
+ return breed_info
43
+ else:
44
+ return None
45
+
46
+ # API route for prediction
47
+ @app.post("/predict")
48
+ async def predict(file: UploadFile = File(...)):
49
+ try:
50
+ # Check file type
51
+ if file.content_type not in ["image/jpeg", "image/png", "image/jpg"]:
52
+ raise HTTPException(status_code=400, detail="Invalid file type. Only JPG and PNG are allowed.")
53
+
54
+ # Read and process the image
55
+ image = Image.open(file.file)
56
+
57
+ # Classify the breed
58
+ breed_name = classify_breed(image, model)
59
+
60
+ # Fetch breed information
61
+ breed_info = fetch_breed_info(breed_name)
62
+
63
+ # Prepare response
64
+ response = {
65
+ "predicted_breed": breed_name,
66
+ "breed_info": breed_info[0] if breed_info else "No additional information available."
67
+ }
68
+ return JSONResponse(content=response)
69
+
70
+ except Exception as e:
71
+ raise HTTPException(status_code=500, detail=str(e))
72
+
73
+ # Root route
74
+ @app.get("/")
75
+ def read_root():
76
+ return {"message": "Welcome to the Dog Breed Classification API!"}