Asim02 commited on
Commit
b82820a
·
verified ·
1 Parent(s): fdb0a9f

Upload 3 files

Browse files

this upload consists model weights readme and a test script

Files changed (3) hide show
  1. 650556_95909_35_weights__.pt +3 -0
  2. README (3).md +61 -0
  3. test.py +170 -0
650556_95909_35_weights__.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7f942769cedcd6aecc0670a790709a6f2be1eeda6d8d5fb94976fc4a6e2ede22
3
+ size 44815245
README (3).md ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ✍️ Printed Word-Level Script Identification Multi-class (14-Class Model)
2
+
3
+ **initial model** for Printed document word-level script separation across **13 Indic languages + English**.
4
+ The model is designed to classify word images into their respective script categories.
5
+ i.e. `Assamese, Bengali, English, Gujarati, Hindi, Kannada, Malayalam, Manipuri, Marathi, Punjabi, Tamil, Telugu, Urdu, odia`
6
+
7
+ ---
8
+
9
+ ## 📊 Dataset Overview
10
+
11
+ - **Training samples**: ~650560
12
+ - **Validation samples**: ~95909
13
+ - (Test set used for evaluation)
14
+
15
+ ---
16
+
17
+ ## ⚙️ Training Setup
18
+
19
+ - **Model**: ResNet-18
20
+ - **Preprocessing**: Custom **binarization function** applied for improved feature extraction
21
+ - **Input size**: 224 × 224 RGB
22
+ - **Optimizer**: Adam
23
+ - **Loss function**: CrossEntropyLoss
24
+ - **Epochs**: model trained up to 35th epoch (weights shared)
25
+
26
+ ---
27
+
28
+ ## 📈 Results & Evaluation
29
+
30
+ The model was evaluated on the **test set**.
31
+ Accompanying this README, you will find PNG visualizations for:
32
+
33
+ - Confusion Matrix
34
+ - Per-class Precision, Recall, F1-Score
35
+ - Support vs Correct Predictions per class
36
+ - Top Misclassifications
37
+
38
+ These provide a detailed breakdown of model performance across all 14 classes.
39
+
40
+ ---
41
+
42
+ ## 📂 Included Files
43
+
44
+ - `model_weights/` → Trained ResNet-18 weights
45
+ - `wt_35_test_report/` → Evaluation visualizations (confusion matrix, metrics, misclassifications, etc.)
46
+ - `test.py` → Script used to run evaluation
47
+
48
+ ---
49
+
50
+ ## 🗂️ Class Labels
51
+
52
+ The model predicts among **14 classes**:
53
+
54
+ `Assamese, Bengali, English, Gujarati, Hindi, Kannada, Malayalam, Manipuri, Marathi, Punjabi, Tamil, Telugu, Urdu, Odia`
55
+
56
+ ---
57
+
58
+ ## 📝 Note
59
+
60
+ This is an **initial baseline model** trained.
61
+ Further improvements can be made by training on the complete dataset and tuning hyperparameters.
test.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torch.nn as nn
4
+ from torchvision import datasets, transforms, models
5
+ from torch.utils.data import DataLoader
6
+ import matplotlib.pyplot as plt
7
+ from sklearn.metrics import confusion_matrix
8
+ import seaborn as sns
9
+ import numpy as np
10
+ import csv # Add this line to import the csv module
11
+ from PIL import Image # Added for binarization
12
+
13
+ # Binarization function
14
+ def binarize(img, threshold=128):
15
+ img = img.convert('L')
16
+ return img.point(lambda p: 255 if p > threshold else 0).convert('L')
17
+
18
+ # Constants
19
+
20
+ # Function to get class names and counts
21
+ def get_class_info(directory):
22
+ classes = sorted(os.listdir(directory))
23
+ class_lengths = {class_name: len(os.listdir(os.path.join(directory, class_name))) for class_name in classes}
24
+ return classes, class_lengths
25
+
26
+ class ImageFolderWithPaths(datasets.ImageFolder):
27
+ def __getitem__(self, index):
28
+ # this is the original getitem method
29
+ original_tuple = super(ImageFolderWithPaths, self).__getitem__(index)
30
+ # the image file path
31
+ path = self.imgs[index][0]
32
+ # make a new tuple that includes original and the path
33
+ tuple_with_path = (original_tuple + (path,))
34
+ return tuple_with_path
35
+
36
+ # 1. Data Loading and Transformation
37
+ def load_data(TEST, batch_size=10): # 25
38
+ transform_test = transforms.Compose([
39
+ transforms.Lambda(lambda img: binarize(img, threshold=128)),
40
+ transforms.Resize((224, 224)),
41
+ transforms.Lambda(lambda img: img.convert('RGB')),
42
+ transforms.ToTensor(),
43
+ transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
44
+ ])
45
+
46
+ test_dataset = ImageFolderWithPaths(root=TEST, transform=transform_test)
47
+ test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
48
+ classes = sorted(os.listdir(TEST))
49
+ return test_loader, classes, len(test_dataset)
50
+
51
+ # 2. Model Loading and Configuration
52
+ def load_model(model_weights_path, num_classes, device):
53
+ model = models.resnet18(weights=None) # Changed to resnet18
54
+ model.fc = nn.Linear(model.fc.in_features, num_classes)
55
+ model.load_state_dict(torch.load(model_weights_path))
56
+ model = model.to(device)
57
+ model.eval()
58
+ return model
59
+
60
+ # 3. Evaluation for Testing
61
+ def evaluate(model, test_iterator, criterion, device):
62
+ model.eval()
63
+ epoch_loss = 0
64
+ correct = 0
65
+ total = 0
66
+
67
+ with torch.no_grad():
68
+ for batch_idx, (data, labels, _) in enumerate(test_iterator):
69
+ data, labels = data.to(device), labels.to(device)
70
+ outputs = model(data)
71
+ loss = criterion(outputs, labels)
72
+ epoch_loss += loss.item()
73
+
74
+ _, predicted = outputs.max(1)
75
+ total += labels.size(0)
76
+ correct += predicted.eq(labels).sum().item()
77
+
78
+ # Calculate overall accuracy
79
+ accuracy = 100. * correct / total
80
+ avg_loss = epoch_loss / len(test_iterator)
81
+ print(f'\nTesting completed. Average Loss: {avg_loss:.4f}, Accuracy: {accuracy:.2f}%\n')
82
+ return avg_loss, accuracy
83
+
84
+ # 4. Prediction and Label Gathering
85
+ def get_all_predictions(model, iterator, device):
86
+ model.eval()
87
+ all_preds = []
88
+ all_labels = []
89
+ all_filenames = [] # to store the filenames
90
+
91
+ with torch.no_grad():
92
+ for data, labels, paths in iterator:
93
+ data = data.to(device)
94
+ outputs = model(data)
95
+ _, predicted = torch.max(outputs, 1)
96
+ all_preds.extend(predicted.cpu().numpy())
97
+ all_labels.extend(labels.numpy())
98
+ # Retrieving file paths
99
+ all_filenames.extend(paths)
100
+ return all_preds, all_labels, all_filenames
101
+
102
+ # 5. Confusion Matrix Creation and Visualization
103
+ def plot_and_save_confusion_matrix(true_labels, predictions, classes, save_path):
104
+ # Ensure the directory exists
105
+ os.makedirs(os.path.dirname(save_path), exist_ok=True)
106
+ conf_mat = confusion_matrix(true_labels, predictions)
107
+ plt.figure(figsize=(10, 10))
108
+ sns.heatmap(conf_mat, annot=True, fmt='d', cmap='Blues', xticklabels=classes, yticklabels=classes)
109
+ plt.xlabel('Predicted Labels')
110
+ plt.ylabel('True Labels')
111
+ plt.title('Confusion Matrix')
112
+ plt.savefig(save_path)
113
+ plt.close() # Close the plot to free memory
114
+ print(f"Confusion matrix saved to {save_path}")
115
+
116
+ # Main Execution
117
+ def main():
118
+ TEST_DIR = '/home/pola/01_printed_document/printed_13class/data/test'
119
+ MODEL_WEIGHTS_PATH = "/home/pola/01_printed_document/printed_13class/model_weights/650556_95909_35_weights.pt"
120
+ directory, filename = os.path.split(MODEL_WEIGHTS_PATH)
121
+ RESULTS_DIR = '/home/pola/01_printed_document/printed_13class/results/'
122
+ if not os.path.exists(RESULTS_DIR):
123
+ os.makedirs(RESULTS_DIR)
124
+ train_classes, train_class_lengths = get_class_info(TEST_DIR)
125
+ max_class_name_length = max(len(name) for name in train_classes)
126
+
127
+ print(f"{'Class Name'.ljust(max_class_name_length)} | {'Test Images'.ljust(16)} ")
128
+ print('-' * (max_class_name_length + 36))
129
+ for class_name in sorted(set(train_classes)):
130
+ train_count = train_class_lengths.get(class_name, 0)
131
+ print(f"{class_name.ljust(max_class_name_length)} | {str(train_count).ljust(16)} ")
132
+ print(f"\nTotal images in Training Dataset: {sum(train_class_lengths.values())}")
133
+ classes = sorted(os.listdir(TEST_DIR))
134
+
135
+ # Initialize
136
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
137
+ criterion = nn.CrossEntropyLoss()
138
+
139
+ # Load Data
140
+ test_loader, classes, num_test_samples = load_data(TEST_DIR)
141
+
142
+ # Load Model
143
+ model = load_model(MODEL_WEIGHTS_PATH, len(classes), device)
144
+ print("Loaded model with classes:", classes)
145
+
146
+ # Evaluate Model
147
+ test_loss, test_acc = evaluate(model, test_loader, criterion, device)
148
+
149
+ # Plot and Save Confusion Matrix
150
+ save_path = './results'
151
+ os.makedirs(os.path.dirname(save_path), exist_ok=True)
152
+ all_preds, all_labels, all_filenames = get_all_predictions(model, test_loader, device)
153
+
154
+ # Dynamic filename construction
155
+ filename_details = f"{filename}_{num_test_samples}_smpls_{test_acc:.2f}_pct_acc_{len(classes)}_classes"
156
+ csv_filename = f"{RESULTS_DIR}/pred_{filename_details}.csv"
157
+ confusion_matrix_filename = f"{RESULTS_DIR}conf_{filename_details}_.png"
158
+
159
+ # Save all details to a CSV file
160
+ with open(csv_filename, 'w', newline='') as csvfile:
161
+ csvwriter = csv.writer(csvfile)
162
+ csvwriter.writerow(['Filename', 'True Label', 'Predicted Label'])
163
+ for filename, true_label, predicted_label in zip(all_filenames, all_labels, all_preds):
164
+ csvwriter.writerow([filename, classes[true_label], classes[predicted_label]])
165
+
166
+ print(f"All prediction details saved in CSV: {csv_filename}")
167
+ plot_and_save_confusion_matrix(all_labels, all_preds, classes, confusion_matrix_filename)
168
+
169
+ if __name__ == "__main__":
170
+ main()