Habith commited on
Commit
783b5ee
·
verified ·
1 Parent(s): 2e9520d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +133 -170
app.py CHANGED
@@ -8,11 +8,10 @@ from transformers import (
8
  TrainingArguments,
9
  Trainer
10
  )
11
- from datasets import load_dataset, Dataset, Image
12
  import numpy as np
13
  from huggingface_hub import HfApi
14
  import os
15
- import json
16
  from PIL import Image as PILImage
17
 
18
  # Configuration
@@ -22,73 +21,136 @@ BASE_MODEL = "Falconsai/nsfw_image_detection"
22
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
23
 
24
  def train_and_save_model():
25
- """Train the model using YOLO format dataset"""
26
 
27
  try:
28
- print("Loading Ultralytics/Brain-tumor dataset (YOLO format)...")
29
 
30
- # Load the dataset
31
- dataset = load_dataset(HF_DATASET)
32
 
33
- print(f"Dataset splits available: {list(dataset.keys())}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
- # Check dataset structure
36
- if 'valid' not in dataset or 'test' not in dataset:
37
- return "❌ Error: Dataset must contain 'valid' and 'test' splits"
38
 
39
- train_split = dataset['valid']
40
- test_split = dataset['test']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
- print("Analyzing YOLO dataset structure...")
 
43
 
44
- # For YOLO datasets, we need to check if images and labels are separate
45
- # Let's examine the structure
46
  if len(train_split) > 0:
47
  sample = train_split[0]
48
  print(f"Sample keys: {list(sample.keys())}")
 
 
 
 
 
 
 
 
49
 
50
- # Check if it's YOLO format (has image path and labels path)
51
- if 'image' in sample:
52
- print(f"Image type: {type(sample['image'])}")
53
- if 'label' in sample:
54
- print(f"Label type: {type(sample['label'])}")
55
- if isinstance(sample['label'], list) and len(sample['label']) > 0:
56
- print(f"First label sample: {sample['label'][0]}")
57
-
58
- # Since Ultralytics datasets are typically for object detection,
59
- # we'll convert them to classification by checking if tumor is present
60
- def yolo_to_classification(item):
61
- """Convert YOLO object detection labels to classification labels"""
62
- image = item['image']
63
- labels = item.get('label', [])
 
 
 
 
 
 
64
 
65
- # For binary classification: 0 = no tumor, 1 = tumor present
66
- # If there are any labels (bounding boxes), it means tumor is present
67
- has_tumor = 1 if labels and len(labels) > 0 else 0
 
 
 
 
 
 
 
 
68
 
69
  return {
70
  'image': image,
71
- 'label': has_tumor
72
  }
73
 
74
- # Apply conversion
75
- print("Converting YOLO labels to classification...")
76
- train_classification = train_split.map(yolo_to_classification)
77
- test_classification = test_split.map(yolo_to_classification)
78
 
79
- # Count tumor vs no_tumor
80
  tumor_count = sum(1 for item in train_classification if item['label'] == 1)
81
  no_tumor_count = sum(1 for item in train_classification if item['label'] == 0)
82
 
83
- print(f"Training set - Tumors: {tumor_count}, No tumors: {no_tumor_count}")
84
 
85
- # Define class names for binary classification
86
  class_names = ["no_tumor", "tumor"]
87
  num_classes = 2
88
 
89
- print(f"Using binary classification: {class_names}")
90
 
91
- # Define transforms for MRI images
92
  transform = transforms.Compose([
93
  transforms.Resize((224, 224)),
94
  transforms.Grayscale(num_output_channels=3),
@@ -111,21 +173,17 @@ def train_and_save_model():
111
  label = item['label']
112
 
113
  if self.transform:
114
- # Ensure image is PIL Image
115
  if not isinstance(image, PILImage.Image):
116
  image = PILImage.fromarray(image)
117
  image = self.transform(image)
118
 
119
  return image, label
120
 
121
- # Create dataset objects
122
  train_dataset = MRIDataset(train_classification, transform=transform)
123
  test_dataset = MRIDataset(test_classification, transform=transform)
124
 
125
- print(f"Training samples: {len(train_dataset)}")
126
- print(f"Test samples: {len(test_dataset)}")
127
-
128
- # Load base model
129
  print("Loading base model...")
130
  model = AutoModelForImageClassification.from_pretrained(
131
  BASE_MODEL,
@@ -140,10 +198,10 @@ def train_and_save_model():
140
  # Training arguments
141
  training_args = TrainingArguments(
142
  output_dir="./results",
143
- num_train_epochs=10,
144
  per_device_train_batch_size=8,
145
  per_device_eval_batch_size=8,
146
- warmup_steps=500,
147
  weight_decay=0.01,
148
  logging_dir="./logs",
149
  logging_steps=10,
@@ -155,7 +213,7 @@ def train_and_save_model():
155
  remove_unused_columns=False,
156
  )
157
 
158
- # Metrics function
159
  def compute_metrics(eval_pred):
160
  predictions, labels = eval_pred
161
  predictions = np.argmax(predictions, axis=1)
@@ -175,165 +233,70 @@ def train_and_save_model():
175
  print("Starting training...")
176
  train_result = trainer.train()
177
 
178
- # Save model locally
179
  trainer.save_model(f"./{CUSTOM_MODEL_NAME}")
180
  processor.save_pretrained(f"./{CUSTOM_MODEL_NAME}")
181
 
182
- # Push to Hugging Face Hub
183
- print("Pushing model to Hugging Face Hub...")
184
- trainer.push_to_hub(commit_message="Train Brain Tumor classifier (YOLO to Classification)")
185
 
186
- # Display training results
187
  train_accuracy = train_result.metrics.get('train_accuracy', 'N/A')
188
  eval_accuracy = train_result.metrics.get('eval_accuracy', 'N/A')
189
 
190
  result_message = f"""
191
- ✅ Training completed successfully!
192
 
193
  Model: {CUSTOM_MODEL_NAME}
194
- Dataset: {HF_DATASET} (YOLO format)
195
- Task: Binary Classification (Tumor Detection)
196
- Classes: {', '.join(class_names)}
197
- Training Samples: {len(train_dataset)}
198
- Test Samples: {len(test_dataset)}
199
  Training Accuracy: {train_accuracy}
200
  Validation Accuracy: {eval_accuracy}
201
- Tumor/No-Tumor Ratio: {tumor_count}/{no_tumor_count}
202
-
203
- Model has been saved and pushed to Hugging Face Hub.
204
  """
205
 
206
  return result_message
207
 
208
  except Exception as e:
209
  import traceback
210
- error_details = traceback.format_exc()
211
-
212
- error_message = f"""
213
- ❌ Error during training:
214
-
215
- Error Type: {type(e).__name__}
216
- Error Message: {str(e)}
217
-
218
- Detailed Traceback:
219
- {error_details}
220
- """
221
- return error_message
222
 
223
  def classify_mri(image):
224
- """Classify a new MRI image using the trained model"""
225
  try:
226
- # Load your custom model
227
  model = AutoModelForImageClassification.from_pretrained(CUSTOM_MODEL_NAME)
228
  processor = AutoImageProcessor.from_pretrained(CUSTOM_MODEL_NAME)
229
 
230
  model.to(DEVICE)
231
  model.eval()
232
 
233
- # Preprocess image
234
  inputs = processor(image, return_tensors="pt").to(DEVICE)
235
 
236
- # Predict
237
  with torch.no_grad():
238
  outputs = model(**inputs)
239
  predictions = torch.nn.functional.softmax(outputs.logits, dim=-1)
240
 
241
- # Binary classification results
242
  class_names = ["No Tumor", "Tumor Detected"]
243
- results = {
244
- class_names[0]: float(predictions[0][0]), # No tumor probability
245
- class_names[1]: float(predictions[0][1]) # Tumor probability
246
- }
247
-
248
- # Add diagnostic information
249
- tumor_prob = float(predictions[0][1])
250
- if tumor_prob > 0.7:
251
- diagnosis = "🟢 Likely no tumor"
252
- elif tumor_prob > 0.3:
253
- diagnosis = "🟡 Uncertain - consult specialist"
254
- else:
255
- diagnosis = "🔴 Possible tumor detected"
256
 
257
- return {
258
- "classification": results,
259
- "diagnosis": diagnosis,
260
- "tumor_probability": tumor_prob
261
- }
262
 
263
  except Exception as e:
264
- return f"⚠️ Model not trained yet or unavailable. Error: {str(e)}"
265
 
266
- # Gradio Interface
267
- with gr.Blocks(title="GoGenix MRI Brain Tumor Classifier") as demo:
268
- gr.Markdown("# 🧠 GoGenix MRI Brain Tumor Classifier")
269
- gr.Markdown(f"**Dataset**: {HF_DATASET} (YOLO Format) | **Base Model**: {BASE_MODEL}")
270
-
271
- with gr.Tab("🚀 Train Model"):
272
- gr.Markdown("### Train GoGenix_MRI_Brain Model")
273
- gr.Markdown(f"Using YOLO format dataset: `{HF_DATASET}`")
274
- gr.Markdown("**Note**: Converting object detection labels to binary classification")
275
-
276
- train_btn = gr.Button("Start Training", variant="primary", size="lg")
277
- output_text = gr.Textbox(
278
- label="Training Status",
279
- lines=20,
280
- placeholder="Click 'Start Training' to begin..."
281
- )
282
-
283
- train_btn.click(
284
- fn=train_and_save_model,
285
- outputs=output_text
286
- )
287
 
288
- with gr.Tab("🔍 Classify MRI"):
289
- gr.Markdown("### Upload MRI Image for Tumor Detection")
290
- gr.Markdown("**Binary Classification**: Tumor vs No Tumor")
291
-
292
- image_input = gr.Image(
293
- type="pil",
294
- label="Brain MRI Scan",
295
- height=300
296
- )
297
- classify_btn = gr.Button("Analyze Scan", variant="secondary")
298
-
299
- with gr.Row():
300
- result_label = gr.Label(
301
- label="Classification Results",
302
- num_top_classes=2
303
- )
304
- diagnosis_text = gr.Textbox(
305
- label="Diagnostic Suggestion",
306
- interactive=False
307
- )
308
-
309
- def process_classification(image):
310
- result = classify_mri(image)
311
- if isinstance(result, dict) and 'classification' in result:
312
- return result['classification'], result.get('diagnosis', '')
313
- else:
314
- return {"Error": 1.0}, result
315
-
316
- classify_btn.click(
317
- fn=process_classification,
318
- inputs=image_input,
319
- outputs=[result_label, diagnosis_text]
320
- )
321
 
322
- with gr.Tab("📊 Dataset Info"):
323
- gr.Markdown("### YOLO Dataset Information")
324
- gr.Markdown(f"""
325
- **Dataset**: {HF_DATASET}
326
- **Format**: YOLO (You Only Look Once) Object Detection
327
- **Original Structure**:
328
- - `images/` folder: Contains MRI scans
329
- - `labels/` folder: Contains bounding box annotations
330
-
331
- **Converted to**: Binary Classification
332
- - **No Tumor**: No bounding boxes in labels
333
- - **Tumor**: One or more bounding boxes present
334
-
335
- **Splits**: test, valid
336
- """)
337
 
338
  if __name__ == "__main__":
339
  demo.launch()
 
8
  TrainingArguments,
9
  Trainer
10
  )
11
+ from datasets import load_dataset, DatasetDict
12
  import numpy as np
13
  from huggingface_hub import HfApi
14
  import os
 
15
  from PIL import Image as PILImage
16
 
17
  # Configuration
 
21
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
22
 
23
  def train_and_save_model():
24
+ """Train the model with explicit dataset format handling"""
25
 
26
  try:
27
+ print("Loading Ultralytics/Brain-tumor dataset with explicit format...")
28
 
29
+ # Try multiple loading methods to handle format detection issues
30
+ dataset = None
31
 
32
+ # Method 1: Try loading with explicit imagefolder format for all splits
33
+ try:
34
+ dataset = load_dataset(HF_DATASET, "imagefolder")
35
+ print("✅ Loaded with 'imagefolder' format")
36
+ except Exception as e1:
37
+ print(f"❌ Method 1 failed: {e1}")
38
+
39
+ # Method 2: Try loading without specific format
40
+ try:
41
+ dataset = load_dataset(HF_DATASET)
42
+ print("✅ Loaded without specific format")
43
+ except Exception as e2:
44
+ print(f"❌ Method 2 failed: {e2}")
45
+
46
+ # Method 3: Try loading with data_files specification
47
+ try:
48
+ dataset = load_dataset(
49
+ HF_DATASET,
50
+ data_files={
51
+ 'train': ['**/train/**/*.jpg', '**/train/**/*.png', '**/train/**/*.jpeg'],
52
+ 'validation': ['**/valid/**/*.jpg', '**/valid/**/*.png', '**/valid/**/*.jpeg'],
53
+ 'test': ['**/test/**/*.jpg', '**/test/**/*.png', '**/test/**/*.jpeg']
54
+ }
55
+ )
56
+ print("✅ Loaded with data_files specification")
57
+ except Exception as e3:
58
+ print(f"❌ Method 3 failed: {e3}")
59
+ return f"All loading methods failed:\n1. {e1}\n2. {e2}\n3. {e3}"
60
+
61
+ if dataset is None:
62
+ return "❌ Could not load dataset with any method"
63
 
64
+ print(f"Dataset splits available: {list(dataset.keys())}")
 
 
65
 
66
+ # Check which splits we have and map them appropriately
67
+ if 'train' in dataset and 'validation' in dataset:
68
+ train_split = dataset['train']
69
+ test_split = dataset['validation']
70
+ print("Using 'train' and 'validation' splits")
71
+ elif 'valid' in dataset and 'test' in dataset:
72
+ train_split = dataset['valid']
73
+ test_split = dataset['test']
74
+ print("Using 'valid' and 'test' splits")
75
+ elif 'train' in dataset and 'test' in dataset:
76
+ train_split = dataset['train']
77
+ test_split = dataset['test']
78
+ print("Using 'train' and 'test' splits")
79
+ else:
80
+ available_splits = list(dataset.keys())
81
+ return f"❌ Cannot determine train/test splits. Available splits: {available_splits}"
82
 
83
+ print(f"Training samples: {len(train_split)}")
84
+ print(f"Test samples: {len(test_split)}")
85
 
86
+ # Analyze dataset structure
 
87
  if len(train_split) > 0:
88
  sample = train_split[0]
89
  print(f"Sample keys: {list(sample.keys())}")
90
+ for key in sample.keys():
91
+ print(f" {key}: {type(sample[key])}")
92
+
93
+ # Determine if this is a classification or object detection dataset
94
+ # For Ultralytics datasets, check if it has object detection format
95
+ def detect_dataset_type(split):
96
+ if len(split) == 0:
97
+ return "empty"
98
 
99
+ sample = split[0]
100
+ if 'objects' in sample or 'bbox' in sample or 'labels' in sample and isinstance(sample.get('labels'), list):
101
+ return "object_detection"
102
+ elif 'label' in sample and isinstance(sample['label'], (int, float)):
103
+ return "classification"
104
+ elif 'image' in sample:
105
+ return "image_only"
106
+ else:
107
+ return "unknown"
108
+
109
+ train_type = detect_dataset_type(train_split)
110
+ test_type = detect_dataset_type(test_split)
111
+
112
+ print(f"Train dataset type: {train_type}")
113
+ print(f"Test dataset type: {test_type}")
114
+
115
+ # Convert to classification format
116
+ def convert_to_classification(item):
117
+ """Convert various formats to classification format"""
118
+ image = item.get('image')
119
 
120
+ # Handle different label formats
121
+ if 'label' in item and isinstance(item['label'], (int, float)):
122
+ label = int(item['label'])
123
+ elif 'objects' in item or 'bbox' in item:
124
+ # Object detection format - convert to binary classification
125
+ # If there are objects/bboxes, it's tumor (1), else no tumor (0)
126
+ label = 1 if (item.get('objects') or item.get('bbox')) else 0
127
+ elif 'labels' in item and isinstance(item['labels'], list) and len(item['labels']) > 0:
128
+ label = 1 # Has labels = tumor
129
+ else:
130
+ label = 0 # No labels = no tumor
131
 
132
  return {
133
  'image': image,
134
+ 'label': label
135
  }
136
 
137
+ print("Converting dataset to classification format...")
138
+ train_classification = train_split.map(convert_to_classification)
139
+ test_classification = test_split.map(convert_to_classification)
 
140
 
141
+ # Count classes
142
  tumor_count = sum(1 for item in train_classification if item['label'] == 1)
143
  no_tumor_count = sum(1 for item in train_classification if item['label'] == 0)
144
 
145
+ print(f"Tumor samples: {tumor_count}, No tumor samples: {no_tumor_count}")
146
 
147
+ # Use binary classification
148
  class_names = ["no_tumor", "tumor"]
149
  num_classes = 2
150
 
151
+ print(f"Using {num_classes} classes: {class_names}")
152
 
153
+ # Define transforms
154
  transform = transforms.Compose([
155
  transforms.Resize((224, 224)),
156
  transforms.Grayscale(num_output_channels=3),
 
173
  label = item['label']
174
 
175
  if self.transform:
 
176
  if not isinstance(image, PILImage.Image):
177
  image = PILImage.fromarray(image)
178
  image = self.transform(image)
179
 
180
  return image, label
181
 
182
+ # Create datasets
183
  train_dataset = MRIDataset(train_classification, transform=transform)
184
  test_dataset = MRIDataset(test_classification, transform=transform)
185
 
186
+ # Load model
 
 
 
187
  print("Loading base model...")
188
  model = AutoModelForImageClassification.from_pretrained(
189
  BASE_MODEL,
 
198
  # Training arguments
199
  training_args = TrainingArguments(
200
  output_dir="./results",
201
+ num_train_epochs=5, # Reduced for testing
202
  per_device_train_batch_size=8,
203
  per_device_eval_batch_size=8,
204
+ warmup_steps=100,
205
  weight_decay=0.01,
206
  logging_dir="./logs",
207
  logging_steps=10,
 
213
  remove_unused_columns=False,
214
  )
215
 
216
+ # Metrics
217
  def compute_metrics(eval_pred):
218
  predictions, labels = eval_pred
219
  predictions = np.argmax(predictions, axis=1)
 
233
  print("Starting training...")
234
  train_result = trainer.train()
235
 
236
+ # Save model
237
  trainer.save_model(f"./{CUSTOM_MODEL_NAME}")
238
  processor.save_pretrained(f"./{CUSTOM_MODEL_NAME}")
239
 
240
+ # Push to hub
241
+ trainer.push_to_hub(commit_message="Train Brain Tumor classifier")
 
242
 
243
+ # Results
244
  train_accuracy = train_result.metrics.get('train_accuracy', 'N/A')
245
  eval_accuracy = train_result.metrics.get('eval_accuracy', 'N/A')
246
 
247
  result_message = f"""
248
+ ✅ Training completed!
249
 
250
  Model: {CUSTOM_MODEL_NAME}
251
+ Dataset: {HF_DATASET}
252
+ Classes: {class_names}
 
 
 
253
  Training Accuracy: {train_accuracy}
254
  Validation Accuracy: {eval_accuracy}
 
 
 
255
  """
256
 
257
  return result_message
258
 
259
  except Exception as e:
260
  import traceback
261
+ return f"❌ Error: {str(e)}\n\n{traceback.format_exc()}"
 
 
 
 
 
 
 
 
 
 
 
262
 
263
  def classify_mri(image):
264
+ """Classify MRI image"""
265
  try:
 
266
  model = AutoModelForImageClassification.from_pretrained(CUSTOM_MODEL_NAME)
267
  processor = AutoImageProcessor.from_pretrained(CUSTOM_MODEL_NAME)
268
 
269
  model.to(DEVICE)
270
  model.eval()
271
 
 
272
  inputs = processor(image, return_tensors="pt").to(DEVICE)
273
 
 
274
  with torch.no_grad():
275
  outputs = model(**inputs)
276
  predictions = torch.nn.functional.softmax(outputs.logits, dim=-1)
277
 
 
278
  class_names = ["No Tumor", "Tumor Detected"]
279
+ results = {class_names[i]: float(predictions[0][i]) for i in range(2)}
 
 
 
 
 
 
 
 
 
 
 
 
280
 
281
+ return results
 
 
 
 
282
 
283
  except Exception as e:
284
+ return f"⚠️ Error: {str(e)}"
285
 
286
+ # Simple Gradio interface
287
+ with gr.Blocks() as demo:
288
+ gr.Markdown("# Brain Tumor Classification")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
289
 
290
+ with gr.Tab("Train"):
291
+ train_btn = gr.Button("Train Model")
292
+ output = gr.Textbox(lines=10)
293
+ train_btn.click(train_and_save_model, outputs=output)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
294
 
295
+ with gr.Tab("Classify"):
296
+ image = gr.Image(type="pil")
297
+ classify_btn = gr.Button("Classify")
298
+ result = gr.Label()
299
+ classify_btn.click(classify_mri, inputs=image, outputs=result)
 
 
 
 
 
 
 
 
 
 
300
 
301
  if __name__ == "__main__":
302
  demo.launch()