Habith commited on
Commit
2e9520d
·
verified ·
1 Parent(s): c84c4ba

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +273 -125
app.py CHANGED
@@ -8,119 +8,217 @@ from transformers import (
8
  TrainingArguments,
9
  Trainer
10
  )
11
- from datasets import load_dataset, Image
12
  import numpy as np
13
- from huggingface_hub import HfApi, create_repo
14
  import os
 
 
15
 
16
  # Configuration
17
- HF_DATASET = "Ultralytics/Brain-tumor" # Your dataset repo
18
- CUSTOM_MODEL_NAME = "GoGenix_MRI_Brain" # Your custom model name
19
  BASE_MODEL = "Falconsai/nsfw_image_detection"
20
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
21
 
22
- class MRIDataset(torch.utils.data.Dataset):
23
- def __init__(self, dataset, transform=None):
24
- self.dataset = dataset
25
- self.transform = transform
 
 
 
 
26
 
27
- def __len__(self):
28
- return len(self.dataset)
29
 
30
- def __getitem__(self, idx):
31
- item = self.dataset[idx]
32
- image = item['image']
33
- label = item['label']
34
 
35
- if self.transform:
36
- image = self.transform(image)
 
 
 
 
 
 
 
 
37
 
38
- return image, label
39
-
40
- def train_and_save_model():
41
- """Train the model and save as GoGenix_MRI_Brain"""
42
-
43
- # Load dataset from Hugging Face Hub
44
- print("Loading dataset from Hugging Face Hub...")
45
- dataset = load_dataset(HF_DATASET)
46
-
47
- # Get class names from dataset
48
- class_names = dataset['train'].features['label'].names
49
- print(f"Classes detected: {class_names}")
50
-
51
- # Define transforms for MRI images
52
- transform = transforms.Compose([
53
- transforms.Resize((224, 224)),
54
- transforms.Grayscale(num_output_channels=3),
55
- transforms.ToTensor(),
56
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
57
- ])
58
-
59
- # Create PyTorch datasets
60
- train_dataset = MRIDataset(dataset['train'], transform=transform)
61
- test_dataset = MRIDataset(dataset['test'], transform=transform)
62
-
63
- # Load base model
64
- print("Loading base model...")
65
- model = AutoModelForImageClassification.from_pretrained(
66
- BASE_MODEL,
67
- num_labels=len(class_names),
68
- ignore_mismatched_sizes=True
69
- )
70
- processor = AutoImageProcessor.from_pretrained(BASE_MODEL)
71
-
72
- model.to(DEVICE)
73
-
74
- # Training arguments
75
- training_args = TrainingArguments(
76
- output_dir="./results",
77
- num_train_epochs=10,
78
- per_device_train_batch_size=8,
79
- per_device_eval_batch_size=8,
80
- warmup_steps=500,
81
- weight_decay=0.01,
82
- logging_dir="./logs",
83
- logging_steps=10,
84
- evaluation_strategy="epoch",
85
- save_strategy="epoch",
86
- load_best_model_at_end=True,
87
- push_to_hub=True,
88
- hub_model_id=CUSTOM_MODEL_NAME,
89
- )
90
-
91
- # Custom compute_metrics function
92
- def compute_metrics(eval_pred):
93
- predictions, labels = eval_pred
94
- predictions = np.argmax(predictions, axis=1)
95
- return {"accuracy": (predictions == labels).mean()}
96
-
97
- # Create trainer
98
- trainer = Trainer(
99
- model=model,
100
- args=training_args,
101
- train_dataset=train_dataset,
102
- eval_dataset=test_dataset,
103
- compute_metrics=compute_metrics,
104
- )
105
-
106
- # Start training
107
- print("Starting training...")
108
- train_result = trainer.train()
109
-
110
- # Save metrics
111
- trainer.log_metrics("train", train_result.metrics)
112
- trainer.save_metrics("train", train_result.metrics)
113
- trainer.save_state()
114
-
115
- # Save model locally
116
- trainer.save_model(f"./{CUSTOM_MODEL_NAME}")
117
- processor.save_pretrained(f"./{CUSTOM_MODEL_NAME}")
118
-
119
- # Push to Hugging Face Hub
120
- print("Pushing model to Hugging Face Hub...")
121
- trainer.push_to_hub()
122
-
123
- return f"Training completed! Model saved as: {CUSTOM_MODEL_NAME}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
 
