| | import os
|
| | import sys
|
| | from concurrent.futures import ThreadPoolExecutor, as_completed
|
| | from typing import Tuple, Dict
|
| | from PIL import Image
|
| | import random
|
| |
|
| | import gradio as gr
|
| | import torch
|
| | from datasets import load_dataset
|
| |
|
| |
|
| | from model_handlers.basic_cnn_handler import BasicCNNModel
|
| | from model_handlers.hugging_face_handler import HuggingFaceModel
|
| | from model_handlers.xception_handler import XceptionModel
|
| |
|
| |
|
| |
|
| |
|
| | BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
| | MODELS_DIR = os.path.join(BASE_DIR, "models")
|
| |
|
| | MODEL_1_DIR = os.path.join(MODELS_DIR, "basic_cnn")
|
| | MODEL_2_DIR = os.path.join(MODELS_DIR, "hugging_face")
|
| | MODEL_3_DIR = os.path.join(MODELS_DIR, "xception")
|
| |
|
| |
|
| | basic_cnn_model = None
|
| | hugging_face_model = None
|
| | xception_model = None
|
| |
|
| |
|
| | dataset = None
|
| | DATASET_NAME = "AIOmarRehan/Vehicles"
|
| |
|
| | MODELS_INFO = {
|
| | "Model 1: Basic CNN": {
|
| | "description": "Custom CNN architecture with 4 Conv blocks and BatchNorm",
|
| | "path": MODEL_1_DIR,
|
| | "handler_class": BasicCNNModel
|
| | },
|
| | "Model 2: Hugging Face Transformers (DeiT-Tiny | Meta)": {
|
| | "description": "Pre-trained transformer-based model from Hugging Face (DeiT-Tiny | Meta)",
|
| | "path": MODEL_2_DIR,
|
| | "handler_class": HuggingFaceModel
|
| | },
|
| | "Model 3: Xception CNN": {
|
| | "description": "Fine-tuned Xception architecture using timm library",
|
| | "path": MODEL_3_DIR,
|
| | "handler_class": XceptionModel
|
| | }
|
| | }
|
| |
|
| |
|
| |
|
| |
|
| | def load_models():
|
| | """Load all three models at startup"""
|
| | global basic_cnn_model, hugging_face_model, xception_model
|
| |
|
| | print("\n" + "="*60)
|
| | print("Loading Models...")
|
| | print("="*60)
|
| |
|
| | try:
|
| | print("\n[1/3] Loading Basic CNN Model...")
|
| | basic_cnn_model = BasicCNNModel(MODEL_1_DIR)
|
| | print("Basic CNN Model loaded successfully")
|
| | except Exception as e:
|
| | print(f"Failed to load Basic CNN Model: {e}")
|
| | basic_cnn_model = None
|
| |
|
| | try:
|
| | print("\n[2/3] Loading Hugging Face (DeiT-Tiny | Meta) Model...")
|
| | hugging_face_model = HuggingFaceModel(MODEL_2_DIR)
|
| | print("Hugging Face Model loaded successfully")
|
| | except Exception as e:
|
| | print(f"Failed to load Hugging Face Model: {e}")
|
| | hugging_face_model = None
|
| |
|
| | try:
|
| | print("\n[3/3] Loading Xception Model...")
|
| | xception_model = XceptionModel(MODEL_3_DIR)
|
| | print("Xception Model loaded successfully")
|
| | except Exception as e:
|
| | print(f"Failed to load Xception Model: {e}")
|
| | xception_model = None
|
| |
|
| | print("\n" + "="*60)
|
| | print("Model Loading Complete!")
|
| | print("="*60 + "\n")
|
| |
|
| |
|
| | def load_dataset_split():
|
| | """Load the dataset for random image selection"""
|
| | global dataset
|
| |
|
| | try:
|
| | print("\nLoading dataset from Hugging Face...")
|
| |
|
| | dataset = load_dataset(DATASET_NAME, split="train", trust_remote_code=True)
|
| | print(f"Dataset loaded successfully: {len(dataset)} images available")
|
| | except Exception as e:
|
| | print(f"Failed to load dataset: {e}")
|
| | print("Random image feature will be disabled")
|
| | dataset = None
|
| |
|
| |
|
| | def get_random_image():
|
| | """Get a random image from the dataset"""
|
| | if dataset is None:
|
| | print("Dataset not loaded, attempting to load...")
|
| | load_dataset_split()
|
| |
|
| | if dataset is None:
|
| | return None
|
| |
|
| | try:
|
| |
|
| | random_idx = random.randint(0, len(dataset) - 1)
|
| | sample = dataset[random_idx]
|
| |
|
| |
|
| | if 'image' in sample:
|
| | img = sample['image']
|
| | elif 'img' in sample:
|
| | img = sample['img']
|
| | else:
|
| |
|
| | for value in sample.values():
|
| | if isinstance(value, Image.Image):
|
| | img = value
|
| | break
|
| | else:
|
| | print(f"Could not find image in sample keys: {sample.keys()}")
|
| | return None
|
| |
|
| | print(f"Loaded random image from index {random_idx}")
|
| | return img
|
| | except Exception as e:
|
| | print(f"Error loading random image: {e}")
|
| | return None
|
| |
|
| |
|
| |
|
| |
|
| | def predict_with_model_1(image: Image.Image) -> Tuple[str, float, Dict]:
|
| | """Predict with Basic CNN Model"""
|
| | if basic_cnn_model is None:
|
| | return "Model 1: Error", 0.0, {}
|
| | try:
|
| | label, confidence, prob_dict = basic_cnn_model.predict(image)
|
| | return label, confidence, prob_dict
|
| | except Exception as e:
|
| | print(f"Error in Model 1 prediction: {e}")
|
| | return "Error", 0.0, {}
|
| |
|
| |
|
| | def predict_with_model_2(image: Image.Image) -> Tuple[str, float, Dict]:
|
| | """Predict with Hugging Face (DeiT-Tiny | Meta) Model"""
|
| | if hugging_face_model is None:
|
| | return "Model 2: Error", 0.0, {}
|
| | try:
|
| | label, confidence, prob_dict = hugging_face_model.predict(image)
|
| | return label, confidence, prob_dict
|
| | except Exception as e:
|
| | print(f"Error in Model 2 prediction: {e}")
|
| | return "Error", 0.0, {}
|
| |
|
| |
|
| | def predict_with_model_3(image: Image.Image) -> Tuple[str, float, Dict]:
|
| | """Predict with Xception Model"""
|
| | if xception_model is None:
|
| | return "Model 3: Error", 0.0, {}
|
| | try:
|
| | label, confidence, prob_dict = xception_model.predict(image)
|
| | return label, confidence, prob_dict
|
| | except Exception as e:
|
| | print(f"Error in Model 3 prediction: {e}")
|
| | return "Error", 0.0, {}
|
| |
|
| |
|
| | def predict_all_models(image: Image.Image):
|
| | if image is None:
|
| | empty_result = {"Model": "N/A", "Prediction": "No image", "Confidence": 0.0}
|
| | empty_probs = {}
|
| | empty_consensus = "<p>Please upload an image to see results</p>"
|
| | return empty_result, empty_result, empty_result, "Please upload an image", empty_probs, empty_probs, empty_probs, empty_consensus
|
| |
|
| | print("\n" + "="*60)
|
| | print("Running Predictions with All Models...")
|
| | print("="*60)
|
| |
|
| |
|
| | with ThreadPoolExecutor(max_workers=3) as executor:
|
| | future_1 = executor.submit(predict_with_model_1, image)
|
| | future_2 = executor.submit(predict_with_model_2, image)
|
| | future_3 = executor.submit(predict_with_model_3, image)
|
| |
|
| |
|
| | result_1_label, result_1_conf, result_1_probs = future_1.result()
|
| | result_2_label, result_2_conf, result_2_probs = future_2.result()
|
| | result_3_label, result_3_conf, result_3_probs = future_3.result()
|
| |
|
| |
|
| | result_1 = {
|
| | "Model": "Basic CNN",
|
| | "Prediction": result_1_label,
|
| | "Confidence": f"{result_1_conf * 100:.2f}%"
|
| | }
|
| |
|
| | result_2 = {
|
| | "Model": "Hugging Face (DeiT-Tiny | Meta)",
|
| | "Prediction": result_2_label,
|
| | "Confidence": f"{result_2_conf * 100:.2f}%"
|
| | }
|
| |
|
| | result_3 = {
|
| | "Model": "Xception",
|
| | "Prediction": result_3_label,
|
| | "Confidence": f"{result_3_conf * 100:.2f}%"
|
| | }
|
| |
|
| |
|
| | all_agree = result_1_label == result_2_label == result_3_label
|
| |
|
| |
|
| | if all_agree:
|
| | consensus_html = f"""
|
| | <div style="background-color: #d4edda; border: 2px solid #28a745; border-radius: 8px; padding: 20px; text-align: center;">
|
| | <h3 style="color: #155724; margin: 0; font-size: 24px;">All Models Agree!</h3>
|
| | <p style="color: #155724; margin: 10px 0 0 0; font-size: 18px; font-weight: bold;">{result_1_label}</p>
|
| | </div>
|
| | """
|
| | else:
|
| | consensus_html = f"""
|
| | <div style="background-color: #f8d7da; border: 2px solid #dc3545; border-radius: 8px; padding: 20px; text-align: center;">
|
| | <h3 style="color: #721c24; margin: 0; font-size: 24px;">Models Disagree</h3>
|
| | <p style="color: #721c24; margin: 10px 0 0 0; font-size: 16px;">Check predictions below for details</p>
|
| | </div>
|
| | """
|
| |
|
| | comparison_text = f"""
|
| | ## Comparison Results
|
| |
|
| | **Model 1 (Basic CNN):** {result_1_label} ({result_1_conf * 100:.2f}%)
|
| |
|
| | **Model 2 (Hugging Face (DeiT-Tiny | Meta)):** {result_2_label} ({result_2_conf * 100:.2f}%)
|
| |
|
| | **Model 3 (Xception):** {result_3_label} ({result_3_conf * 100:.2f}%)
|
| | """
|
| |
|
| | print(f"Prediction 1: {result_1_label} ({result_1_conf * 100:.2f}%)")
|
| | print(f"Prediction 2: {result_2_label} ({result_2_conf * 100:.2f}%)")
|
| | print(f"Prediction 3: {result_3_label} ({result_3_conf * 100:.2f}%)")
|
| | print(f"Consensus: {'All agree!' if all_agree else 'Disagreement detected'}")
|
| | print("="*60 + "\n")
|
| |
|
| | return result_1, result_2, result_3, comparison_text, result_1_probs, result_2_probs, result_3_probs, consensus_html
|
| |
|
| |
|
| |
|
| |
|
| | def build_interface() -> gr.Blocks:
|
| | with gr.Blocks(
|
| | title="PyTorch Unified Model Comparison",
|
| | theme=gr.themes.Soft()
|
| | ) as demo:
|
| |
|
| |
|
| | gr.Markdown("""
|
| | # PyTorch Unified Model Comparison
|
| |
|
| | Upload an image and compare predictions from three different PyTorch models **simultaneously**.
|
| |
|
| | This tool helps you understand how different architectures (Basic CNN, Transformers, Xception)
|
| | classify the same image and identify where they agree or disagree.
|
| | """)
|
| |
|
| |
|
| | with gr.Accordion("Model Information", open=False):
|
| | gr.Markdown(f"""
|
| | ### Model 1: Basic CNN
|
| | - **Description:** {MODELS_INFO['Model 1: Basic CNN']['description']}
|
| | - **Architecture:** 4 Conv blocks + BatchNorm + Global Avg Pooling
|
| | - **Input Size:** 224×224
|
| |
|
| | ### Model 2: Hugging Face Transformers (DeiT-Tiny | Meta)
|
| | - **Description:** {MODELS_INFO['Model 2: Hugging Face Transformers (DeiT-Tiny | Meta)']['description']}
|
| | - **Framework:** transformers library
|
| |
|
| | ### Model 3: Xception CNN
|
| | - **Description:** {MODELS_INFO['Model 3: Xception CNN']['description']}
|
| | - **Architecture:** Fine-tuned Xception with timm
|
| | """)
|
| |
|
| |
|
| | with gr.Row():
|
| | with gr.Column():
|
| | image_input = gr.Image(
|
| | type="pil",
|
| | label="Upload Image",
|
| | sources=["upload", "webcam"]
|
| | )
|
| | predict_btn = gr.Button("Predict with All Models", variant="primary", size="lg")
|
| | random_img_btn = gr.Button("Load Random Image from Dataset", variant="secondary", size="lg")
|
| |
|
| |
|
| | gr.Markdown("## Results")
|
| |
|
| | with gr.Row():
|
| | with gr.Column():
|
| | result_1_box = gr.JSON(label="Model 1: Basic CNN")
|
| | with gr.Column():
|
| | result_2_box = gr.JSON(label="Model 2: Hugging Face (DeiT-Tiny)")
|
| | with gr.Column():
|
| | result_3_box = gr.JSON(label="Model 3: Xception")
|
| |
|
| |
|
| | comparison_output = gr.Markdown(label="Comparison Summary")
|
| |
|
| |
|
| | consensus_output = gr.HTML(value="<p></p>")
|
| |
|
| |
|
| | gr.Markdown("## Class Probabilities")
|
| |
|
| | with gr.Row():
|
| | with gr.Column():
|
| | probs_1 = gr.Label(label="Model 1: Basic CNN | Probabilities")
|
| | with gr.Column():
|
| | probs_2 = gr.Label(label="Model 2: DeiT-Tiny | Meta | Probabilities")
|
| | with gr.Column():
|
| | probs_3 = gr.Label(label="Model 3: Xception | Probabilities")
|
| |
|
| |
|
| | predict_btn.click(
|
| | fn=predict_all_models,
|
| | inputs=image_input,
|
| | outputs=[result_1_box, result_2_box, result_3_box, comparison_output, probs_1, probs_2, probs_3, consensus_output]
|
| | )
|
| |
|
| |
|
| | image_input.change(
|
| | fn=predict_all_models,
|
| | inputs=image_input,
|
| | outputs=[result_1_box, result_2_box, result_3_box, comparison_output, probs_1, probs_2, probs_3, consensus_output]
|
| | )
|
| |
|
| |
|
| | random_img_btn.click(
|
| | fn=get_random_image,
|
| | inputs=None,
|
| | outputs=image_input
|
| | )
|
| |
|
| |
|
| | gr.Markdown("""
|
| | ---
|
| |
|
| | **Available Classes:** Auto Rickshaws | Bikes | Cars | Motorcycles | Planes | Ships | Trains
|
| |
|
| | **Dataset:** Random images are loaded from [AIOmarRehan/Vehicles](https://huggingface.co/datasets/AIOmarRehan/Vehicles) on Hugging Face
|
| |
|
| | This unified application allows real-time comparison of three different deep learning models
|
| | to understand their individual strengths and weaknesses.
|
| | """)
|
| |
|
| | return demo
|
| |
|
| |
|
| |
|
| |
|
| | if __name__ == "__main__":
|
| |
|
| | load_models()
|
| |
|
| |
|
| | load_dataset_split()
|
| |
|
| |
|
| | demo = build_interface()
|
| |
|
| | server_name = os.getenv("GRADIO_SERVER_NAME", "0.0.0.0")
|
| | server_port = int(os.getenv("GRADIO_SERVER_PORT", "7860"))
|
| |
|
| | print(f"\nLaunching Gradio Interface on {server_name}:{server_port}")
|
| | print("Open your browser and navigate to http://localhost:7860\n")
|
| |
|
| | demo.launch(
|
| | server_name=server_name,
|
| | server_port=server_port,
|
| | share=False,
|
| | show_error=True
|
| | ) |