Habith commited on
Commit
b2ef413
·
verified ·
1 Parent(s): 783b5ee

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +276 -166
app.py CHANGED
@@ -1,6 +1,6 @@
1
  import gradio as gr
2
  import torch
3
- from torch.utils.data import DataLoader
4
  from torchvision import transforms
5
  from transformers import (
6
  AutoImageProcessor,
@@ -8,158 +8,178 @@ from transformers import (
8
  TrainingArguments,
9
  Trainer
10
  )
11
- from datasets import load_dataset, DatasetDict
12
  import numpy as np
13
  from huggingface_hub import HfApi
14
  import os
15
  from PIL import Image as PILImage
16
 
17
  # Configuration
18
- HF_DATASET = "Ultralytics/Brain-tumor"
19
  CUSTOM_MODEL_NAME = "GoGenix_MRI_Brain"
20
  BASE_MODEL = "Falconsai/nsfw_image_detection"
21
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
22
 
23
- def train_and_save_model():
24
- """Train the model with explicit dataset format handling"""
 
 
 
 
 
 
 
 
25
 
26
- try:
27
- print("Loading Ultralytics/Brain-tumor dataset with explicit format...")
28
-
29
- # Try multiple loading methods to handle format detection issues
30
- dataset = None
31
-
32
- # Method 1: Try loading with explicit imagefolder format for all splits
33
  try:
34
- dataset = load_dataset(HF_DATASET, "imagefolder")
35
- print("✅ Loaded with 'imagefolder' format")
36
- except Exception as e1:
37
- print(f"❌ Method 1 failed: {e1}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
- # Method 2: Try loading without specific format
40
- try:
41
- dataset = load_dataset(HF_DATASET)
42
- print("✅ Loaded without specific format")
43
- except Exception as e2:
44
- print(f"❌ Method 2 failed: {e2}")
 
 
 
 
45
 
46
- # Method 3: Try loading with data_files specification
47
- try:
48
- dataset = load_dataset(
49
- HF_DATASET,
50
- data_files={
51
- 'train': ['**/train/**/*.jpg', '**/train/**/*.png', '**/train/**/*.jpeg'],
52
- 'validation': ['**/valid/**/*.jpg', '**/valid/**/*.png', '**/valid/**/*.jpeg'],
53
- 'test': ['**/test/**/*.jpg', '**/test/**/*.png', '**/test/**/*.jpeg']
54
- }
55
- )
56
- print(" Loaded with data_files specification")
57
- except Exception as e3:
58
- print(f"❌ Method 3 failed: {e3}")
59
- return f"All loading methods failed:\n1. {e1}\n2. {e2}\n3. {e3}"
60
-
61
- if dataset is None:
62
- return "❌ Could not load dataset with any method"
63
-
64
- print(f"Dataset splits available: {list(dataset.keys())}")
65
-
66
- # Check which splits we have and map them appropriately
67
- if 'train' in dataset and 'validation' in dataset:
68
- train_split = dataset['train']
69
- test_split = dataset['validation']
70
- print("Using 'train' and 'validation' splits")
71
- elif 'valid' in dataset and 'test' in dataset:
72
- train_split = dataset['valid']
73
- test_split = dataset['test']
74
- print("Using 'valid' and 'test' splits")
75
- elif 'train' in dataset and 'test' in dataset:
76
- train_split = dataset['train']
77
- test_split = dataset['test']
78
- print("Using 'train' and 'test' splits")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
  else:
80
- available_splits = list(dataset.keys())
81
- return f"❌ Cannot determine train/test splits. Available splits: {available_splits}"
82
 
83
- print(f"Training samples: {len(train_split)}")
84
- print(f"Test samples: {len(test_split)}")
85
 
86
- # Analyze dataset structure
 
 
 
87
  if len(train_split) > 0:
88
  sample = train_split[0]
89
- print(f"Sample keys: {list(sample.keys())}")
90
- for key in sample.keys():
91
- print(f" {key}: {type(sample[key])}")
92
-
93
- # Determine if this is a classification or object detection dataset
94
- # For Ultralytics datasets, check if it has object detection format
95
- def detect_dataset_type(split):
96
- if len(split) == 0:
97
- return "empty"
98
 
99
- sample = split[0]
100
- if 'objects' in sample or 'bbox' in sample or 'labels' in sample and isinstance(sample.get('labels'), list):
101
- return "object_detection"
102
- elif 'label' in sample and isinstance(sample['label'], (int, float)):
103
- return "classification"
104
- elif 'image' in sample:
105
- return "image_only"
106
- else:
107
- return "unknown"
108
-
109
- train_type = detect_dataset_type(train_split)
110
- test_type = detect_dataset_type(test_split)
111
-
112
- print(f"Train dataset type: {train_type}")
113
- print(f"Test dataset type: {test_type}")
114
-
115
- # Convert to classification format
116
- def convert_to_classification(item):
117
- """Convert various formats to classification format"""
118
- image = item.get('image')
119
 
120
- # Handle different label formats
121
- if 'label' in item and isinstance(item['label'], (int, float)):
122
- label = int(item['label'])
123
- elif 'objects' in item or 'bbox' in item:
124
- # Object detection format - convert to binary classification
125
- # If there are objects/bboxes, it's tumor (1), else no tumor (0)
126
- label = 1 if (item.get('objects') or item.get('bbox')) else 0
127
- elif 'labels' in item and isinstance(item['labels'], list) and len(item['labels']) > 0:
128
- label = 1 # Has labels = tumor
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
  else:
130
- label = 0 # No labels = no tumor
131
-
132
- return {
133
- 'image': image,
134
- 'label': label
135
- }
136
-
137
- print("Converting dataset to classification format...")
138
- train_classification = train_split.map(convert_to_classification)
139
- test_classification = test_split.map(convert_to_classification)
140
-
141
- # Count classes
142
- tumor_count = sum(1 for item in train_classification if item['label'] == 1)
143
- no_tumor_count = sum(1 for item in train_classification if item['label'] == 0)
144
-
145
- print(f"Tumor samples: {tumor_count}, No tumor samples: {no_tumor_count}")
146
-
147
- # Use binary classification
148
- class_names = ["no_tumor", "tumor"]
149
- num_classes = 2
150
-
151
- print(f"Using {num_classes} classes: {class_names}")
152
 
153
- # Define transforms
154
  transform = transforms.Compose([
155
  transforms.Resize((224, 224)),
156
- transforms.Grayscale(num_output_channels=3),
157
  transforms.ToTensor(),
158
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
159
  ])
160
 
161
- # Custom dataset class
162
- class MRIDataset(torch.utils.data.Dataset):
163
  def __init__(self, dataset, transform=None):
164
  self.dataset = dataset
165
  self.transform = transform
@@ -169,28 +189,42 @@ def train_and_save_model():
169
 
170
  def __getitem__(self, idx):
171
  item = self.dataset[idx]
172
- image = item['image']
173
- label = item['label']
174
 
175
- if self.transform:
176
- if not isinstance(image, PILImage.Image):
 
 
 
 
 
177
  image = PILImage.fromarray(image)
 
 
 
 
 
 
 
 
 
 
 
178
  image = self.transform(image)
179
 
180
  return image, label
181
 
182
  # Create datasets
183
- train_dataset = MRIDataset(train_classification, transform=transform)
184
- test_dataset = MRIDataset(test_classification, transform=transform)
185
 
186
- # Load model
 
 
187
  print("Loading base model...")
188
  model = AutoModelForImageClassification.from_pretrained(
189
  BASE_MODEL,
190
  num_labels=num_classes,
191
- ignore_mismatched_sizes=True,
192
- id2label={0: "no_tumor", 1: "tumor"},
193
- label2id={"no_tumor": 0, "tumor": 1}
194
  )
195
  processor = AutoImageProcessor.from_pretrained(BASE_MODEL)
196
  model.to(DEVICE)
@@ -198,10 +232,10 @@ def train_and_save_model():
198
  # Training arguments
199
  training_args = TrainingArguments(
200
  output_dir="./results",
201
- num_train_epochs=5, # Reduced for testing
202
  per_device_train_batch_size=8,
203
  per_device_eval_batch_size=8,
204
- warmup_steps=100,
205
  weight_decay=0.01,
206
  logging_dir="./logs",
207
  logging_steps=10,
@@ -210,10 +244,9 @@ def train_and_save_model():
210
  load_best_model_at_end=True,
211
  push_to_hub=True,
212
  hub_model_id=CUSTOM_MODEL_NAME,
213
- remove_unused_columns=False,
214
  )
215
 
216
- # Metrics
217
  def compute_metrics(eval_pred):
218
  predictions, labels = eval_pred
219
  predictions = np.argmax(predictions, axis=1)
@@ -224,79 +257,156 @@ def train_and_save_model():
224
  trainer = Trainer(
225
  model=model,
226
  args=training_args,
227
- train_dataset=train_dataset,
228
- eval_dataset=test_dataset,
229
  compute_metrics=compute_metrics,
230
  )
231
 
232
  # Start training
233
- print("Starting training...")
234
  train_result = trainer.train()
235
 
236
  # Save model
237
  trainer.save_model(f"./{CUSTOM_MODEL_NAME}")
238
  processor.save_pretrained(f"./{CUSTOM_MODEL_NAME}")
239
 
240
- # Push to hub
241
- trainer.push_to_hub(commit_message="Train Brain Tumor classifier")
242
 
243
- # Results
244
  train_accuracy = train_result.metrics.get('train_accuracy', 'N/A')
245
  eval_accuracy = train_result.metrics.get('eval_accuracy', 'N/A')
246
 
247
- result_message = f"""
248
- Training completed!
249
 
 
250
  Model: {CUSTOM_MODEL_NAME}
251
- Dataset: {HF_DATASET}
252
  Classes: {class_names}
253
- Training Accuracy: {train_accuracy}
254
- Validation Accuracy: {eval_accuracy}
 
 
 
255
  """
256
 
257
  return result_message
258
 
259
  except Exception as e:
260
  import traceback
261
- return f"❌ Error: {str(e)}\n\n{traceback.format_exc()}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
262
 
263
  def classify_mri(image):
264
- """Classify MRI image"""
265
  try:
 
266
  model = AutoModelForImageClassification.from_pretrained(CUSTOM_MODEL_NAME)
267
  processor = AutoImageProcessor.from_pretrained(CUSTOM_MODEL_NAME)
268
 
269
  model.to(DEVICE)
270
  model.eval()
271
 
 
272
  inputs = processor(image, return_tensors="pt").to(DEVICE)
273
 
 
274
  with torch.no_grad():
275
  outputs = model(**inputs)
276
  predictions = torch.nn.functional.softmax(outputs.logits, dim=-1)
277
 
278
- class_names = ["No Tumor", "Tumor Detected"]
279
- results = {class_names[i]: float(predictions[0][i]) for i in range(2)}
 
 
 
 
 
 
 
 
 
 
 
 
280
 
281
  return results
282
 
283
  except Exception as e:
284
- return f"⚠️ Error: {str(e)}"
285
 
286
- # Simple Gradio interface
287
- with gr.Blocks() as demo:
288
- gr.Markdown("# Brain Tumor Classification")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
289
 
290
- with gr.Tab("Train"):
291
- train_btn = gr.Button("Train Model")
292
- output = gr.Textbox(lines=10)
293
- train_btn.click(train_and_save_model, outputs=output)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
294
 
295
- with gr.Tab("Classify"):
296
- image = gr.Image(type="pil")
297
- classify_btn = gr.Button("Classify")
298
- result = gr.Label()
299
- classify_btn.click(classify_mri, inputs=image, outputs=result)
 
 
 
 
 
 
300
 
301
  if __name__ == "__main__":
302
  demo.launch()
 
1
  import gradio as gr
2
  import torch
3
+ from torch.utils.data import DataLoader, Dataset
4
  from torchvision import transforms
5
  from transformers import (
6
  AutoImageProcessor,
 
8
  TrainingArguments,
9
  Trainer
10
  )
11
+ from datasets import load_dataset
12
  import numpy as np
13
  from huggingface_hub import HfApi
14
  import os
15
  from PIL import Image as PILImage
16
 
17
  # Configuration
 
18
  CUSTOM_MODEL_NAME = "GoGenix_MRI_Brain"
19
  BASE_MODEL = "Falconsai/nsfw_image_detection"
20
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
21
 
22
+ # Your custom dataset selection
23
+ BRAIN_TUMOR_DATASETS = [
24
+ "PranomVignesh/MRI-Images-of-Brain-Tumor", # Your first choice
25
+ "Hemg/Brain-Tumor-MRI-Dataset", # Your second choice
26
+ "AntonXue/mcal-mri-brain-tumor", # Your third choice
27
+ ]
28
+
29
+ def find_working_dataset():
30
+ """Try your custom datasets and return the first one that works"""
31
+ working_datasets = []
32
 
33
+ for dataset_name in BRAIN_TUMOR_DATASETS:
 
 
 
 
 
 
34
  try:
35
+ print(f"Trying dataset: {dataset_name}")
36
+ dataset = load_dataset(dataset_name)
37
+
38
+ # Basic validation
39
+ splits = list(dataset.keys())
40
+ print(f"Found splits: {splits}")
41
+
42
+ # Check if dataset has content
43
+ if len(splits) == 0:
44
+ print(f"⚠️ {dataset_name} - No splits found")
45
+ continue
46
+
47
+ first_split = splits[0]
48
+ if len(dataset[first_split]) == 0:
49
+ print(f"⚠️ {dataset_name} - Empty dataset")
50
+ continue
51
+
52
+ # Check sample structure
53
+ sample = dataset[first_split][0]
54
+ sample_keys = list(sample.keys())
55
+ print(f"Sample keys: {sample_keys}")
56
 
57
+ if 'image' in sample_keys:
58
+ working_datasets.append({
59
+ 'name': dataset_name,
60
+ 'splits': splits,
61
+ 'sample_structure': sample_keys,
62
+ 'dataset': dataset
63
+ })
64
+ print(f"✅ {dataset_name} - VALID")
65
+ else:
66
+ print(f"⚠️ {dataset_name} - Missing 'image' key")
67
 
68
+ except Exception as e:
69
+ print(f"❌ {dataset_name} - Failed: {str(e)}")
70
+ continue
71
+
72
+ return working_datasets
73
+
74
+ def train_and_save_model():
75
+ """Train the model using your selected datasets"""
76
+
77
+ try:
78
+ print("Searching for compatible brain tumor datasets...")
79
+ working_datasets = find_working_dataset()
80
+
81
+ if not working_datasets:
82
+ return "❌ None of your selected datasets worked. Please check the dataset names or try different datasets."
83
+
84
+ # Use the first working dataset
85
+ selected_dataset = working_datasets[0]
86
+ dataset_name = selected_dataset['name']
87
+ splits = selected_dataset['splits']
88
+ dataset_obj = selected_dataset['dataset']
89
+
90
+ result_message = f"✅ Selected dataset: {dataset_name}\n"
91
+ result_message += f"Splits available: {splits}\n\n"
92
+
93
+ print(f"Using dataset: {dataset_name}")
94
+
95
+ # Determine which splits to use
96
+ train_split_key = None
97
+ test_split_key = None
98
+
99
+ # Prioritize standard split names
100
+ if 'train' in splits:
101
+ train_split_key = 'train'
102
+ elif 'training' in splits:
103
+ train_split_key = 'training'
104
+ elif 'Train' in splits:
105
+ train_split_key = 'Train'
106
+ else:
107
+ train_split_key = splits[0] # Use first available split
108
+
109
+ if 'test' in splits:
110
+ test_split_key = 'test'
111
+ elif 'validation' in splits:
112
+ test_split_key = 'validation'
113
+ elif 'valid' in splits:
114
+ test_split_key = 'valid'
115
+ elif 'Test' in splits:
116
+ test_split_key = 'Test'
117
+ elif len(splits) > 1:
118
+ test_split_key = splits[1] # Use second split
119
  else:
120
+ test_split_key = splits[0] # Use same split for train/test (will split later)
 
121
 
122
+ train_split = dataset_obj[train_split_key]
123
+ test_split = dataset_obj[test_split_key]
124
 
125
+ result_message += f"Using '{train_split_key}' split for training ({len(train_split)} samples)\n"
126
+ result_message += f"Using '{test_split_key}' split for testing ({len(test_split)} samples)\n\n"
127
+
128
+ # Analyze dataset in detail
129
  if len(train_split) > 0:
130
  sample = train_split[0]
131
+ result_message += f"Sample structure: {list(sample.keys())}\n"
 
 
 
 
 
 
 
 
132
 
133
+ # Check image properties
134
+ if 'image' in sample:
135
+ img = sample['image']
136
+ result_message += f"Image type: {type(img)}, size: {getattr(img, 'size', 'N/A')}\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
 
138
+ # Detect number of classes
139
+ if 'label' in sample:
140
+ unique_labels = set()
141
+ # Check first 100 samples for unique labels
142
+ for i in range(min(100, len(train_split))):
143
+ unique_labels.add(train_split[i]['label'])
144
+
145
+ num_classes = len(unique_labels)
146
+ result_message += f"Detected {num_classes} unique labels: {sorted(unique_labels)}\n"
147
+
148
+ # Try to get class names
149
+ if hasattr(train_split.features.get('label', None), 'names'):
150
+ class_names = train_split.features['label'].names
151
+ else:
152
+ # Map numeric labels to meaningful names
153
+ if num_classes == 2:
154
+ class_names = ["no_tumor", "tumor"]
155
+ elif num_classes == 3:
156
+ class_names = ["glioma", "meningioma", "pituitary"]
157
+ elif num_classes == 4:
158
+ class_names = ["glioma", "meningioma", "no_tumor", "pituitary"]
159
+ else:
160
+ class_names = [f"class_{i}" for i in range(num_classes)]
161
+
162
+ result_message += f"Using class names: {class_names}\n"
163
  else:
164
+ # Default to binary classification
165
+ num_classes = 2
166
+ class_names = ["no_tumor", "tumor"]
167
+ result_message += "No labels found, using binary classification\n"
168
+ else:
169
+ num_classes = 2
170
+ class_names = ["no_tumor", "tumor"]
171
+ result_message += "Empty dataset, using default binary classification\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
172
 
173
+ # Define transforms for MRI images
174
  transform = transforms.Compose([
175
  transforms.Resize((224, 224)),
176
+ transforms.Grayscale(num_output_channels=3), # Ensure 3 channels
177
  transforms.ToTensor(),
178
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
179
  ])
180
 
181
+ # Custom dataset class with robust error handling
182
+ class MRIDataset(Dataset):
183
  def __init__(self, dataset, transform=None):
184
  self.dataset = dataset
185
  self.transform = transform
 
189
 
190
  def __getitem__(self, idx):
191
  item = self.dataset[idx]
 
 
192
 
193
+ # Handle image
194
+ image = item.get('image')
195
+ if image is None:
196
+ # Create placeholder image if none exists
197
+ image = PILImage.new('RGB', (224, 224), color='gray')
198
+ elif not isinstance(image, PILImage.Image):
199
+ try:
200
  image = PILImage.fromarray(image)
201
+ except:
202
+ image = PILImage.new('RGB', (224, 224), color='gray')
203
+
204
+ # Handle label
205
+ label = item.get('label', 0)
206
+ if isinstance(label, (list, tuple)) and len(label) > 0:
207
+ label = label[0] # Take first element if label is a list
208
+ label = int(label) if label is not None else 0
209
+
210
+ # Apply transform
211
+ if self.transform:
212
  image = self.transform(image)
213
 
214
  return image, label
215
 
216
  # Create datasets
217
+ train_dataset_obj = MRIDataset(train_split, transform=transform)
218
+ test_dataset_obj = MRIDataset(test_split, transform=transform)
219
 
220
+ result_message += f"Final dataset - Train: {len(train_dataset_obj)}, Test: {len(test_dataset_obj)}\n\n"
221
+
222
+ # Load base model
223
  print("Loading base model...")
224
  model = AutoModelForImageClassification.from_pretrained(
225
  BASE_MODEL,
226
  num_labels=num_classes,
227
+ ignore_mismatched_sizes=True
 
 
228
  )
229
  processor = AutoImageProcessor.from_pretrained(BASE_MODEL)
230
  model.to(DEVICE)
 
232
  # Training arguments
233
  training_args = TrainingArguments(
234
  output_dir="./results",
235
+ num_train_epochs=10,
236
  per_device_train_batch_size=8,
237
  per_device_eval_batch_size=8,
238
+ warmup_steps=500,
239
  weight_decay=0.01,
240
  logging_dir="./logs",
241
  logging_steps=10,
 
244
  load_best_model_at_end=True,
245
  push_to_hub=True,
246
  hub_model_id=CUSTOM_MODEL_NAME,
 
247
  )
248
 
249
+ # Metrics function
250
  def compute_metrics(eval_pred):
251
  predictions, labels = eval_pred
252
  predictions = np.argmax(predictions, axis=1)
 
257
  trainer = Trainer(
258
  model=model,
259
  args=training_args,
260
+ train_dataset=train_dataset_obj,
261
+ eval_dataset=test_dataset_obj,
262
  compute_metrics=compute_metrics,
263
  )
264
 
265
  # Start training
266
+ result_message += "Starting training...\n"
267
  train_result = trainer.train()
268
 
269
  # Save model
270
  trainer.save_model(f"./{CUSTOM_MODEL_NAME}")
271
  processor.save_pretrained(f"./{CUSTOM_MODEL_NAME}")
272
 
273
+ # Push to Hugging Face Hub
274
+ trainer.push_to_hub(commit_message=f"Trained on {dataset_name}")
275
 
276
+ # Training results
277
  train_accuracy = train_result.metrics.get('train_accuracy', 'N/A')
278
  eval_accuracy = train_result.metrics.get('eval_accuracy', 'N/A')
279
 
280
+ result_message += f"""
281
+ 🎯 Training Completed Successfully!
282
 
283
+ Dataset: {dataset_name}
284
  Model: {CUSTOM_MODEL_NAME}
 
285
  Classes: {class_names}
286
+ Training Accuracy: {train_accuracy or 'N/A'}
287
+ Validation Accuracy: {eval_accuracy or 'N/A'}
288
+
289
+ Model has been saved and pushed to Hugging Face Hub.
290
+ You can now use the 'Classify MRI' tab to test the model.
291
  """
292
 
293
  return result_message
294
 
295
  except Exception as e:
296
  import traceback
297
+ error_details = traceback.format_exc()
298
+
299
+ error_msg = f"""
300
+ ❌ Training Failed
301
+
302
+ Error: {str(e)}
303
+
304
+ Datasets tried:
305
+ {BRAIN_TUMOR_DATASETS}
306
+
307
+ Please check:
308
+ 1. Dataset names are correct
309
+ 2. Internet connection
310
+ 3. Dataset accessibility
311
+
312
+ Error Details:
313
+ {error_details}
314
+ """
315
+ return error_msg
316
 
317
  def classify_mri(image):
318
+ """Classify a new MRI image using the trained model"""
319
  try:
320
+ # Load your custom model
321
  model = AutoModelForImageClassification.from_pretrained(CUSTOM_MODEL_NAME)
322
  processor = AutoImageProcessor.from_pretrained(CUSTOM_MODEL_NAME)
323
 
324
  model.to(DEVICE)
325
  model.eval()
326
 
327
+ # Preprocess image
328
  inputs = processor(image, return_tensors="pt").to(DEVICE)
329
 
330
+ # Predict
331
  with torch.no_grad():
332
  outputs = model(**inputs)
333
  predictions = torch.nn.functional.softmax(outputs.logits, dim=-1)
334
 
335
+ # Try to detect number of classes
336
+ num_classes = predictions.shape[1]
337
+
338
+ # Default class names based on number of classes
339
+ if num_classes == 2:
340
+ class_names = ["No Tumor", "Tumor Detected"]
341
+ elif num_classes == 3:
342
+ class_names = ["Glioma", "Meningioma", "Pituitary Tumor"]
343
+ elif num_classes == 4:
344
+ class_names = ["Glioma", "Meningioma", "No Tumor", "Pituitary Tumor"]
345
+ else:
346
+ class_names = [f"Class {i}" for i in range(num_classes)]
347
+
348
+ results = {class_names[i]: float(predictions[0][i]) for i in range(num_classes)}
349
 
350
  return results
351
 
352
  except Exception as e:
353
+ return f"⚠️ Model not trained yet or unavailable. Error: {str(e)}"
354
 
355
+ # Gradio Interface
356
+ with gr.Blocks(title="GoGenix MRI Brain Tumor Classifier") as demo:
357
+ gr.Markdown("# 🧠 GoGenix MRI Brain Tumor Classifier")
358
+ gr.Markdown("**Using Your Selected Datasets**")
359
+
360
+ with gr.Tab("🚀 Train Model"):
361
+ gr.Markdown("### Train with Your Custom Datasets")
362
+ gr.Markdown("Will try these datasets in order:")
363
+ for i, dataset in enumerate(BRAIN_TUMOR_DATASETS, 1):
364
+ gr.Markdown(f"{i}. `{dataset}`")
365
+
366
+ train_btn = gr.Button("Start Training", variant="primary", size="lg")
367
+ output_text = gr.Textbox(
368
+ label="Training Status",
369
+ lines=20,
370
+ placeholder="Click 'Start Training' to begin..."
371
+ )
372
+
373
+ train_btn.click(
374
+ fn=train_and_save_model,
375
+ outputs=output_text
376
+ )
377
 
378
+ with gr.Tab("🔍 Classify MRI"):
379
+ gr.Markdown("### Upload MRI Image for Classification")
380
+ gr.Markdown("**Note**: Requires successful training first")
381
+
382
+ image_input = gr.Image(
383
+ type="pil",
384
+ label="Brain MRI Scan",
385
+ height=300
386
+ )
387
+ classify_btn = gr.Button("Classify", variant="secondary")
388
+ result = gr.Label(
389
+ label="Brain Tumor Classification Results",
390
+ num_top_classes=4
391
+ )
392
+
393
+ classify_btn.click(
394
+ fn=classify_mri,
395
+ inputs=image_input,
396
+ outputs=result
397
+ )
398
 
399
+ with gr.Tab("📊 Your Datasets"):
400
+ gr.Markdown("### Your Selected Brain Tumor Datasets")
401
+ gr.Markdown("""
402
+ **Currently Using:**
403
+
404
+ 1. **PranomVignesh/MRI-Images-of-Brain-Tumor** - Primary choice
405
+ 2. **Hemg/Brain-Tumor-MRI-Dataset** - Secondary choice
406
+ 3. **AntonXue/mcal-mri-brain-tumor** - Tertiary choice
407
+
408
+ The system will try these in order and use the first one that works.
409
+ """)
410
 
411
  if __name__ == "__main__":
412
  demo.launch()