125
  def classify_mri(image):
126
  """Classify a new MRI image using the trained model"""
@@ -140,25 +238,47 @@ def classify_mri(image):
140
  outputs = model(**inputs)
141
  predictions = torch.nn.functional.softmax(outputs.logits, dim=-1)
142
 
143
- # Map to class names
144
- class_names = ["glioma", "meningioma", "no_tumor", "pituitary"]
145
- return {class_names[i]: float(predictions[0][i]) for i in range(len(class_names))}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
 
147
  except Exception as e:
148
- return f"Error: Model not trained yet or unavailable. Please train first."
149
 
150
  # Gradio Interface
151
  with gr.Blocks(title="GoGenix MRI Brain Tumor Classifier") as demo:
152
  gr.Markdown("# 🧠 GoGenix MRI Brain Tumor Classifier")
153
- gr.Markdown(f"Training **Falconsai/nsfw_image_detection** on Brain Tumor MRI dataset")
154
 
155
  with gr.Tab("🚀 Train Model"):
156
  gr.Markdown("### Train GoGenix_MRI_Brain Model")
157
- gr.Markdown(f"Dataset: `{HF_DATASET}`")
158
- gr.Markdown(f"Target Model: `{CUSTOM_MODEL_NAME}`")
159
 
160
- train_btn = gr.Button("Start Training", variant="primary")
161
- output_text = gr.Textbox(label="Training Status", lines=5)
 
 
 
 
162
 
163
  train_btn.click(
164
  fn=train_and_save_model,
@@ -166,26 +286,54 @@ with gr.Blocks(title="GoGenix MRI Brain Tumor Classifier") as demo:
166
  )
167
 
168
  with gr.Tab("🔍 Classify MRI"):
169
- gr.Markdown("### Upload MRI Image for Classification")
170
- image_input = gr.Image(type="pil", label="MRI Scan")
171
- classify_btn = gr.Button("Classify", variant="secondary")
172
- result = gr.Label(label="Tumor Classification Results")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
173
 
174
  classify_btn.click(
175
- fn=classify_mri,
176
  inputs=image_input,
177
- outputs=result
178
  )
179
 
180
- with gr.Tab("📊 Model Info"):
181
- gr.Markdown("### Model Information")
182
  gr.Markdown(f"""
183
- - **Base Model**: {BASE_MODEL}
184
- - **Custom Model**: {CUSTOM_MODEL_NAME}
185
- - **Dataset**: {HF_DATASET}
186
- - **Classes**: Glioma, Meningioma, No Tumor, Pituitary Tumor
187
- - **Device**: {DEVICE}
 
 
 
 
 
 
188
  """)
189
 
190
  if __name__ == "__main__":
191
- demo.launch(share=True)
 
8
  TrainingArguments,
9
  Trainer
10
  )
11
+ from datasets import load_dataset, Dataset, Image
12
  import numpy as np
13
+ from huggingface_hub import HfApi
14
  import os
15
+ import json
16
+ from PIL import Image as PILImage
17
 
18
  # Configuration
19
+ HF_DATASET = "Ultralytics/Brain-tumor"
20
+ CUSTOM_MODEL_NAME = "GoGenix_MRI_Brain"
21
  BASE_MODEL = "Falconsai/nsfw_image_detection"
22
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
23
 
24
+ def train_and_save_model():
25
+ """Train the model using YOLO format dataset"""
26
+
27
+ try:
28
+ print("Loading Ultralytics/Brain-tumor dataset (YOLO format)...")
29
+
30
+ # Load the dataset
31
+ dataset = load_dataset(HF_DATASET)
32
 
