JanithDeshan24's picture
Update app.py
da5a49b verified
import os
import numpy as np
import tensorflow as tf
from flask import Flask, request, render_template, send_from_directory
from werkzeug.utils import secure_filename
from PIL import UnidentifiedImageError
from huggingface_hub import hf_hub_download # <-- IMPORT THE HUGGING FACE LIBRARY
# --- 1. CONFIGURATION ---
app = Flask(__name__)
UPLOAD_FOLDER = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'uploads')
app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER
os.makedirs(UPLOAD_FOLDER, exist_ok=True)
IMG_SIZE = 224
# --- 2. DOWNLOAD AND LOAD MODELS/CLASS NAMES FROM HUGGING FACE HUB ---
print("--- Downloading models and class names from Hugging Face Hub... ---")
try:
# IMPORTANT: Update this with your Hugging Face username and the name of your model repository
REPO_ID = "JanithDeshan24/Dog-Breed-Identifier"
# Download the files from the Hub. hf_hub_download returns the local path to the file.
BREED_MODEL_PATH = hf_hub_download(repo_id=REPO_ID, filename="dog_breed_project_model.h5")
GATEKEEPER_MODEL_PATH = hf_hub_download(repo_id=REPO_ID, filename="gatekeeper_model.h5")
CLASS_NAMES_PATH = hf_hub_download(repo_id=REPO_ID, filename="class_names.txt")
print(f"Models and class names will be loaded from the following paths:")
print(f"Breed Model: {BREED_MODEL_PATH}")
print(f"Gatekeeper Model: {GATEKEEPER_MODEL_PATH}")
print(f"Class Names: {CLASS_NAMES_PATH}")
# Load the expert model for breed classification
breed_model = tf.keras.models.load_model(BREED_MODEL_PATH)
print("✅ Dog Breed (Expert) model loaded.")
# Load the gatekeeper model for dog vs. not-dog classification
gatekeeper_model = tf.keras.models.load_model(GATEKEEPER_MODEL_PATH)
print("✅ Gatekeeper (Dog vs. Not-Dog) model loaded.")
# Load class names
with open(CLASS_NAMES_PATH, 'r') as f:
class_names = [line.strip() for line in f.readlines()]
print(f"✅ Class names loaded. Found {len(class_names)} classes.")
except Exception as e:
print(f"❌ Error loading models: {e}")
breed_model = None
gatekeeper_model = None
print("--- Setup complete ---")
# --- 3. IMAGE PREPROCESSING FUNCTION (CONSISTENT WITH TRAINING) ---
def preprocess_uploaded_image(filepath, img_size):
"""
Loads, decodes, and preprocesses an uploaded image for both models.
This function handles different file types, grayscale images, and aspect ratios.
"""
try:
# Read the file and decode it as a 3-channel (RGB) image
img = tf.io.read_file(filepath)
img = tf.image.decode_image(img, channels=3, expand_animations=False)
# Pad to a square aspect ratio without distortion
img = tf.image.resize_with_pad(img, img_size, img_size)
# Expand dimensions to create a batch of 1
img_batch = tf.expand_dims(img, 0)
# Preprocess for each model's specific requirements
gatekeeper_input = tf.keras.applications.mobilenet_v2.preprocess_input(tf.identity(img_batch))
breed_model_input = tf.keras.applications.resnet_v2.preprocess_input(tf.identity(img_batch))
return gatekeeper_input, breed_model_input
except (UnidentifiedImageError, tf.errors.InvalidArgumentError):
# Handle cases where the file is not a valid image
return None, None
except Exception as e:
print(f"An unexpected error occurred during preprocessing: {e}")
return None, None
# --- 4. FLASK ROUTES ---
@app.route('/', methods=['GET', 'POST'])
def index():
if request.method == 'POST':
if not all([breed_model, gatekeeper_model]):
return render_template('index.html', error="Models are not loaded. Please check the server logs.")
if 'file' not in request.files:
return render_template('index.html', error="No file part in the request.")
file = request.files['file']
if file.filename == '':
return render_template('index.html', error="No file selected.")
if file:
filename = secure_filename(file.filename)
filepath = os.path.join(app.config['UPLOAD_FOLDER'], filename)
file.save(filepath)
# --- PREDICTION PIPELINE ---
gatekeeper_img, breed_img = preprocess_uploaded_image(filepath, IMG_SIZE)
if gatekeeper_img is None:
return render_template('index.html', error="Invalid or corrupted image file. Please try another.")
# Step 1: Use the Gatekeeper to check if the image is a dog
gatekeeper_pred = gatekeeper_model.predict(gatekeeper_img)[0][0]
if gatekeeper_pred > 0.5:
# -----------------------------------------------------------------
# START: UPGRADED LOGIC WITH TEST-TIME AUGMENTATION (TTA)
# -----------------------------------------------------------------
# Create 4 augmented versions of the image (0, 90, 180, 270 degrees)
images_to_predict = [
breed_img,
tf.image.rot90(breed_img, k=1), # 90 degrees
tf.image.rot90(breed_img, k=2), # 180 degrees (upside down)
tf.image.rot90(breed_img, k=3) # 270 degrees
]
# Stack the images into a single batch
tta_batch = tf.concat(images_to_predict, axis=0)
# Get predictions for all 4 images in one go
batch_predictions = breed_model.predict(tta_batch)
# Average the predictions to get the final, robust result
breed_predictions = tf.reduce_mean(batch_predictions, axis=0)
# -----------------------------------------------------------------
# END: UPGRADED LOGIC WITH TEST-TIME AUGMENTATION (TTA)
# -----------------------------------------------------------------
# Get top 3 predictions from the averaged result
top_k_values, top_k_indices = tf.math.top_k(breed_predictions, k=3)
top_breeds = []
for i in range(3):
breed_name = class_names[top_k_indices[i]]
confidence = top_k_values[i] * 100
top_breeds.append({"name": breed_name.replace("_", " ").title(), "confidence": f"{confidence:.2f}%"})
return render_template('index.html',
is_dog=True,
predictions=top_breeds,
uploaded_image=filename)
else:
# If it's not a dog, return a clear message
not_dog_confidence = (1 - gatekeeper_pred) * 100
return render_template('index.html',
is_dog=False,
prediction_text=f"This doesn't look like a dog.",
confidence_text=f"({not_dog_confidence:.2f}% sure it's not a dog)",
uploaded_image=filename)
return render_template('index.html')
@app.route('/uploads/<filename>')
def uploaded_file(filename):
"""Serves the uploaded file to be displayed on the webpage."""
return send_from_directory(app.config['UPLOAD_FOLDER'], filename)
if __name__ == '__main__':
app.run(debug=True)