--- license: mit datasets: - AIOmarRehan/AnimalsDataset --- # Animal Image Classification (TensorFlow & CNN) > "A complete end‑to‑end pipeline for building, cleaning, preprocessing, training, evaluating, and deploying a deep CNN model for multi‑class animal image classification." This project is designed to be **clean**, **organized**, and **human-friendly**, showing the full machine‑learning workflow — from **data validation** to **model evaluation & ROC curves**. --- ## Project Structure | Component | Description | |----------|-------------| | **Data Loading** | Reads and extracts the ZIP dataset from Google Drive | | **EDA** | Class distribution, file integrity, image sizes, brightness, contrast, samples display | | **Preprocessing** | Resizing, normalization, augmentation, hashing, cleaning corrupted files | | **Model** | Deep custom CNN with BatchNorm, Dropout & L2 Regularization | | **Training** | Adam optimizer, LR scheduler, Early stopping | | **Evaluation** | Confusion matrix, classification report, ROC curves | | **Export** | Saves final `.h5` model | --- ## How to Run ### 1. Upload your dataset to Google Drive Your dataset must be structured as: ``` Animals/ ├── Cats/ ├── Dogs/ ├── Snakes/ ``` ### 2. Update the ZIP path ```python zip_path = '/content/drive/MyDrive/Animals.zip' extract_to = '/content/my_data' ``` ### 3. Run the Notebook Once executed, the script will: - Mount Google Drive - Extract images - Build a DataFrame of paths - Run EDA checks - Clean and prepare images - Train the CNN model - Export results --- ## Data Preparation & EDA This project performs **deep dataset validation** including: ### Class Distribution ```python class_count = df['class'].value_counts() class_count.plot(kind='bar') ``` ### Image Size Properties ```python image_df['Channels'].value_counts().plot(kind='bar') ``` ### Duplicate Image Detection Using MD5 hashing: ```python def get_hash(file_path): with open(file_path, 'rb') as f: return hashlib.md5(f.read()).hexdigest() ``` ### Brightness & Contrast Issues ```python stat = ImageStat.Stat(img.convert("L")) brightness = stat.mean[0] contrast = stat.stddev[0] ``` ### Auto‑fixing poor images Brightness/contrast enhanced using: ```python img = ImageEnhance.Brightness(img).enhance(1.5) img = ImageEnhance.Contrast(img).enhance(1.5) ``` --- ## Image Preprocessing All images are resized to **256×256** and normalized to **[0,1]**. ```python def preprocess_image(path, target_size=(256, 256)): img = tf.io.read_file(path) img = tf.image.decode_image(img, channels=3) img = tf.image.resize(img, target_size) return tf.cast(img, tf.float32) / 255.0 ``` ### Data Augmentation ```python data_augmentation = tf.keras.Sequential([ tf.keras.layers.RandomFlip("horizontal"), tf.keras.layers.RandomRotation(0.1), tf.keras.layers.RandomZoom(0.1), ]) ``` --- ## CNN Model Architecture Below is a simplified view of the model: ``` Conv2D (32) → BatchNorm → Conv2D (32) → BatchNorm → MaxPool → Dropout Conv2D (64) → BatchNorm → Conv2D (64) → BatchNorm → MaxPool → Dropout Conv2D (128) → BatchNorm → Conv2D (128) → BatchNorm → MaxPool → Dropout Conv2D (256) → BatchNorm → Conv2D (256) → BatchNorm → MaxPool → Dropout Flatten → Dense (softmax) ``` Example code: ```python model.add(Conv2D(32, (3,3), activation='relu', padding='same')) model.add(BatchNormalization()) model.add(MaxPooling2D((2,2))) ``` --- ## Training ```python epochs = 50 optimizer = Adam(learning_rate=0.0005) model.compile(optimizer=optimizer, loss='sparse_categorical_crossentropy', metrics=['accuracy']) ``` ### Callbacks | Callback | Purpose | |----------|---------| | **ReduceLROnPlateau** | Auto‑reduce LR when val_loss stagnates | | **EarlyStopping** | Stop training when no improvement | --- ## Model Evaluation ### Accuracy ```python test_loss, test_accuracy = model.evaluate(test_ds) ``` ### Classification Report ```python print(classification_report(y_true, y_pred, target_names=le.classes_)) ``` ### Confusion Matrix ```python sns.heatmap(cm, annot=True, cmap='Blues') ``` ### ROC Curve (One-vs-Rest) ```python fpr, tpr, _ = roc_curve(y_true_bin[:, i], y_probs[:, i]) ``` --- ## Saving the Model ```python model.save("Animal_Classification.h5") ``` --- ## Full Code Organization (High-Level) | Step | Description | |------|-------------| | 1 | Import libraries, mount drive | | 2 | Extract ZIP | | 3 | Build DataFrame | | 4 | EDA & cleaning | | 5 | Preprocessing & augmentation | | 6 | Dataset pipeline (train/val/test) | | 7 | CNN architecture | | 8 | Training | | 9 | Evaluation | |10 | Save model | --- ## Final Notes This README is crafted to feel **human**, clean, and attractive — not autogenerated. It can be directly used in any GitHub repository. If you want, I can also: - Generate a **short version** - Add **badges** (TensorFlow, Python, etc.) - Write an **installation section** - Turn it into a **Hugging Face Space README** # Animal Image Classification – Complete Pipeline (README) > "A clean dataset is half the model’s accuracy. The rest is just engineering." This project presents a **complete end-to-end deep learning pipeline** for **multi-class animal image classification** using TensorFlow/Keras. It includes everything from data extraction, cleaning, and analysis, to model training, evaluation, and exporting. --- ## Table of Contents | Section | Description | |--------|-------------| | **1. Project Overview** | What this project does & architecture overview | | **2. Features** | Key capabilities of this pipeline | | **3. Directory Structure** | Recommended project layout | | **4. Installation** | How to install and run this project | | **5. Dataset Processing** | Extraction, cleaning, inspections | | **6. Exploratory Data Analysis** | Visualizations & summary statistics | | **7. Preprocessing & Augmentation** | Data preparation logic | | **8. CNN Model Architecture** | Layers, blocks, hyperparameters | | **9. Training & Callbacks** | How the model is trained | | **10. Evaluation Metrics** | Reports, ROC curve, confusion matrix | | **11. Model Export** | Saving and downloading the model | | **12. Code Examples** | Important snippets explained | --- ## 1. Project Overview This project builds a **Convolutional Neural Network (CNN)** to classify images of animals into multiple categories. The process includes: - Dataset extraction from Google Drive - Data validation (duplicates, corrupt files, mislabeled images) - Image enhancement & cleaning - Class distribution analysis - Image size analysis and outlier detection - Data augmentation - CNN model training with regularization - Performance evaluation using multiple metrics - Model export to `.h5` The pipeline is designed to be **robust, explainable, and production-friendly**. --- ## 2. Features | Feature | Description | |---------|-------------| | **Automatic Dataset Extraction** | Unzips and loads images from Google Drive | | **Image Validation** | Detects duplicates, corrupted images, and mislabeled files | | **Data Cleaning** | Brightness/contrast enhancements for dark or overexposed samples | | **EDA Visualizations** | Class distribution, size, color modes, outliers | | **TensorFlow Dataset Pipeline** | Efficient TFRecords-like batching & prefetching | | **Deep CNN Model** | 32 → 64 → 128 → 256 filters with batch norm and dropout | | **Model Evaluation Dashboard** | Confusion matrix, ROC curves, precision/recall/F1 | | **Model Export** | Saves final model as `Animal_Classification.h5` | --- ## 3. Recommended Directory Structure ```text Animal-Classification ┣ data ┃ ┗ Animals (extracted folders) ┣ notebooks ┣ src ┃ ┣ preprocessing.py ┃ ┣ model.py ┃ ┗ utils.py ┣ README.md ┗ requirements.txt ``` --- ## 4. Installation ```bash pip install tensorflow pandas matplotlib seaborn scikit-learn pillow tqdm ``` If using **Google Colab**, the project already supports: - `google.colab.drive` - `google.colab.files` --- ## 5. Dataset Extraction & Loading Example snippet: ```python zip_path = '/content/drive/MyDrive/Animals.zip' extract_to = '/content/my_data' with zipfile.ZipFile(zip_path, 'r') as zip_ref: zip_ref.extractall(extract_to) ``` Images are collected into a DataFrame: ```python paths = [(path.parts[-2], path.name, str(path)) for path in Path(extract_to).rglob('*.*')] df = pd.DataFrame(paths, columns=['class','image','full_path']) ``` --- ## 6. Exploratory Data Analysis Examples of generated visualizations: - Barplot of class distribution - Pie chart of percentage per class - Scatter plots of image width and height - Image mode (RGB/Gray) distribution Example: ```python plt.figure(figsize=(32,16)) class_count.plot(kind='bar') ``` --- ## 7. Preprocessing & Augmentation ### Preprocessing function ```python def preprocess_image(path, target_size=(256,256)): img = tf.io.read_file(path) img = tf.image.decode_image(img, channels=3) img = tf.image.resize(img, target_size) return tf.cast(img, tf.float32)/255.0 ``` ### Augmentation ```python data_augmentation = tf.keras.Sequential([ tf.keras.layers.RandomFlip("horizontal"), tf.keras.layers.RandomRotation(0.1), tf.keras.layers.RandomZoom(0.1), ]) ``` --- ## 8. CNN Model Architecture | Block | Layers | |------|---------| | **Block 1** | Conv2D(32) → BN → Conv2D(32) → BN → MaxPool → Dropout(0.2) | | **Block 2** | Conv2D(64) → BN → Conv2D(64) → BN → MaxPool → Dropout(0.3) | | **Block 3** | Conv2D(128) → BN → Conv2D(128) → BN → MaxPool → Dropout(0.4) | | **Block 4** | Conv2D(256) → BN → Conv2D(256) → BN → MaxPool → Dropout(0.5) | | **Output** | Flatten → Dense(num_classes, softmax) | Example snippet: ```python model.add(Conv2D(64,(3,3),activation='relu',padding='same')) model.add(BatchNormalization()) ``` --- ## 9. Training ```python optimizer = Adam(learning_rate=0.0005) model.compile(optimizer=optimizer, loss='sparse_categorical_crossentropy', metrics=['accuracy']) ``` Using callbacks: ```python ReduceLROnPlateau(...) EarlyStopping(...) ``` --- ## 10. Evaluation Metrics This project computes: - Precision, Recall, F1 (macro & per class) - Confusion matrix (heatmap) - ROC curves (one-vs-rest) - Macro-average ROC Example: ```python cm = confusion_matrix(y_true, y_pred) sns.heatmap(cm, annot=True) ``` --- ## 11. Model Export ```python model.save("Animal_Classification.h5") files.download("Animal_Classification.h5") ``` --- ## 12. Example Snippets ### Checking corrupted files ```python try: with Image.open(path) as img: img.verify() except: corrupted.append(path) ``` ### Filtering duplicate images ```python df['file_hash'] = df['full_path'].apply(get_hash) df_unique = df.drop_duplicates(subset='file_hash') ```