33
+ print(f"Dataset splits available: {list(dataset.keys())}")
 
34
 
35
+ # Check dataset structure
36
+ if 'valid' not in dataset or 'test' not in dataset:
37
+ return "❌ Error: Dataset must contain 'valid' and 'test' splits"
 
38
 
39
+ train_split = dataset['valid']
40
+ test_split = dataset['test']
41
+
42
+ print("Analyzing YOLO dataset structure...")
43
+
44
+ # For YOLO datasets, we need to check if images and labels are separate
45
+ # Let's examine the structure
46
+ if len(train_split) > 0:
47
+ sample = train_split[0]
48
+ print(f"Sample keys: {list(sample.keys())}")
49
 
50
+ # Check if it's YOLO format (has image path and labels path)
51
+ if 'image' in sample:
52
+ print(f"Image type: {type(sample['image'])}")
53
+ if 'label' in sample:
54
+ print(f"Label type: {type(sample['label'])}")
55
+ if isinstance(sample['label'], list) and len(sample['label']) > 0:
56
+ print(f"First label sample: {sample['label'][0]}")
57
+
58
+ # Since Ultralytics datasets are typically for object detection,
59
+ # we'll convert them to classification by checking if tumor is present
60
+ def yolo_to_classification(item):
61
+ """Convert YOLO object detection labels to classification labels"""
62
+ image = item['image']
63
+ labels = item.get('label', [])
64
+
65
+ # For binary classification: 0 = no tumor, 1 = tumor present
66
+ # If there are any labels (bounding boxes), it means tumor is present
67
+ has_tumor = 1 if labels and len(labels) > 0 else 0
68
+
69
+ return {
70
+ 'image': image,
71
+ 'label': has_tumor
72
+ }
73
+
74
+ # Apply conversion
75
+ print("Converting YOLO labels to classification...")
76
+ train_classification = train_split.map(yolo_to_classification)
77
+ test_classification = test_split.map(yolo_to_classification)
78
+
79
+ # Count tumor vs no_tumor
80
+ tumor_count = sum(1 for item in train_classification if item['label'] == 1)
81
+ no_tumor_count = sum(1 for item in train_classification if item['label'] == 0)
82
+
83
+ print(f"Training set - Tumors: {tumor_count}, No tumors: {no_tumor_count}")
84
+
85
+ # Define class names for binary classification
86
+ class_names = ["no_tumor", "tumor"]
87
+ num_classes = 2
88
+
89
+ print(f"Using binary classification: {class_names}")
90
+
91
+ # Define transforms for MRI images
92
+ transform = transforms.Compose([
93
+ transforms.Resize((224, 224)),
94
+ transforms.Grayscale(num_output_channels=3),
95
+ transforms.ToTensor(),
96
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
97
+ ])
98
+
99
+ # Custom dataset class
100
+ class MRIDataset(torch.utils.data.Dataset):
101
+ def __init__(self, dataset, transform=None):
102
+ self.dataset = dataset
103
+ self.transform = transform
104
+
105
+ def __len__(self):
106
+ return len(self.dataset)
107
+
108
+ def __getitem__(self, idx):
109
+ item = self.dataset[idx]
110
+ image = item['image']
111
+ label = item['label']
112
+
113
+ if self.transform:
114
+ # Ensure image is PIL Image
115
+ if not isinstance(image, PILImage.Image):
116
+ image = PILImage.fromarray(image)
117
+ image = self.transform(image)
118
+
119
+ return image, label
120
+
121
+ # Create dataset objects
122
+ train_dataset = MRIDataset(train_classification, transform=transform)
123
+ test_dataset = MRIDataset(test_classification, transform=transform)
124
+
125
+ print(f"Training samples: {len(train_dataset)}")
126
+ print(f"Test samples: {len(test_dataset)}")
127
+
128
+ # Load base model
129
+ print("Loading base model...")
130
+ model = AutoModelForImageClassification.from_pretrained(
131
+ BASE_MODEL,
132
+ num_labels=num_classes,
133
+ ignore_mismatched_sizes=True,
134
+ id2label={0: "no_tumor", 1: "tumor"},
135
+ label2id={"no_tumor": 0, "tumor": 1}
136
+ )
137
+ processor = AutoImageProcessor.from_pretrained(BASE_MODEL)
138
+ model.to(DEVICE)
139
+
140
+ # Training arguments
141
+ training_args = TrainingArguments(
142
+ output_dir="./results",
143
+ num_train_epochs=10,
144
+ per_device_train_batch_size=8,
145
+ per_device_eval_batch_size=8,
146
+ warmup_steps=500,
147
+ weight_decay=0.01,
148
+ logging_dir="./logs",
149
+ logging_steps=10,
150
+ evaluation_strategy="epoch",
151
+ save_strategy="epoch",
152
+ load_best_model_at_end=True,
153
+ push_to_hub=True,
154
+ hub_model_id=CUSTOM_MODEL_NAME,
155
+ remove_unused_columns=False,
156
+ )
157
+
158
+ # Metrics function
159
+ def compute_metrics(eval_pred):
160
+ predictions, labels = eval_pred
161
+ predictions = np.argmax(predictions, axis=1)
162
+ accuracy = (predictions == labels).mean()
163
+ return {"accuracy": accuracy}
164
+
165
+ # Create trainer
166
+ trainer = Trainer(
167
+ model=model,
168
+ args=training_args,
169
+ train_dataset=train_dataset,
170
+ eval_dataset=test_dataset,
171
+ compute_metrics=compute_metrics,
172
+ )
173
+
174
+ # Start training
175
+ print("Starting training...")
176
+ train_result = trainer.train()
177
+
178
+ # Save model locally
179
+ trainer.save_model(f"./{CUSTOM_MODEL_NAME}")
180
+ processor.save_pretrained(f"./{CUSTOM_MODEL_NAME}")
181
+
182
+ # Push to Hugging Face Hub
183
+ print("Pushing model to Hugging Face Hub...")
184
+ trainer.push_to_hub(commit_message="Train Brain Tumor classifier (YOLO to Classification)")
185
+
186
+ # Display training results
187
+ train_accuracy = train_result.metrics.get('train_accuracy', 'N/A')
188
+ eval_accuracy = train_result.metrics.get('eval_accuracy', 'N/A')
189
+
190
+ result_message = f"""
191
+ ✅ Training completed successfully!
192
+
193
+ Model: {CUSTOM_MODEL_NAME}
194
+ Dataset: {HF_DATASET} (YOLO format)
195
+ Task: Binary Classification (Tumor Detection)
196
+ Classes: {', '.join(class_names)}
197
+ Training Samples: {len(train_dataset)}
198
+ Test Samples: {len(test_dataset)}
199
+ Training Accuracy: {train_accuracy}
200
+ Validation Accuracy: {eval_accuracy}
201
+ Tumor/No-Tumor Ratio: {tumor_count}/{no_tumor_count}
202
+
203
+ Model has been saved and pushed to Hugging Face Hub.
204
+ """
205
+
206
+ return result_message
207
+
208
+ except Exception as e:
209
+ import traceback
210
+ error_details = traceback.format_exc()
211
+
212
+ error_message = f"""
213
+ ❌ Error during training:
214
+
215
+ Error Type: {type(e).__name__}
216
+ Error Message: {str(e)}
217
+
218
+ Detailed Traceback:
219
+ {error_details}
220
+ """
221
+ return error_message
222
 
