Spaces:
Sleeping
Sleeping
| import os | |
| import numpy as np | |
| import tensorflow as tf | |
| from flask import Flask, request, render_template, send_from_directory | |
| from werkzeug.utils import secure_filename | |
| from PIL import UnidentifiedImageError | |
| from huggingface_hub import hf_hub_download # <-- IMPORT THE HUGGING FACE LIBRARY | |
| # --- 1. CONFIGURATION --- | |
| app = Flask(__name__) | |
| UPLOAD_FOLDER = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'uploads') | |
| app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER | |
| os.makedirs(UPLOAD_FOLDER, exist_ok=True) | |
| IMG_SIZE = 224 | |
| # --- 2. DOWNLOAD AND LOAD MODELS/CLASS NAMES FROM HUGGING FACE HUB --- | |
| print("--- Downloading models and class names from Hugging Face Hub... ---") | |
| try: | |
| # IMPORTANT: Update this with your Hugging Face username and the name of your model repository | |
| REPO_ID = "JanithDeshan24/Dog-Breed-Identifier" | |
| # Download the files from the Hub. hf_hub_download returns the local path to the file. | |
| BREED_MODEL_PATH = hf_hub_download(repo_id=REPO_ID, filename="dog_breed_project_model.h5") | |
| GATEKEEPER_MODEL_PATH = hf_hub_download(repo_id=REPO_ID, filename="gatekeeper_model.h5") | |
| CLASS_NAMES_PATH = hf_hub_download(repo_id=REPO_ID, filename="class_names.txt") | |
| print(f"Models and class names will be loaded from the following paths:") | |
| print(f"Breed Model: {BREED_MODEL_PATH}") | |
| print(f"Gatekeeper Model: {GATEKEEPER_MODEL_PATH}") | |
| print(f"Class Names: {CLASS_NAMES_PATH}") | |
| # Load the expert model for breed classification | |
| breed_model = tf.keras.models.load_model(BREED_MODEL_PATH) | |
| print("✅ Dog Breed (Expert) model loaded.") | |
| # Load the gatekeeper model for dog vs. not-dog classification | |
| gatekeeper_model = tf.keras.models.load_model(GATEKEEPER_MODEL_PATH) | |
| print("✅ Gatekeeper (Dog vs. Not-Dog) model loaded.") | |
| # Load class names | |
| with open(CLASS_NAMES_PATH, 'r') as f: | |
| class_names = [line.strip() for line in f.readlines()] | |
| print(f"✅ Class names loaded. Found {len(class_names)} classes.") | |
| except Exception as e: | |
| print(f"❌ Error loading models: {e}") | |
| breed_model = None | |
| gatekeeper_model = None | |
| print("--- Setup complete ---") | |
| # --- 3. IMAGE PREPROCESSING FUNCTION (CONSISTENT WITH TRAINING) --- | |
| def preprocess_uploaded_image(filepath, img_size): | |
| """ | |
| Loads, decodes, and preprocesses an uploaded image for both models. | |
| This function handles different file types, grayscale images, and aspect ratios. | |
| """ | |
| try: | |
| # Read the file and decode it as a 3-channel (RGB) image | |
| img = tf.io.read_file(filepath) | |
| img = tf.image.decode_image(img, channels=3, expand_animations=False) | |
| # Pad to a square aspect ratio without distortion | |
| img = tf.image.resize_with_pad(img, img_size, img_size) | |
| # Expand dimensions to create a batch of 1 | |
| img_batch = tf.expand_dims(img, 0) | |
| # Preprocess for each model's specific requirements | |
| gatekeeper_input = tf.keras.applications.mobilenet_v2.preprocess_input(tf.identity(img_batch)) | |
| breed_model_input = tf.keras.applications.resnet_v2.preprocess_input(tf.identity(img_batch)) | |
| return gatekeeper_input, breed_model_input | |
| except (UnidentifiedImageError, tf.errors.InvalidArgumentError): | |
| # Handle cases where the file is not a valid image | |
| return None, None | |
| except Exception as e: | |
| print(f"An unexpected error occurred during preprocessing: {e}") | |
| return None, None | |
| # --- 4. FLASK ROUTES --- | |
| def index(): | |
| if request.method == 'POST': | |
| if not all([breed_model, gatekeeper_model]): | |
| return render_template('index.html', error="Models are not loaded. Please check the server logs.") | |
| if 'file' not in request.files: | |
| return render_template('index.html', error="No file part in the request.") | |
| file = request.files['file'] | |
| if file.filename == '': | |
| return render_template('index.html', error="No file selected.") | |
| if file: | |
| filename = secure_filename(file.filename) | |
| filepath = os.path.join(app.config['UPLOAD_FOLDER'], filename) | |
| file.save(filepath) | |
| # --- PREDICTION PIPELINE --- | |
| gatekeeper_img, breed_img = preprocess_uploaded_image(filepath, IMG_SIZE) | |
| if gatekeeper_img is None: | |
| return render_template('index.html', error="Invalid or corrupted image file. Please try another.") | |
| # Step 1: Use the Gatekeeper to check if the image is a dog | |
| gatekeeper_pred = gatekeeper_model.predict(gatekeeper_img)[0][0] | |
| if gatekeeper_pred > 0.5: | |
| # ----------------------------------------------------------------- | |
| # START: UPGRADED LOGIC WITH TEST-TIME AUGMENTATION (TTA) | |
| # ----------------------------------------------------------------- | |
| # Create 4 augmented versions of the image (0, 90, 180, 270 degrees) | |
| images_to_predict = [ | |
| breed_img, | |
| tf.image.rot90(breed_img, k=1), # 90 degrees | |
| tf.image.rot90(breed_img, k=2), # 180 degrees (upside down) | |
| tf.image.rot90(breed_img, k=3) # 270 degrees | |
| ] | |
| # Stack the images into a single batch | |
| tta_batch = tf.concat(images_to_predict, axis=0) | |
| # Get predictions for all 4 images in one go | |
| batch_predictions = breed_model.predict(tta_batch) | |
| # Average the predictions to get the final, robust result | |
| breed_predictions = tf.reduce_mean(batch_predictions, axis=0) | |
| # ----------------------------------------------------------------- | |
| # END: UPGRADED LOGIC WITH TEST-TIME AUGMENTATION (TTA) | |
| # ----------------------------------------------------------------- | |
| # Get top 3 predictions from the averaged result | |
| top_k_values, top_k_indices = tf.math.top_k(breed_predictions, k=3) | |
| top_breeds = [] | |
| for i in range(3): | |
| breed_name = class_names[top_k_indices[i]] | |
| confidence = top_k_values[i] * 100 | |
| top_breeds.append({"name": breed_name.replace("_", " ").title(), "confidence": f"{confidence:.2f}%"}) | |
| return render_template('index.html', | |
| is_dog=True, | |
| predictions=top_breeds, | |
| uploaded_image=filename) | |
| else: | |
| # If it's not a dog, return a clear message | |
| not_dog_confidence = (1 - gatekeeper_pred) * 100 | |
| return render_template('index.html', | |
| is_dog=False, | |
| prediction_text=f"This doesn't look like a dog.", | |
| confidence_text=f"({not_dog_confidence:.2f}% sure it's not a dog)", | |
| uploaded_image=filename) | |
| return render_template('index.html') | |
| def uploaded_file(filename): | |
| """Serves the uploaded file to be displayed on the webpage.""" | |
| return send_from_directory(app.config['UPLOAD_FOLDER'], filename) | |
| if __name__ == '__main__': | |
| app.run(debug=True) | |