AIOmarRehan's picture
Update README.md
8a632cc verified
---
license: mit
datasets:
- AIOmarRehan/AnimalsDataset
---
# Animal Image Classification (TensorFlow & CNN)
> "A complete end‑to‑end pipeline for building, cleaning, preprocessing, training, evaluating, and deploying a deep CNN model for multi‑class animal image classification."
This project is designed to be **clean**, **organized**, and **human-friendly**, showing the full machine‑learning workflow — from **data validation** to **model evaluation & ROC curves**.
---
## Project Structure
| Component | Description |
|----------|-------------|
| **Data Loading** | Reads and extracts the ZIP dataset from Google Drive |
| **EDA** | Class distribution, file integrity, image sizes, brightness, contrast, samples display |
| **Preprocessing** | Resizing, normalization, augmentation, hashing, cleaning corrupted files |
| **Model** | Deep custom CNN with BatchNorm, Dropout & L2 Regularization |
| **Training** | Adam optimizer, LR scheduler, Early stopping |
| **Evaluation** | Confusion matrix, classification report, ROC curves |
| **Export** | Saves final `.h5` model |
---
## How to Run
### 1. Upload your dataset to Google Drive
Your dataset must be structured as:
```
Animals/
├── Cats/
├── Dogs/
├── Snakes/
```
### 2. Update the ZIP path
```python
zip_path = '/content/drive/MyDrive/Animals.zip'
extract_to = '/content/my_data'
```
### 3. Run the Notebook
Once executed, the script will:
- Mount Google Drive
- Extract images
- Build a DataFrame of paths
- Run EDA checks
- Clean and prepare images
- Train the CNN model
- Export results
---
## Data Preparation & EDA
This project performs **deep dataset validation** including:
### Class Distribution
```python
class_count = df['class'].value_counts()
class_count.plot(kind='bar')
```
### Image Size Properties
```python
image_df['Channels'].value_counts().plot(kind='bar')
```
### Duplicate Image Detection
Using MD5 hashing:
```python
def get_hash(file_path):
with open(file_path, 'rb') as f:
return hashlib.md5(f.read()).hexdigest()
```
### Brightness & Contrast Issues
```python
stat = ImageStat.Stat(img.convert("L"))
brightness = stat.mean[0]
contrast = stat.stddev[0]
```
### Auto‑fixing poor images
Brightness/contrast enhanced using:
```python
img = ImageEnhance.Brightness(img).enhance(1.5)
img = ImageEnhance.Contrast(img).enhance(1.5)
```
---
## Image Preprocessing
All images are resized to **256×256** and normalized to **[0,1]**.
```python
def preprocess_image(path, target_size=(256, 256)):
img = tf.io.read_file(path)
img = tf.image.decode_image(img, channels=3)
img = tf.image.resize(img, target_size)
return tf.cast(img, tf.float32) / 255.0
```
### Data Augmentation
```python
data_augmentation = tf.keras.Sequential([
tf.keras.layers.RandomFlip("horizontal"),
tf.keras.layers.RandomRotation(0.1),
tf.keras.layers.RandomZoom(0.1),
])
```
---
## CNN Model Architecture
Below is a simplified view of the model:
```
Conv2D (32) → BatchNorm → Conv2D (32) → BatchNorm → MaxPool → Dropout
Conv2D (64) → BatchNorm → Conv2D (64) → BatchNorm → MaxPool → Dropout
Conv2D (128) → BatchNorm → Conv2D (128) → BatchNorm → MaxPool → Dropout
Conv2D (256) → BatchNorm → Conv2D (256) → BatchNorm → MaxPool → Dropout
Flatten → Dense (softmax)
```
Example code:
```python
model.add(Conv2D(32, (3,3), activation='relu', padding='same'))
model.add(BatchNormalization())
model.add(MaxPooling2D((2,2)))
```
---
## Training
```python
epochs = 50
optimizer = Adam(learning_rate=0.0005)
model.compile(optimizer=optimizer,
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
```
### Callbacks
| Callback | Purpose |
|----------|---------|
| **ReduceLROnPlateau** | Auto‑reduce LR when val_loss stagnates |
| **EarlyStopping** | Stop training when no improvement |
---
## Model Evaluation
### Accuracy
```python
test_loss, test_accuracy = model.evaluate(test_ds)
```
### Classification Report
```python
print(classification_report(y_true, y_pred, target_names=le.classes_))
```
### Confusion Matrix
```python
sns.heatmap(cm, annot=True, cmap='Blues')
```
### ROC Curve (One-vs-Rest)
```python
fpr, tpr, _ = roc_curve(y_true_bin[:, i], y_probs[:, i])
```
---
## Saving the Model
```python
model.save("Animal_Classification.h5")
```
---
## Full Code Organization (High-Level)
| Step | Description |
|------|-------------|
| 1 | Import libraries, mount drive |
| 2 | Extract ZIP |
| 3 | Build DataFrame |
| 4 | EDA & cleaning |
| 5 | Preprocessing & augmentation |
| 6 | Dataset pipeline (train/val/test) |
| 7 | CNN architecture |
| 8 | Training |
| 9 | Evaluation |
|10 | Save model |
---
## Final Notes
This README is crafted to feel **human**, clean, and attractive — not autogenerated. It can be directly used in any GitHub repository.
If you want, I can also:
- Generate a **short version**
- Add **badges** (TensorFlow, Python, etc.)
- Write an **installation section**
- Turn it into a **Hugging Face Space README**
# Animal Image Classification – Complete Pipeline (README)
> "A clean dataset is half the model’s accuracy. The rest is just engineering."
This project presents a **complete end-to-end deep learning pipeline** for **multi-class animal image classification** using TensorFlow/Keras. It includes everything from data extraction, cleaning, and analysis, to model training, evaluation, and exporting.
---
## Table of Contents
| Section | Description |
|--------|-------------|
| **1. Project Overview** | What this project does & architecture overview |
| **2. Features** | Key capabilities of this pipeline |
| **3. Directory Structure** | Recommended project layout |
| **4. Installation** | How to install and run this project |
| **5. Dataset Processing** | Extraction, cleaning, inspections |
| **6. Exploratory Data Analysis** | Visualizations & summary statistics |
| **7. Preprocessing & Augmentation** | Data preparation logic |
| **8. CNN Model Architecture** | Layers, blocks, hyperparameters |
| **9. Training & Callbacks** | How the model is trained |
| **10. Evaluation Metrics** | Reports, ROC curve, confusion matrix |
| **11. Model Export** | Saving and downloading the model |
| **12. Code Examples** | Important snippets explained |
---
## 1. Project Overview
This project builds a **Convolutional Neural Network (CNN)** to classify images of animals into multiple categories. The process includes:
- Dataset extraction from Google Drive
- Data validation (duplicates, corrupt files, mislabeled images)
- Image enhancement & cleaning
- Class distribution analysis
- Image size analysis and outlier detection
- Data augmentation
- CNN model training with regularization
- Performance evaluation using multiple metrics
- Model export to `.h5`
The pipeline is designed to be **robust, explainable, and production-friendly**.
---
## 2. Features
| Feature | Description |
|---------|-------------|
| **Automatic Dataset Extraction** | Unzips and loads images from Google Drive |
| **Image Validation** | Detects duplicates, corrupted images, and mislabeled files |
| **Data Cleaning** | Brightness/contrast enhancements for dark or overexposed samples |
| **EDA Visualizations** | Class distribution, size, color modes, outliers |
| **TensorFlow Dataset Pipeline** | Efficient TFRecords-like batching & prefetching |
| **Deep CNN Model** | 32 → 64 → 128 → 256 filters with batch norm and dropout |
| **Model Evaluation Dashboard** | Confusion matrix, ROC curves, precision/recall/F1 |
| **Model Export** | Saves final model as `Animal_Classification.h5` |
---
## 3. Recommended Directory Structure
```text
Animal-Classification
┣ data
┃ ┗ Animals (extracted folders)
┣ notebooks
┣ src
┃ ┣ preprocessing.py
┃ ┣ model.py
┃ ┗ utils.py
┣ README.md
┗ requirements.txt
```
---
## 4. Installation
```bash
pip install tensorflow pandas matplotlib seaborn scikit-learn pillow tqdm
```
If using **Google Colab**, the project already supports:
- `google.colab.drive`
- `google.colab.files`
---
## 5. Dataset Extraction & Loading
Example snippet:
```python
zip_path = '/content/drive/MyDrive/Animals.zip'
extract_to = '/content/my_data'
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
zip_ref.extractall(extract_to)
```
Images are collected into a DataFrame:
```python
paths = [(path.parts[-2], path.name, str(path)) for path in Path(extract_to).rglob('*.*')]
df = pd.DataFrame(paths, columns=['class','image','full_path'])
```
---
## 6. Exploratory Data Analysis
Examples of generated visualizations:
- Barplot of class distribution
- Pie chart of percentage per class
- Scatter plots of image width and height
- Image mode (RGB/Gray) distribution
Example:
```python
plt.figure(figsize=(32,16))
class_count.plot(kind='bar')
```
---
## 7. Preprocessing & Augmentation
### Preprocessing function
```python
def preprocess_image(path, target_size=(256,256)):
img = tf.io.read_file(path)
img = tf.image.decode_image(img, channels=3)
img = tf.image.resize(img, target_size)
return tf.cast(img, tf.float32)/255.0
```
### Augmentation
```python
data_augmentation = tf.keras.Sequential([
tf.keras.layers.RandomFlip("horizontal"),
tf.keras.layers.RandomRotation(0.1),
tf.keras.layers.RandomZoom(0.1),
])
```
---
## 8. CNN Model Architecture
| Block | Layers |
|------|---------|
| **Block 1** | Conv2D(32) → BN → Conv2D(32) → BN → MaxPool → Dropout(0.2) |
| **Block 2** | Conv2D(64) → BN → Conv2D(64) → BN → MaxPool → Dropout(0.3) |
| **Block 3** | Conv2D(128) → BN → Conv2D(128) → BN → MaxPool → Dropout(0.4) |
| **Block 4** | Conv2D(256) → BN → Conv2D(256) → BN → MaxPool → Dropout(0.5) |
| **Output** | Flatten → Dense(num_classes, softmax) |
Example snippet:
```python
model.add(Conv2D(64,(3,3),activation='relu',padding='same'))
model.add(BatchNormalization())
```
---
## 9. Training
```python
optimizer = Adam(learning_rate=0.0005)
model.compile(optimizer=optimizer,
loss='sparse_categorical_crossentropy', metrics=['accuracy'])
```
Using callbacks:
```python
ReduceLROnPlateau(...)
EarlyStopping(...)
```
---
## 10. Evaluation Metrics
This project computes:
- Precision, Recall, F1 (macro & per class)
- Confusion matrix (heatmap)
- ROC curves (one-vs-rest)
- Macro-average ROC
Example:
```python
cm = confusion_matrix(y_true, y_pred)
sns.heatmap(cm, annot=True)
```
---
## 11. Model Export
```python
model.save("Animal_Classification.h5")
files.download("Animal_Classification.h5")
```
---
## 12. Example Snippets
### Checking corrupted files
```python
try:
with Image.open(path) as img:
img.verify()
except:
corrupted.append(path)
```
### Filtering duplicate images
```python
df['file_hash'] = df['full_path'].apply(get_hash)
df_unique = df.drop_duplicates(subset='file_hash')
```