223
  def classify_mri(image):
224
  """Classify a new MRI image using the trained model"""
 
238
  outputs = model(**inputs)
239
  predictions = torch.nn.functional.softmax(outputs.logits, dim=-1)
240
 
241
+ # Binary classification results
242
+ class_names = ["No Tumor", "Tumor Detected"]
243
+ results = {
244
+ class_names[0]: float(predictions[0][0]), # No tumor probability
245
+ class_names[1]: float(predictions[0][1]) # Tumor probability
246
+ }
247
+
248
+ # Add diagnostic information
249
+ tumor_prob = float(predictions[0][1])
250
+ if tumor_prob > 0.7:
251
+ diagnosis = "🟢 Likely no tumor"
252
+ elif tumor_prob > 0.3:
253
+ diagnosis = "🟡 Uncertain - consult specialist"
254
+ else:
255
+ diagnosis = "🔴 Possible tumor detected"
256
+
257
+ return {
258
+ "classification": results,
259
+ "diagnosis": diagnosis,
260
+ "tumor_probability": tumor_prob
261
+ }
262
 
263
  except Exception as e:
264
+ return f"⚠️ Model not trained yet or unavailable. Error: {str(e)}"
265
 
266
  # Gradio Interface
267
  with gr.Blocks(title="GoGenix MRI Brain Tumor Classifier") as demo:
