MultivexAI commited on
Commit
51a15c5
·
verified ·
1 Parent(s): 683e275

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +128 -0
README.md CHANGED
@@ -1,3 +1,131 @@
1
  ---
2
  license: apache-2.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  license: apache-2.0
3
+ pipeline_tag: image-classification
4
+ library_name: pytorch
5
+ tags:
6
+ - mnist
7
+ - robust
8
+ - open-set
9
+ - computer-vision
10
+ model-index:
11
+ - name: RobustMNIST-v1.0
12
+ results:
13
+ - task:
14
+ type: image-classification
15
+ dataset:
16
+ name: MNIST
17
+ type: mnist
18
+ metrics:
19
+ - name: Accuracy (Clean)
20
+ type: accuracy
21
+ value: 99.51
22
+ - name: Accuracy (Extreme OOD)
23
+ type: accuracy
24
+ value: 92.33
25
  ---
26
+
27
+ # RobustMNIST (v1.0)
28
+
29
+ <div align="left">
30
+ <img src="https://cdn-uploads.huggingface.co/production/uploads/64b433c3faa3181a5e98c87c/nex9yBr2wmq88q9UwHULO.png" width="1000">
31
+ </div>
32
+
33
+ `RobustMNIST` is a lightweight, 11-class convolutional neural network designed for handwritten digit recognition. Unlike standard models, this architecture is built to handle out-of-distribution (OOD) inputs and extreme image corruption through a dedicated **"Unknown"** class.
34
+
35
+ ### Model Details
36
+ * **Developed by:** MultivexAI
37
+ * **Task:** Open-Set Handwritten Digit Recognition
38
+ * **Architecture:** 6-Layer Gated CNN (approx. 430k parameters)
39
+ * **Classes:** 11 (0–9 for standard digits, **10 for "Unknown"**)
40
+ * **Input:** 1x28x28 grayscale image.
41
+
42
+ ---
43
+
44
+ ## The "Unknown" Class (Class 10)
45
+
46
+ Traditional MNIST models often guess a digit confidently even when the input is just random noise or a shape that isn't a number.
47
+
48
+ RobustMNIST introduces **Class 10**, representing the "Unknown" domain.
49
+ - **In-Distribution:** For clean digits, the model predicts classes 0–9 with high accuracy, while maintaining a 15-20% uncertainty margin for Class 10.
50
+ - **Out-of-Distribution:** When an image is severely corrupted (noise, stains, blurs) or represents a non-digit shape, the model's confidence shifts entirely to Class 10.
51
+
52
+ ---
53
+
54
+ ## Performance Metrics
55
+
56
+ Evaluation on standard MNIST and extreme corruption sets:
57
+
58
+ | Set | Accuracy |
59
+ | :--- | :--- |
60
+ | **Clean MNIST Test Set** | 99.51% |
61
+ | **Extreme OOD / Corrupted Set** | 92.33% |
62
+
63
+ ---
64
+
65
+ ### Limitations & Expectations
66
+ While titled **RobustMNIST**, it is important to clarify that "robust" does not mean "invincible." This is a small-scale model designed to demonstrate OOD detection, not a perfect safety system.
67
+
68
+ - **No 100% Guarantee:** Like all neural networks, this model can and will make mistakes.
69
+ - **The "Robust" Definition:** In this context, robustness refers to the model's *improved* resistance to noise and its ability to express uncertainty via the "Unknown" class compared to standard classifiers. It is not an absolute shield against all possible adversarial or geometric attacks.
70
+ - **Semantic Edge Cases:** Certain transformations, such as rotating a "6" until it looks like a "9" or mirroring asymmetric digits create mathematical ambiguities. We acknowledge these limits; at this parameter count, the model prioritizes identifying structured digits over handling every possible topological distortion.
71
+ - **Research Scope:** This is a 1.0 release focused on balancing clean accuracy with OOD calibration. We agree that edge cases exist where the model may still fail or default to "Unknown" unexpectedly.
72
+
73
+ ## Usage
74
+
75
+ To use this model, ensure you have `model.py` and `model.pt` in your directory.
76
+
77
+ ### Simple Test Script (`test.py`)
78
+
79
+ This script picks a random digit from the MNIST test set and runs a prediction.
80
+
81
+ Tested on **Python 3.12**.
82
+
83
+ ```python
84
+ import torch
85
+ import torch.nn.functional as F
86
+ from torchvision import datasets, transforms
87
+ from model import HierarchicalNetwork
88
+
89
+ # Execution configurations
90
+ PARAMETER_PATH = "model.pt"
91
+ HARDWARE_TARGET = torch.device("cuda" if torch.cuda.is_available() else "cpu")
92
+
93
+ def execute():
94
+ # initialize architecture
95
+ processor = HierarchicalNetwork(out_dims=11).to(HARDWARE_TARGET)
96
+
97
+ # load state parameters
98
+ state_data = torch.load(PARAMETER_PATH, map_location=HARDWARE_TARGET)
99
+ weights = state_data.get('state_dict', state_data)
100
+ processor.load_state_dict(weights)
101
+ processor.eval()
102
+
103
+ # pull random sample
104
+ dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transforms.ToTensor())
105
+ sample_index = torch.randint(0, len(dataset), (1,)).item()
106
+ input_tensor, ground_truth = dataset[sample_index]
107
+
108
+ # compute projections
109
+ formatted_input= input_tensor.unsqueeze(0).to(HARDWARE_TARGET)
110
+ with torch.inference_mode():
111
+ raw_outputs = processor(formatted_input)
112
+ probabilities = F.softmax(raw_outputs, dim=1).cpu().numpy()[0]
113
+
114
+ # compile outputs
115
+ predicted_class = probabilities.argmax()
116
+ category_names = [str(i) for i in range(10)]+["Unknown"]
117
+
118
+ print("\n" + "="*30)
119
+ print(f"Sample Index : {sample_index}")
120
+ print(f"True Label : {ground_truth}")
121
+ print(f"Prediction : {category_names[predicted_class]}")
122
+ print(f"Confidence : {probabilities[predicted_class] * 100:.2f}%")
123
+ print("=" * 30)
124
+
125
+ if __name__ == "__main__":
126
+ execute()
127
+ ```
128
+
129
+ ---
130
+
131
+ **Released by MultivexAI** | Licensed under Apache-2.0