|
|
--- |
|
|
license: mit |
|
|
datasets: |
|
|
- AIOmarRehan/AnimalsDataset |
|
|
--- |
|
|
|
|
|
# Animal Image Classification (TensorFlow & CNN) |
|
|
|
|
|
> "A complete end‑to‑end pipeline for building, cleaning, preprocessing, training, evaluating, and deploying a deep CNN model for multi‑class animal image classification." |
|
|
|
|
|
This project is designed to be **clean**, **organized**, and **human-friendly**, showing the full machine‑learning workflow — from **data validation** to **model evaluation & ROC curves**. |
|
|
|
|
|
--- |
|
|
|
|
|
## Project Structure |
|
|
|
|
|
| Component | Description | |
|
|
|----------|-------------| |
|
|
| **Data Loading** | Reads and extracts the ZIP dataset from Google Drive | |
|
|
| **EDA** | Class distribution, file integrity, image sizes, brightness, contrast, samples display | |
|
|
| **Preprocessing** | Resizing, normalization, augmentation, hashing, cleaning corrupted files | |
|
|
| **Model** | Deep custom CNN with BatchNorm, Dropout & L2 Regularization | |
|
|
| **Training** | Adam optimizer, LR scheduler, Early stopping | |
|
|
| **Evaluation** | Confusion matrix, classification report, ROC curves | |
|
|
| **Export** | Saves final `.h5` model | |
|
|
|
|
|
--- |
|
|
|
|
|
## How to Run |
|
|
|
|
|
### 1. Upload your dataset to Google Drive |
|
|
Your dataset must be structured as: |
|
|
``` |
|
|
Animals/ |
|
|
├── Cats/ |
|
|
├── Dogs/ |
|
|
├── Snakes/ |
|
|
``` |
|
|
|
|
|
### 2. Update the ZIP path |
|
|
```python |
|
|
zip_path = '/content/drive/MyDrive/Animals.zip' |
|
|
extract_to = '/content/my_data' |
|
|
``` |
|
|
|
|
|
### 3. Run the Notebook |
|
|
Once executed, the script will: |
|
|
- Mount Google Drive |
|
|
- Extract images |
|
|
- Build a DataFrame of paths |
|
|
- Run EDA checks |
|
|
- Clean and prepare images |
|
|
- Train the CNN model |
|
|
- Export results |
|
|
|
|
|
--- |
|
|
|
|
|
## Data Preparation & EDA |
|
|
|
|
|
This project performs **deep dataset validation** including: |
|
|
|
|
|
### Class Distribution |
|
|
```python |
|
|
class_count = df['class'].value_counts() |
|
|
class_count.plot(kind='bar') |
|
|
``` |
|
|
|
|
|
### Image Size Properties |
|
|
```python |
|
|
image_df['Channels'].value_counts().plot(kind='bar') |
|
|
``` |
|
|
|
|
|
### Duplicate Image Detection |
|
|
Using MD5 hashing: |
|
|
```python |
|
|
def get_hash(file_path): |
|
|
with open(file_path, 'rb') as f: |
|
|
return hashlib.md5(f.read()).hexdigest() |
|
|
``` |
|
|
|
|
|
### Brightness & Contrast Issues |
|
|
```python |
|
|
stat = ImageStat.Stat(img.convert("L")) |
|
|
brightness = stat.mean[0] |
|
|
contrast = stat.stddev[0] |
|
|
``` |
|
|
|
|
|
### Auto‑fixing poor images |
|
|
Brightness/contrast enhanced using: |
|
|
```python |
|
|
img = ImageEnhance.Brightness(img).enhance(1.5) |
|
|
img = ImageEnhance.Contrast(img).enhance(1.5) |
|
|
``` |
|
|
|
|
|
--- |
|
|
|
|
|
## Image Preprocessing |
|
|
|
|
|
All images are resized to **256×256** and normalized to **[0,1]**. |
|
|
|
|
|
```python |
|
|
def preprocess_image(path, target_size=(256, 256)): |
|
|
img = tf.io.read_file(path) |
|
|
img = tf.image.decode_image(img, channels=3) |
|
|
img = tf.image.resize(img, target_size) |
|
|
return tf.cast(img, tf.float32) / 255.0 |
|
|
``` |
|
|
|
|
|
### Data Augmentation |
|
|
```python |
|
|
data_augmentation = tf.keras.Sequential([ |
|
|
tf.keras.layers.RandomFlip("horizontal"), |
|
|
tf.keras.layers.RandomRotation(0.1), |
|
|
tf.keras.layers.RandomZoom(0.1), |
|
|
]) |
|
|
``` |
|
|
|
|
|
--- |
|
|
|
|
|
## CNN Model Architecture |
|
|
|
|
|
Below is a simplified view of the model: |
|
|
|
|
|
``` |
|
|
Conv2D (32) → BatchNorm → Conv2D (32) → BatchNorm → MaxPool → Dropout |
|
|
Conv2D (64) → BatchNorm → Conv2D (64) → BatchNorm → MaxPool → Dropout |
|
|
Conv2D (128) → BatchNorm → Conv2D (128) → BatchNorm → MaxPool → Dropout |
|
|
Conv2D (256) → BatchNorm → Conv2D (256) → BatchNorm → MaxPool → Dropout |
|
|
Flatten → Dense (softmax) |
|
|
``` |
|
|
|
|
|
Example code: |
|
|
```python |
|
|
model.add(Conv2D(32, (3,3), activation='relu', padding='same')) |
|
|
model.add(BatchNormalization()) |
|
|
model.add(MaxPooling2D((2,2))) |
|
|
``` |
|
|
|
|
|
--- |
|
|
|
|
|
## Training |
|
|
|
|
|
```python |
|
|
epochs = 50 |
|
|
optimizer = Adam(learning_rate=0.0005) |
|
|
model.compile(optimizer=optimizer, |
|
|
loss='sparse_categorical_crossentropy', |
|
|
metrics=['accuracy']) |
|
|
``` |
|
|
|
|
|
### Callbacks |
|
|
| Callback | Purpose | |
|
|
|----------|---------| |
|
|
| **ReduceLROnPlateau** | Auto‑reduce LR when val_loss stagnates | |
|
|
| **EarlyStopping** | Stop training when no improvement | |
|
|
|
|
|
--- |
|
|
|
|
|
## Model Evaluation |
|
|
|
|
|
### Accuracy |
|
|
```python |
|
|
test_loss, test_accuracy = model.evaluate(test_ds) |
|
|
``` |
|
|
|
|
|
### Classification Report |
|
|
```python |
|
|
print(classification_report(y_true, y_pred, target_names=le.classes_)) |
|
|
``` |
|
|
|
|
|
### Confusion Matrix |
|
|
```python |
|
|
sns.heatmap(cm, annot=True, cmap='Blues') |
|
|
``` |
|
|
|
|
|
### ROC Curve (One-vs-Rest) |
|
|
```python |
|
|
fpr, tpr, _ = roc_curve(y_true_bin[:, i], y_probs[:, i]) |
|
|
``` |
|
|
|
|
|
--- |
|
|
|
|
|
## Saving the Model |
|
|
|
|
|
```python |
|
|
model.save("Animal_Classification.h5") |
|
|
``` |
|
|
|
|
|
--- |
|
|
|
|
|
## Full Code Organization (High-Level) |
|
|
|
|
|
| Step | Description | |
|
|
|------|-------------| |
|
|
| 1 | Import libraries, mount drive | |
|
|
| 2 | Extract ZIP | |
|
|
| 3 | Build DataFrame | |
|
|
| 4 | EDA & cleaning | |
|
|
| 5 | Preprocessing & augmentation | |
|
|
| 6 | Dataset pipeline (train/val/test) | |
|
|
| 7 | CNN architecture | |
|
|
| 8 | Training | |
|
|
| 9 | Evaluation | |
|
|
|10 | Save model | |
|
|
|
|
|
--- |
|
|
|
|
|
## Final Notes |
|
|
|
|
|
This README is crafted to feel **human**, clean, and attractive — not autogenerated. It can be directly used in any GitHub repository. |
|
|
|
|
|
If you want, I can also: |
|
|
- Generate a **short version** |
|
|
- Add **badges** (TensorFlow, Python, etc.) |
|
|
- Write an **installation section** |
|
|
- Turn it into a **Hugging Face Space README** |
|
|
|
|
|
|
|
|
# Animal Image Classification – Complete Pipeline (README) |
|
|
|
|
|
> "A clean dataset is half the model’s accuracy. The rest is just engineering." |
|
|
|
|
|
This project presents a **complete end-to-end deep learning pipeline** for **multi-class animal image classification** using TensorFlow/Keras. It includes everything from data extraction, cleaning, and analysis, to model training, evaluation, and exporting. |
|
|
|
|
|
--- |
|
|
|
|
|
## Table of Contents |
|
|
| Section | Description | |
|
|
|--------|-------------| |
|
|
| **1. Project Overview** | What this project does & architecture overview | |
|
|
| **2. Features** | Key capabilities of this pipeline | |
|
|
| **3. Directory Structure** | Recommended project layout | |
|
|
| **4. Installation** | How to install and run this project | |
|
|
| **5. Dataset Processing** | Extraction, cleaning, inspections | |
|
|
| **6. Exploratory Data Analysis** | Visualizations & summary statistics | |
|
|
| **7. Preprocessing & Augmentation** | Data preparation logic | |
|
|
| **8. CNN Model Architecture** | Layers, blocks, hyperparameters | |
|
|
| **9. Training & Callbacks** | How the model is trained | |
|
|
| **10. Evaluation Metrics** | Reports, ROC curve, confusion matrix | |
|
|
| **11. Model Export** | Saving and downloading the model | |
|
|
| **12. Code Examples** | Important snippets explained | |
|
|
|
|
|
--- |
|
|
|
|
|
## 1. Project Overview |
|
|
This project builds a **Convolutional Neural Network (CNN)** to classify images of animals into multiple categories. The process includes: |
|
|
|
|
|
- Dataset extraction from Google Drive |
|
|
- Data validation (duplicates, corrupt files, mislabeled images) |
|
|
- Image enhancement & cleaning |
|
|
- Class distribution analysis |
|
|
- Image size analysis and outlier detection |
|
|
- Data augmentation |
|
|
- CNN model training with regularization |
|
|
- Performance evaluation using multiple metrics |
|
|
- Model export to `.h5` |
|
|
|
|
|
The pipeline is designed to be **robust, explainable, and production-friendly**. |
|
|
|
|
|
--- |
|
|
|
|
|
## 2. Features |
|
|
| Feature | Description | |
|
|
|---------|-------------| |
|
|
| **Automatic Dataset Extraction** | Unzips and loads images from Google Drive | |
|
|
| **Image Validation** | Detects duplicates, corrupted images, and mislabeled files | |
|
|
| **Data Cleaning** | Brightness/contrast enhancements for dark or overexposed samples | |
|
|
| **EDA Visualizations** | Class distribution, size, color modes, outliers | |
|
|
| **TensorFlow Dataset Pipeline** | Efficient TFRecords-like batching & prefetching | |
|
|
| **Deep CNN Model** | 32 → 64 → 128 → 256 filters with batch norm and dropout | |
|
|
| **Model Evaluation Dashboard** | Confusion matrix, ROC curves, precision/recall/F1 | |
|
|
| **Model Export** | Saves final model as `Animal_Classification.h5` | |
|
|
|
|
|
--- |
|
|
|
|
|
## 3. Recommended Directory Structure |
|
|
```text |
|
|
Animal-Classification |
|
|
┣ data |
|
|
┃ ┗ Animals (extracted folders) |
|
|
┣ notebooks |
|
|
┣ src |
|
|
┃ ┣ preprocessing.py |
|
|
┃ ┣ model.py |
|
|
┃ ┗ utils.py |
|
|
┣ README.md |
|
|
┗ requirements.txt |
|
|
``` |
|
|
|
|
|
--- |
|
|
|
|
|
## 4. Installation |
|
|
```bash |
|
|
pip install tensorflow pandas matplotlib seaborn scikit-learn pillow tqdm |
|
|
``` |
|
|
If using **Google Colab**, the project already supports: |
|
|
- `google.colab.drive` |
|
|
- `google.colab.files` |
|
|
|
|
|
--- |
|
|
|
|
|
## 5. Dataset Extraction & Loading |
|
|
Example snippet: |
|
|
```python |
|
|
zip_path = '/content/drive/MyDrive/Animals.zip' |
|
|
extract_to = '/content/my_data' |
|
|
with zipfile.ZipFile(zip_path, 'r') as zip_ref: |
|
|
zip_ref.extractall(extract_to) |
|
|
``` |
|
|
Images are collected into a DataFrame: |
|
|
```python |
|
|
paths = [(path.parts[-2], path.name, str(path)) for path in Path(extract_to).rglob('*.*')] |
|
|
df = pd.DataFrame(paths, columns=['class','image','full_path']) |
|
|
``` |
|
|
|
|
|
--- |
|
|
|
|
|
## 6. Exploratory Data Analysis |
|
|
Examples of generated visualizations: |
|
|
- Barplot of class distribution |
|
|
- Pie chart of percentage per class |
|
|
- Scatter plots of image width and height |
|
|
- Image mode (RGB/Gray) distribution |
|
|
|
|
|
Example: |
|
|
```python |
|
|
plt.figure(figsize=(32,16)) |
|
|
class_count.plot(kind='bar') |
|
|
``` |
|
|
|
|
|
--- |
|
|
|
|
|
## 7. Preprocessing & Augmentation |
|
|
### Preprocessing function |
|
|
```python |
|
|
def preprocess_image(path, target_size=(256,256)): |
|
|
img = tf.io.read_file(path) |
|
|
img = tf.image.decode_image(img, channels=3) |
|
|
img = tf.image.resize(img, target_size) |
|
|
return tf.cast(img, tf.float32)/255.0 |
|
|
``` |
|
|
### Augmentation |
|
|
```python |
|
|
data_augmentation = tf.keras.Sequential([ |
|
|
tf.keras.layers.RandomFlip("horizontal"), |
|
|
tf.keras.layers.RandomRotation(0.1), |
|
|
tf.keras.layers.RandomZoom(0.1), |
|
|
]) |
|
|
``` |
|
|
|
|
|
--- |
|
|
|
|
|
## 8. CNN Model Architecture |
|
|
| Block | Layers | |
|
|
|------|---------| |
|
|
| **Block 1** | Conv2D(32) → BN → Conv2D(32) → BN → MaxPool → Dropout(0.2) | |
|
|
| **Block 2** | Conv2D(64) → BN → Conv2D(64) → BN → MaxPool → Dropout(0.3) | |
|
|
| **Block 3** | Conv2D(128) → BN → Conv2D(128) → BN → MaxPool → Dropout(0.4) | |
|
|
| **Block 4** | Conv2D(256) → BN → Conv2D(256) → BN → MaxPool → Dropout(0.5) | |
|
|
| **Output** | Flatten → Dense(num_classes, softmax) | |
|
|
|
|
|
Example snippet: |
|
|
```python |
|
|
model.add(Conv2D(64,(3,3),activation='relu',padding='same')) |
|
|
model.add(BatchNormalization()) |
|
|
``` |
|
|
|
|
|
--- |
|
|
|
|
|
## 9. Training |
|
|
```python |
|
|
optimizer = Adam(learning_rate=0.0005) |
|
|
model.compile(optimizer=optimizer, |
|
|
loss='sparse_categorical_crossentropy', metrics=['accuracy']) |
|
|
``` |
|
|
Using callbacks: |
|
|
```python |
|
|
ReduceLROnPlateau(...) |
|
|
EarlyStopping(...) |
|
|
``` |
|
|
|
|
|
--- |
|
|
|
|
|
## 10. Evaluation Metrics |
|
|
This project computes: |
|
|
- Precision, Recall, F1 (macro & per class) |
|
|
- Confusion matrix (heatmap) |
|
|
- ROC curves (one-vs-rest) |
|
|
- Macro-average ROC |
|
|
|
|
|
Example: |
|
|
```python |
|
|
cm = confusion_matrix(y_true, y_pred) |
|
|
sns.heatmap(cm, annot=True) |
|
|
``` |
|
|
|
|
|
--- |
|
|
|
|
|
## 11. Model Export |
|
|
```python |
|
|
model.save("Animal_Classification.h5") |
|
|
files.download("Animal_Classification.h5") |
|
|
``` |
|
|
|
|
|
--- |
|
|
|
|
|
## 12. Example Snippets |
|
|
### Checking corrupted files |
|
|
```python |
|
|
try: |
|
|
with Image.open(path) as img: |
|
|
img.verify() |
|
|
except: |
|
|
corrupted.append(path) |
|
|
``` |
|
|
### Filtering duplicate images |
|
|
```python |
|
|
df['file_hash'] = df['full_path'].apply(get_hash) |
|
|
df_unique = df.drop_duplicates(subset='file_hash') |
|
|
``` |