268
  gr.Markdown("# 🧠 GoGenix MRI Brain Tumor Classifier")
269
+ gr.Markdown(f"**Dataset**: {HF_DATASET} (YOLO Format) | **Base Model**: {BASE_MODEL}")
270
 
271
  with gr.Tab("🚀 Train Model"):
272
  gr.Markdown("### Train GoGenix_MRI_Brain Model")
273
+ gr.Markdown(f"Using YOLO format dataset: `{HF_DATASET}`")
274
+ gr.Markdown("**Note**: Converting object detection labels to binary classification")
275
 
276
+ train_btn = gr.Button("Start Training", variant="primary", size="lg")
277
+ output_text = gr.Textbox(
278
+ label="Training Status",
279
+ lines=20,
280
+ placeholder="Click 'Start Training' to begin..."
281
+ )
282
 
283
  train_btn.click(
284
  fn=train_and_save_model,
 
286
  )
287
 
288
  with gr.Tab("🔍 Classify MRI"):
289
+ gr.Markdown("### Upload MRI Image for Tumor Detection")
290
+ gr.Markdown("**Binary Classification**: Tumor vs No Tumor")
291
+
292
+ image_input = gr.Image(
293
+ type="pil",
294
+ label="Brain MRI Scan",
295
+ height=300
296
+ )
297
+ classify_btn = gr.Button("Analyze Scan", variant="secondary")
298
+
299
+ with gr.Row():
300
+ result_label = gr.Label(
301
+ label="Classification Results",
302
+ num_top_classes=2
303
+ )
304
+ diagnosis_text = gr.Textbox(
305
+ label="Diagnostic Suggestion",
306
+ interactive=False
307
+ )
308
+
309
+ def process_classification(image):
310
+ result = classify_mri(image)
311
+ if isinstance(result, dict) and 'classification' in result:
312
+ return result['classification'], result.get('diagnosis', '')
313
+ else:
314
+ return {"Error": 1.0}, result
315
 
316
  classify_btn.click(
317
+ fn=process_classification,
318
  inputs=image_input,
319
+ outputs=[result_label, diagnosis_text]
320
  )
321
 
322
+ with gr.Tab("📊 Dataset Info"):
323
+ gr.Markdown("### YOLO Dataset Information")
324
  gr.Markdown(f"""
325
+ **Dataset**: {HF_DATASET}
326
+ **Format**: YOLO (You Only Look Once) Object Detection
327
+ **Original Structure**:
328
+ - `images/` folder: Contains MRI scans
329
+ - `labels/` folder: Contains bounding box annotations
330
+
331
+ **Converted to**: Binary Classification
332
+ - **No Tumor**: No bounding boxes in labels
333
+ - **Tumor**: One or more bounding boxes present
334
+
335
+ **Splits**: test, valid
336
  """)
337
 
338
  if __name__ == "__main__":
339
+ demo.launch()