sukhmani1303 commited on
Commit
248c98e
·
verified ·
1 Parent(s): cfef993

Upload tuberculosis ViT model with complete configuration

Browse files
Files changed (7) hide show
  1. README.md +115 -0
  2. config.json +16 -0
  3. handler.py +159 -0
  4. metadata.json +17 -0
  5. model.pt +3 -0
  6. pytorch_model.bin +3 -0
  7. requirements.txt +5 -0
README.md ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language: en
3
+ license: apache-2.0
4
+ tags:
5
+ - medical
6
+ - chest-xray
7
+ - tuberculosis
8
+ - vision-transformer
9
+ - pytorch
10
+ - image-classification
11
+ metrics:
12
+ - accuracy
13
+ - precision
14
+ - recall
15
+ - f1
16
+ pipeline_tag: image-classification
17
+ ---
18
+
19
+ # Tuberculosis Detection ViT Model
20
+
21
+ This repository contains a Vision Transformer (ViT) model for classifying chest X-ray images as **Normal** or **Tuberculosis**.
22
+
23
+ ## Model Details
24
+
25
+ - **Model Type**: Vision Transformer (ViT)
26
+ - **Framework**: PyTorch 2.0.1
27
+ - **Architecture**:
28
+ - Patch size: 16
29
+ - Number of patches: 196
30
+ - Projection dimension: 768
31
+ - Number of attention heads: 12
32
+ - Number of encoder blocks: 12
33
+ - **Input Size**: 224x224 RGB images
34
+ - **Classes**: Normal, Tuberculosis
35
+
36
+ ## Dataset
37
+ - **Source**: TB Chest Radiography Database
38
+ - **Preprocessing**: Grayscale conversion, CLAHE enhancement, Gaussian blur, RGB conversion, standardization
39
+ - **Training**: Early stopping with validation loss monitoring
40
+
41
+ ## Usage
42
+
43
+ ### Quick Start
44
+ ```python
45
+ from huggingface_hub import hf_hub_download
46
+ import torch
47
+ import cv2
48
+ import numpy as np
49
+
50
+ # Download the handler
51
+ handler_path = hf_hub_download(repo_id="sukhmani1303/tuberculosis-vit-model", filename="handler.py")
52
+
53
+ # Import and use
54
+ exec(open(handler_path).read())
55
+ classifier = TBClassifier()
56
+
57
+ # Load and predict
58
+ image = cv2.imread("path/to/chest_xray.jpg")
59
+ result = classifier.predict(image)
60
+ print(f"Prediction: {result['prediction']}, Confidence: {result['confidence']:.4f}")
61
+ ```
62
+
63
+ ### Loading the PyTorch Model
64
+ ```python
65
+ from huggingface_hub import hf_hub_download
66
+ import torch
67
+ import json
68
+
69
+ # Download model files
70
+ config_path = hf_hub_download(repo_id="sukhmani1303/tuberculosis-vit-model", filename="config.json")
71
+ model_path = hf_hub_download(repo_id="sukhmani1303/tuberculosis-vit-model", filename="pytorch_model.bin")
72
+
73
+ # Load configuration
74
+ with open(config_path, 'r') as f:
75
+ config = json.load(f)
76
+
77
+ # Initialize your ViT model class with config
78
+ # model = ViT(**config) # You'll need to have your ViT class available
79
+ # model.load_state_dict(torch.load(model_path, map_location='cpu'))
80
+ # model.eval()
81
+ ```
82
+
83
+ ## Performance
84
+ - **Training**: Early stopping with validation monitoring
85
+ - **Evaluation**: Tested on held-out validation set
86
+ - **Metrics**: Accuracy, Precision, Recall, F1-score available upon request
87
+
88
+ ## Preprocessing Pipeline
89
+ 1. Convert RGB to grayscale
90
+ 2. Apply CLAHE (Contrast Limited Adaptive Histogram Equalization)
91
+ 3. Apply Gaussian blur (5x5 kernel)
92
+ 4. Resize to 224x224
93
+ 5. Convert back to RGB format
94
+ 6. Normalize with z-score standardization
95
+
96
+ ## Medical Disclaimer
97
+ ⚠️ **IMPORTANT**: This model is for educational and research purposes only. It should not be used for actual medical diagnosis. Always consult qualified healthcare professionals for medical advice and diagnosis.
98
+
99
+ ## Citation
100
+ If you use this model in your research, please cite:
101
+ ```
102
+ @misc{tuberculosis-vit-model,
103
+ author = {Sukhmani},
104
+ title = {Tuberculosis Detection using Vision Transformer},
105
+ year = {2024},
106
+ publisher = {Hugging Face},
107
+ url = {https://huggingface.co/sukhmani1303/tuberculosis-vit-model}
108
+ }
109
+ ```
110
+
111
+ ## License
112
+ Apache 2.0
113
+
114
+ ## Contact
115
+ For questions or issues, please open an issue in the repository.
config.json ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "vit",
3
+ "patch_size": 16,
4
+ "num_patches": 196,
5
+ "projection_dim": 768,
6
+ "num_heads": 12,
7
+ "num_encoder": 12,
8
+ "input_size": 224,
9
+ "num_classes": 1,
10
+ "class_names": [
11
+ "Normal",
12
+ "Tuberculosis"
13
+ ],
14
+ "framework": "pytorch",
15
+ "torch_dtype": "float32"
16
+ }
handler.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ import cv2
4
+ import numpy as np
5
+ import os
6
+
7
+ class TBClassifier:
8
+ """
9
+ Tuberculosis classifier using Vision Transformer
10
+ """
11
+ def __init__(self, model_path="model.pt", config_path="config.json"):
12
+ """
13
+ Initialize the classifier
14
+
15
+ Args:
16
+ model_path: Path to the TorchScript model file
17
+ config_path: Path to the configuration file
18
+ """
19
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20
+
21
+ # Load configuration if available
22
+ self.config = None
23
+ if os.path.exists(config_path):
24
+ import json
25
+ with open(config_path, 'r') as f:
26
+ self.config = json.load(f)
27
+
28
+ # Load model
29
+ try:
30
+ self.model = torch.jit.load(model_path, map_location=self.device)
31
+ self.model.eval()
32
+ print(f"Model loaded successfully on {self.device}")
33
+ except Exception as e:
34
+ raise RuntimeError(f"Failed to load model: {str(e)}")
35
+
36
+ # Model configuration
37
+ self.class_names = self.config.get('class_names', ['Normal', 'Tuberculosis']) if self.config else ['Normal', 'Tuberculosis']
38
+ self.img_size = self.config.get('input_size', 224) if self.config else 224
39
+
40
+ print(f"Classifier initialized with classes: {self.class_names}")
41
+
42
+ def preprocess(self, image):
43
+ """
44
+ Preprocess input image for model inference
45
+
46
+ Args:
47
+ image: Input image as numpy array (BGR or RGB)
48
+
49
+ Returns:
50
+ Preprocessed tensor ready for model inference
51
+ """
52
+ try:
53
+ # Handle different input formats
54
+ if isinstance(image, str):
55
+ image = cv2.imread(image)
56
+
57
+ if image is None:
58
+ raise ValueError("Invalid image input")
59
+
60
+ # Convert to grayscale if needed
61
+ if len(image.shape) == 3 and image.shape[2] == 3:
62
+ # Assume BGR format from cv2
63
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
64
+
65
+ # Apply CLAHE for contrast enhancement
66
+ clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))
67
+ image = clahe.apply(image)
68
+
69
+ # Apply Gaussian blur for noise reduction
70
+ image = cv2.GaussianBlur(image, (5, 5), 0)
71
+
72
+ # Resize to model input size
73
+ image = cv2.resize(image, (self.img_size, self.img_size), interpolation=cv2.INTER_LINEAR)
74
+
75
+ # Convert back to RGB format
76
+ image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
77
+
78
+ # Convert to tensor format (C, H, W)
79
+ image = np.moveaxis(image, -1, 0).astype(np.float32)
80
+
81
+ # Normalize
82
+ image = (image - image.mean()) / (image.std() + 1e-8)
83
+
84
+ # Add batch dimension and move to device
85
+ tensor = torch.tensor(image).unsqueeze(0).to(self.device)
86
+
87
+ return tensor
88
+
89
+ except Exception as e:
90
+ raise RuntimeError(f"Preprocessing failed: {str(e)}")
91
+
92
+ def predict(self, image, return_probs=False):
93
+ """
94
+ Predict tuberculosis from chest X-ray image
95
+
96
+ Args:
97
+ image: Input image (file path, numpy array, or PIL image)
98
+ return_probs: Whether to return raw probabilities
99
+
100
+ Returns:
101
+ Dictionary with prediction results
102
+ """
103
+ try:
104
+ # Preprocess image
105
+ processed_image = self.preprocess(image)
106
+
107
+ # Model inference
108
+ with torch.no_grad():
109
+ output = self.model(processed_image)
110
+
111
+ # Handle different output formats
112
+ if len(output.shape) > 1:
113
+ output = output.squeeze(-1)
114
+
115
+ prob = torch.sigmoid(output).item() # Ensure probability is in [0,1]
116
+
117
+ # Determine class
118
+ class_id = 1 if prob > 0.5 else 0
119
+ confidence = prob if class_id == 1 else 1 - prob
120
+ prediction = self.class_names[class_id]
121
+
122
+ result = {
123
+ "prediction": prediction,
124
+ "confidence": float(confidence),
125
+ "class_id": class_id
126
+ }
127
+
128
+ if return_probs:
129
+ result["raw_probability"] = float(prob)
130
+ result["probabilities"] = {
131
+ self.class_names[0]: float(1 - prob),
132
+ self.class_names[1]: float(prob)
133
+ }
134
+
135
+ return result
136
+
137
+ except Exception as e:
138
+ return {
139
+ "error": str(e),
140
+ "prediction": None,
141
+ "confidence": None
142
+ }
143
+
144
+ def batch_predict(self, images, return_probs=False):
145
+ """
146
+ Predict on multiple images
147
+
148
+ Args:
149
+ images: List of images
150
+ return_probs: Whether to return raw probabilities
151
+
152
+ Returns:
153
+ List of prediction results
154
+ """
155
+ results = []
156
+ for img in images:
157
+ result = self.predict(img, return_probs=return_probs)
158
+ results.append(result)
159
+ return results
metadata.json ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_name": "Tuberculosis ViT Classifier",
3
+ "version": "1.0.0",
4
+ "author": "sukhmani1303",
5
+ "description": "Vision Transformer for tuberculosis detection in chest X-rays",
6
+ "tags": [
7
+ "medical",
8
+ "chest-xray",
9
+ "tuberculosis",
10
+ "vision-transformer",
11
+ "pytorch"
12
+ ],
13
+ "license": "apache-2.0",
14
+ "framework": "pytorch",
15
+ "model_size_mb": 654.1593074798584,
16
+ "created_date": "2025-06-08"
17
+ }
model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8ceea603a3fd53322fcefff6a15cf0d326f55e80235f33bb98d12a62f41cba00
3
+ size 343282923
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:aa2914422445a137de9749565af1a8d5e29327de2abb5411b6c44ac214f36578
3
+ size 342643362
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ torch>=2.0.0,<3.0.0
2
+ torchvision>=0.15.0
3
+ opencv-python-headless>=4.8.0
4
+ numpy>=1.24.0
5
+ huggingface-hub>=0.16.0