Spaces:
No application file
No application file
Upload 3 files
Browse filesthis upload consists model weights readme and a test script
- 650556_95909_35_weights__.pt +3 -0
- README (3).md +61 -0
- test.py +170 -0
650556_95909_35_weights__.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:7f942769cedcd6aecc0670a790709a6f2be1eeda6d8d5fb94976fc4a6e2ede22
|
| 3 |
+
size 44815245
|
README (3).md
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ✍️ Printed Word-Level Script Identification Multi-class (14-Class Model)
|
| 2 |
+
|
| 3 |
+
**initial model** for Printed document word-level script separation across **13 Indic languages + English**.
|
| 4 |
+
The model is designed to classify word images into their respective script categories.
|
| 5 |
+
i.e. `Assamese, Bengali, English, Gujarati, Hindi, Kannada, Malayalam, Manipuri, Marathi, Punjabi, Tamil, Telugu, Urdu, odia`
|
| 6 |
+
|
| 7 |
+
---
|
| 8 |
+
|
| 9 |
+
## 📊 Dataset Overview
|
| 10 |
+
|
| 11 |
+
- **Training samples**: ~650560
|
| 12 |
+
- **Validation samples**: ~95909
|
| 13 |
+
- (Test set used for evaluation)
|
| 14 |
+
|
| 15 |
+
---
|
| 16 |
+
|
| 17 |
+
## ⚙️ Training Setup
|
| 18 |
+
|
| 19 |
+
- **Model**: ResNet-18
|
| 20 |
+
- **Preprocessing**: Custom **binarization function** applied for improved feature extraction
|
| 21 |
+
- **Input size**: 224 × 224 RGB
|
| 22 |
+
- **Optimizer**: Adam
|
| 23 |
+
- **Loss function**: CrossEntropyLoss
|
| 24 |
+
- **Epochs**: model trained up to 35th epoch (weights shared)
|
| 25 |
+
|
| 26 |
+
---
|
| 27 |
+
|
| 28 |
+
## 📈 Results & Evaluation
|
| 29 |
+
|
| 30 |
+
The model was evaluated on the **test set**.
|
| 31 |
+
Accompanying this README, you will find PNG visualizations for:
|
| 32 |
+
|
| 33 |
+
- Confusion Matrix
|
| 34 |
+
- Per-class Precision, Recall, F1-Score
|
| 35 |
+
- Support vs Correct Predictions per class
|
| 36 |
+
- Top Misclassifications
|
| 37 |
+
|
| 38 |
+
These provide a detailed breakdown of model performance across all 14 classes.
|
| 39 |
+
|
| 40 |
+
---
|
| 41 |
+
|
| 42 |
+
## 📂 Included Files
|
| 43 |
+
|
| 44 |
+
- `model_weights/` → Trained ResNet-18 weights
|
| 45 |
+
- `wt_35_test_report/` → Evaluation visualizations (confusion matrix, metrics, misclassifications, etc.)
|
| 46 |
+
- `test.py` → Script used to run evaluation
|
| 47 |
+
|
| 48 |
+
---
|
| 49 |
+
|
| 50 |
+
## 🗂️ Class Labels
|
| 51 |
+
|
| 52 |
+
The model predicts among **14 classes**:
|
| 53 |
+
|
| 54 |
+
`Assamese, Bengali, English, Gujarati, Hindi, Kannada, Malayalam, Manipuri, Marathi, Punjabi, Tamil, Telugu, Urdu, Odia`
|
| 55 |
+
|
| 56 |
+
---
|
| 57 |
+
|
| 58 |
+
## 📝 Note
|
| 59 |
+
|
| 60 |
+
This is an **initial baseline model** trained.
|
| 61 |
+
Further improvements can be made by training on the complete dataset and tuning hyperparameters.
|
test.py
ADDED
|
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
from torchvision import datasets, transforms, models
|
| 5 |
+
from torch.utils.data import DataLoader
|
| 6 |
+
import matplotlib.pyplot as plt
|
| 7 |
+
from sklearn.metrics import confusion_matrix
|
| 8 |
+
import seaborn as sns
|
| 9 |
+
import numpy as np
|
| 10 |
+
import csv # Add this line to import the csv module
|
| 11 |
+
from PIL import Image # Added for binarization
|
| 12 |
+
|
| 13 |
+
# Binarization function
|
| 14 |
+
def binarize(img, threshold=128):
|
| 15 |
+
img = img.convert('L')
|
| 16 |
+
return img.point(lambda p: 255 if p > threshold else 0).convert('L')
|
| 17 |
+
|
| 18 |
+
# Constants
|
| 19 |
+
|
| 20 |
+
# Function to get class names and counts
|
| 21 |
+
def get_class_info(directory):
|
| 22 |
+
classes = sorted(os.listdir(directory))
|
| 23 |
+
class_lengths = {class_name: len(os.listdir(os.path.join(directory, class_name))) for class_name in classes}
|
| 24 |
+
return classes, class_lengths
|
| 25 |
+
|
| 26 |
+
class ImageFolderWithPaths(datasets.ImageFolder):
|
| 27 |
+
def __getitem__(self, index):
|
| 28 |
+
# this is the original getitem method
|
| 29 |
+
original_tuple = super(ImageFolderWithPaths, self).__getitem__(index)
|
| 30 |
+
# the image file path
|
| 31 |
+
path = self.imgs[index][0]
|
| 32 |
+
# make a new tuple that includes original and the path
|
| 33 |
+
tuple_with_path = (original_tuple + (path,))
|
| 34 |
+
return tuple_with_path
|
| 35 |
+
|
| 36 |
+
# 1. Data Loading and Transformation
|
| 37 |
+
def load_data(TEST, batch_size=10): # 25
|
| 38 |
+
transform_test = transforms.Compose([
|
| 39 |
+
transforms.Lambda(lambda img: binarize(img, threshold=128)),
|
| 40 |
+
transforms.Resize((224, 224)),
|
| 41 |
+
transforms.Lambda(lambda img: img.convert('RGB')),
|
| 42 |
+
transforms.ToTensor(),
|
| 43 |
+
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
|
| 44 |
+
])
|
| 45 |
+
|
| 46 |
+
test_dataset = ImageFolderWithPaths(root=TEST, transform=transform_test)
|
| 47 |
+
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
|
| 48 |
+
classes = sorted(os.listdir(TEST))
|
| 49 |
+
return test_loader, classes, len(test_dataset)
|
| 50 |
+
|
| 51 |
+
# 2. Model Loading and Configuration
|
| 52 |
+
def load_model(model_weights_path, num_classes, device):
|
| 53 |
+
model = models.resnet18(weights=None) # Changed to resnet18
|
| 54 |
+
model.fc = nn.Linear(model.fc.in_features, num_classes)
|
| 55 |
+
model.load_state_dict(torch.load(model_weights_path))
|
| 56 |
+
model = model.to(device)
|
| 57 |
+
model.eval()
|
| 58 |
+
return model
|
| 59 |
+
|
| 60 |
+
# 3. Evaluation for Testing
|
| 61 |
+
def evaluate(model, test_iterator, criterion, device):
|
| 62 |
+
model.eval()
|
| 63 |
+
epoch_loss = 0
|
| 64 |
+
correct = 0
|
| 65 |
+
total = 0
|
| 66 |
+
|
| 67 |
+
with torch.no_grad():
|
| 68 |
+
for batch_idx, (data, labels, _) in enumerate(test_iterator):
|
| 69 |
+
data, labels = data.to(device), labels.to(device)
|
| 70 |
+
outputs = model(data)
|
| 71 |
+
loss = criterion(outputs, labels)
|
| 72 |
+
epoch_loss += loss.item()
|
| 73 |
+
|
| 74 |
+
_, predicted = outputs.max(1)
|
| 75 |
+
total += labels.size(0)
|
| 76 |
+
correct += predicted.eq(labels).sum().item()
|
| 77 |
+
|
| 78 |
+
# Calculate overall accuracy
|
| 79 |
+
accuracy = 100. * correct / total
|
| 80 |
+
avg_loss = epoch_loss / len(test_iterator)
|
| 81 |
+
print(f'\nTesting completed. Average Loss: {avg_loss:.4f}, Accuracy: {accuracy:.2f}%\n')
|
| 82 |
+
return avg_loss, accuracy
|
| 83 |
+
|
| 84 |
+
# 4. Prediction and Label Gathering
|
| 85 |
+
def get_all_predictions(model, iterator, device):
|
| 86 |
+
model.eval()
|
| 87 |
+
all_preds = []
|
| 88 |
+
all_labels = []
|
| 89 |
+
all_filenames = [] # to store the filenames
|
| 90 |
+
|
| 91 |
+
with torch.no_grad():
|
| 92 |
+
for data, labels, paths in iterator:
|
| 93 |
+
data = data.to(device)
|
| 94 |
+
outputs = model(data)
|
| 95 |
+
_, predicted = torch.max(outputs, 1)
|
| 96 |
+
all_preds.extend(predicted.cpu().numpy())
|
| 97 |
+
all_labels.extend(labels.numpy())
|
| 98 |
+
# Retrieving file paths
|
| 99 |
+
all_filenames.extend(paths)
|
| 100 |
+
return all_preds, all_labels, all_filenames
|
| 101 |
+
|
| 102 |
+
# 5. Confusion Matrix Creation and Visualization
|
| 103 |
+
def plot_and_save_confusion_matrix(true_labels, predictions, classes, save_path):
|
| 104 |
+
# Ensure the directory exists
|
| 105 |
+
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
| 106 |
+
conf_mat = confusion_matrix(true_labels, predictions)
|
| 107 |
+
plt.figure(figsize=(10, 10))
|
| 108 |
+
sns.heatmap(conf_mat, annot=True, fmt='d', cmap='Blues', xticklabels=classes, yticklabels=classes)
|
| 109 |
+
plt.xlabel('Predicted Labels')
|
| 110 |
+
plt.ylabel('True Labels')
|
| 111 |
+
plt.title('Confusion Matrix')
|
| 112 |
+
plt.savefig(save_path)
|
| 113 |
+
plt.close() # Close the plot to free memory
|
| 114 |
+
print(f"Confusion matrix saved to {save_path}")
|
| 115 |
+
|
| 116 |
+
# Main Execution
|
| 117 |
+
def main():
|
| 118 |
+
TEST_DIR = '/home/pola/01_printed_document/printed_13class/data/test'
|
| 119 |
+
MODEL_WEIGHTS_PATH = "/home/pola/01_printed_document/printed_13class/model_weights/650556_95909_35_weights.pt"
|
| 120 |
+
directory, filename = os.path.split(MODEL_WEIGHTS_PATH)
|
| 121 |
+
RESULTS_DIR = '/home/pola/01_printed_document/printed_13class/results/'
|
| 122 |
+
if not os.path.exists(RESULTS_DIR):
|
| 123 |
+
os.makedirs(RESULTS_DIR)
|
| 124 |
+
train_classes, train_class_lengths = get_class_info(TEST_DIR)
|
| 125 |
+
max_class_name_length = max(len(name) for name in train_classes)
|
| 126 |
+
|
| 127 |
+
print(f"{'Class Name'.ljust(max_class_name_length)} | {'Test Images'.ljust(16)} ")
|
| 128 |
+
print('-' * (max_class_name_length + 36))
|
| 129 |
+
for class_name in sorted(set(train_classes)):
|
| 130 |
+
train_count = train_class_lengths.get(class_name, 0)
|
| 131 |
+
print(f"{class_name.ljust(max_class_name_length)} | {str(train_count).ljust(16)} ")
|
| 132 |
+
print(f"\nTotal images in Training Dataset: {sum(train_class_lengths.values())}")
|
| 133 |
+
classes = sorted(os.listdir(TEST_DIR))
|
| 134 |
+
|
| 135 |
+
# Initialize
|
| 136 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 137 |
+
criterion = nn.CrossEntropyLoss()
|
| 138 |
+
|
| 139 |
+
# Load Data
|
| 140 |
+
test_loader, classes, num_test_samples = load_data(TEST_DIR)
|
| 141 |
+
|
| 142 |
+
# Load Model
|
| 143 |
+
model = load_model(MODEL_WEIGHTS_PATH, len(classes), device)
|
| 144 |
+
print("Loaded model with classes:", classes)
|
| 145 |
+
|
| 146 |
+
# Evaluate Model
|
| 147 |
+
test_loss, test_acc = evaluate(model, test_loader, criterion, device)
|
| 148 |
+
|
| 149 |
+
# Plot and Save Confusion Matrix
|
| 150 |
+
save_path = './results'
|
| 151 |
+
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
| 152 |
+
all_preds, all_labels, all_filenames = get_all_predictions(model, test_loader, device)
|
| 153 |
+
|
| 154 |
+
# Dynamic filename construction
|
| 155 |
+
filename_details = f"{filename}_{num_test_samples}_smpls_{test_acc:.2f}_pct_acc_{len(classes)}_classes"
|
| 156 |
+
csv_filename = f"{RESULTS_DIR}/pred_{filename_details}.csv"
|
| 157 |
+
confusion_matrix_filename = f"{RESULTS_DIR}conf_{filename_details}_.png"
|
| 158 |
+
|
| 159 |
+
# Save all details to a CSV file
|
| 160 |
+
with open(csv_filename, 'w', newline='') as csvfile:
|
| 161 |
+
csvwriter = csv.writer(csvfile)
|
| 162 |
+
csvwriter.writerow(['Filename', 'True Label', 'Predicted Label'])
|
| 163 |
+
for filename, true_label, predicted_label in zip(all_filenames, all_labels, all_preds):
|
| 164 |
+
csvwriter.writerow([filename, classes[true_label], classes[predicted_label]])
|
| 165 |
+
|
| 166 |
+
print(f"All prediction details saved in CSV: {csv_filename}")
|
| 167 |
+
plot_and_save_confusion_matrix(all_labels, all_preds, classes, confusion_matrix_filename)
|
| 168 |
+
|
| 169 |
+
if __name__ == "__main__":
|
| 170 |
+
main()
|