stevekor commited on
Commit
016cbfe
·
1 Parent(s): 5d42efd

Restore legacy plot outputs (match 42bf0f)

Browse files
Files changed (1) hide show
  1. app.py +65 -104
app.py CHANGED
@@ -1,6 +1,6 @@
1
  import gradio as gr
2
  import os
3
- from PIL import Image, ImageDraw
4
  import numpy as np
5
  import pickle
6
  import io
@@ -79,53 +79,41 @@ def _ensure_hf_repo_cloned(repo_url: str, repo_dir: str) -> None:
79
 
80
  subprocess.run(["git", "clone", "--depth", "1", clone_url, repo_dir], check=True, env=env)
81
 
82
-
83
- def _make_error_image(message: str, size=(420, 420)) -> Image.Image:
84
- # A simple fallback to avoid blank/failed Gradio image renders.
85
- img = Image.new("RGB", size, color=(255, 255, 255))
86
- draw = ImageDraw.Draw(img)
87
- draw.text((12, 12), message[:800], fill=(0, 0, 0))
88
- return img
89
-
90
  #################### BEAM PREDICTION #########################}
91
  def beam_prediction_task(data_percentage, task_complexity, theme='Dark'):
92
  # Folder naming convention based on input_type, data_percentage, and task_complexity
93
  raw_folder = f"images/raw_{data_percentage/100:.1f}_{task_complexity}"
94
  embeddings_folder = f"images/embedding_{data_percentage/100:.1f}_{task_complexity}"
95
 
96
- # Process raw confusion matrix
97
- raw_cm = compute_average_confusion_matrix(raw_folder, num_labels=int(task_complexity))
98
  if raw_cm is not None:
99
- raw_img = plot_confusion_matrix_beamPred(
 
100
  raw_cm,
101
  classes=np.arange(raw_cm.shape[0]),
102
  title=f"Confusion Matrix (Raw Channels)\n{data_percentage}% data, {task_complexity} beams",
103
- save_path=None,
104
  theme=theme,
105
  )
 
106
  else:
107
- raw_img = _make_error_image(
108
- f"No data found for Raw Channels.\n\n"
109
- f"Expected CSVs under: {raw_folder}\n"
110
- f"Try a different data percentage / beam count."
111
- )
112
 
113
- # Process embeddings confusion matrix
114
- embeddings_cm = compute_average_confusion_matrix(embeddings_folder, num_labels=int(task_complexity))
115
  if embeddings_cm is not None:
116
- embeddings_img = plot_confusion_matrix_beamPred(
 
117
  embeddings_cm,
118
  classes=np.arange(embeddings_cm.shape[0]),
119
  title=f"Confusion Matrix (LWM Embeddings)\n{data_percentage}% data, {task_complexity} beams",
120
- save_path=None,
121
  theme=theme,
122
  )
 
123
  else:
124
- embeddings_img = _make_error_image(
125
- f"No data found for LWM Embeddings.\n\n"
126
- f"Expected CSVs under: {embeddings_folder}\n"
127
- f"Try a different data percentage / beam count."
128
- )
129
 
130
  return raw_img, embeddings_img
131
 
@@ -148,7 +136,7 @@ def compute_f1_score(cm):
148
  f1 = np.nan_to_num(f1) # Replace NaN with 0
149
  return np.mean(f1) # Return the mean F1-score across all classes
150
 
151
- def plot_confusion_matrix_beamPred(cm, classes, title, save_path=None, theme='Dark'):
152
  # Compute the average F1-score
153
  avg_f1 = compute_f1_score(cm)
154
 
@@ -185,67 +173,47 @@ def plot_confusion_matrix_beamPred(cm, classes, title, save_path=None, theme='Da
185
  plt.xlabel('Predicted label', color=text_color, fontsize=20)
186
  plt.tight_layout()
187
 
188
- buf = io.BytesIO()
189
- plt.savefig(buf, format="png", transparent=True)
190
  plt.close()
191
- buf.seek(0)
192
- return Image.open(buf)
193
 
194
- def compute_average_confusion_matrix(folder, num_labels=None):
195
- """Average confusion matrices across CSVs in a folder.
196
-
197
- For Beam Prediction, the correct class set is defined by the selected beam count
198
- (task_complexity). Inferring the number of labels from unique targets can shrink
199
- the matrix when some beams never appear in a particular split.
200
- """
201
-
202
- if not os.path.isdir(folder):
203
- return None
204
-
205
- csv_files = [f for f in os.listdir(folder) if f.endswith(".csv")]
206
- if not csv_files:
207
- return None
208
-
209
- # If num_labels isn't specified, infer from max label index seen.
210
- if num_labels is None:
211
- inferred = 0
212
- for file in csv_files:
213
- data = pd.read_csv(os.path.join(folder, file))
214
- y_true = pd.to_numeric(data.get("Target"), errors="coerce").fillna(-1).astype(int)
215
- y_pred = pd.to_numeric(data.get("Top-1 Prediction"), errors="coerce").fillna(-1).astype(int)
216
- true_max = int(y_true.max()) if len(y_true) else -1
217
- pred_max = int(y_pred.max()) if len(y_pred) else -1
218
- local_max = max(true_max, pred_max)
219
- inferred = max(inferred, local_max + 1)
220
- num_labels = inferred
221
-
222
- try:
223
- num_labels = int(num_labels)
224
- except Exception:
225
- return None
226
- if num_labels <= 0:
227
- return None
228
 
 
229
  confusion_matrices = []
230
- for file in csv_files:
231
- data = pd.read_csv(os.path.join(folder, file))
232
- y_true = pd.to_numeric(data.get("Target"), errors="coerce").fillna(-1).astype(int)
233
- y_pred = pd.to_numeric(data.get("Top-1 Prediction"), errors="coerce").fillna(-1).astype(int)
234
-
235
- # Filter invalid/out-of-range labels to avoid skew.
236
- valid = (y_true >= 0) & (y_true < num_labels) & (y_pred >= 0) & (y_pred < num_labels)
237
- y_true = y_true[valid]
238
- y_pred = y_pred[valid]
239
-
240
- if len(y_true) == 0:
241
- continue
242
 
243
- cm = confusion_matrix(y_true, y_pred, labels=np.arange(num_labels))
244
- confusion_matrices.append(cm)
 
 
 
 
245
 
246
- if not confusion_matrices:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
247
  return None
248
- return np.mean(confusion_matrices, axis=0)
249
 
250
  ########################## LOS/NLOS CLASSIFICATION #############################3
251
 
@@ -305,11 +273,12 @@ def plot_confusion_matrix_from_csv(csv_file_path, title, save_path, light_mode=F
305
  plt.xlabel('Predicted label', color=text_color, fontsize=14)
306
  plt.tight_layout()
307
 
308
- buf = io.BytesIO()
309
- plt.savefig(buf, format="png", transparent=True)
310
  plt.close()
311
- buf.seek(0)
312
- return Image.open(buf)
 
313
 
314
  # Function to load confusion matrix based on percentage and input_type
315
  def display_confusion_matrices_los(percentage):
@@ -322,28 +291,20 @@ def display_confusion_matrices_los(percentage):
322
  # Process raw confusion matrix
323
  raw_csv_file = os.path.join(raw_folder, f"test_predictions_raw_{percentage/100:.3f}_los.csv")
324
  raw_cm_img_path = os.path.join(raw_folder, "confusion_matrix_raw.png")
325
- try:
326
- raw_img = plot_confusion_matrix_from_csv(
327
- raw_csv_file,
328
- f"Confusion Matrix (Raw Channels)\n{percentage:.1f}% data",
329
- raw_cm_img_path,
330
- )
331
- except Exception as exc:
332
- raw_img = _make_error_image(f"Failed to load Raw CSV:\n{raw_csv_file}\n\n{exc}")
333
 
334
  # Process embeddings confusion matrix
335
  embeddings_csv_file = os.path.join(embeddings_folder, f"test_predictions_embedding_{percentage/100:.3f}_los.csv")
336
  embeddings_cm_img_path = os.path.join(embeddings_folder, "confusion_matrix_embeddings.png")
337
- try:
338
- embeddings_img = plot_confusion_matrix_from_csv(
339
- embeddings_csv_file,
340
- f"Confusion Matrix (LWM Embeddings)\n{percentage:.1f}% data",
341
- embeddings_cm_img_path,
342
- )
343
- except Exception as exc:
344
- embeddings_img = _make_error_image(
345
- f"Failed to load Embedding CSV:\n{embeddings_csv_file}\n\n{exc}"
346
- )
347
 
348
  return raw_img, embeddings_img
349
 
 
1
  import gradio as gr
2
  import os
3
+ from PIL import Image
4
  import numpy as np
5
  import pickle
6
  import io
 
79
 
80
  subprocess.run(["git", "clone", "--depth", "1", clone_url, repo_dir], check=True, env=env)
81
 
 
 
 
 
 
 
 
 
82
  #################### BEAM PREDICTION #########################}
83
  def beam_prediction_task(data_percentage, task_complexity, theme='Dark'):
84
  # Folder naming convention based on input_type, data_percentage, and task_complexity
85
  raw_folder = f"images/raw_{data_percentage/100:.1f}_{task_complexity}"
86
  embeddings_folder = f"images/embedding_{data_percentage/100:.1f}_{task_complexity}"
87
 
88
+ # Process raw confusion matrix (match legacy behavior)
89
+ raw_cm = compute_average_confusion_matrix(raw_folder)
90
  if raw_cm is not None:
91
+ raw_cm_path = os.path.join(raw_folder, "confusion_matrix_raw.png")
92
+ plot_confusion_matrix_beamPred(
93
  raw_cm,
94
  classes=np.arange(raw_cm.shape[0]),
95
  title=f"Confusion Matrix (Raw Channels)\n{data_percentage}% data, {task_complexity} beams",
96
+ save_path=raw_cm_path,
97
  theme=theme,
98
  )
99
+ raw_img = Image.open(raw_cm_path)
100
  else:
101
+ raw_img = None
 
 
 
 
102
 
103
+ # Process embeddings confusion matrix (match legacy behavior)
104
+ embeddings_cm = compute_average_confusion_matrix(embeddings_folder)
105
  if embeddings_cm is not None:
106
+ embeddings_cm_path = os.path.join(embeddings_folder, "confusion_matrix_embeddings.png")
107
+ plot_confusion_matrix_beamPred(
108
  embeddings_cm,
109
  classes=np.arange(embeddings_cm.shape[0]),
110
  title=f"Confusion Matrix (LWM Embeddings)\n{data_percentage}% data, {task_complexity} beams",
111
+ save_path=embeddings_cm_path,
112
  theme=theme,
113
  )
114
+ embeddings_img = Image.open(embeddings_cm_path)
115
  else:
116
+ embeddings_img = None
 
 
 
 
117
 
118
  return raw_img, embeddings_img
119
 
 
136
  f1 = np.nan_to_num(f1) # Replace NaN with 0
137
  return np.mean(f1) # Return the mean F1-score across all classes
138
 
139
+ def plot_confusion_matrix_beamPred(cm, classes, title, save_path, theme='Dark'):
140
  # Compute the average F1-score
141
  avg_f1 = compute_f1_score(cm)
142
 
 
173
  plt.xlabel('Predicted label', color=text_color, fontsize=20)
174
  plt.tight_layout()
175
 
176
+ plt.savefig(save_path, transparent=True) # Transparent to blend with the site background
 
177
  plt.close()
 
 
178
 
179
+ # Return the saved image
180
+ return Image.open(save_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
181
 
182
+ def compute_average_confusion_matrix(folder):
183
  confusion_matrices = []
184
+ max_num_labels = 0
 
 
 
 
 
 
 
 
 
 
 
185
 
186
+ # First pass to determine the maximum number of labels
187
+ for file in os.listdir(folder):
188
+ if file.endswith(".csv"):
189
+ data = pd.read_csv(os.path.join(folder, file))
190
+ num_labels = len(np.unique(data["Target"]))
191
+ max_num_labels = max(max_num_labels, num_labels)
192
 
193
+ # Second pass to calculate the confusion matrices and pad if necessary
194
+ for file in os.listdir(folder):
195
+ if file.endswith(".csv"):
196
+ data = pd.read_csv(os.path.join(folder, file))
197
+ y_true = data["Target"]
198
+ y_pred = data["Top-1 Prediction"]
199
+ num_labels = len(np.unique(y_true))
200
+
201
+ # Compute confusion matrix
202
+ cm = confusion_matrix(y_true, y_pred, labels=np.arange(max_num_labels))
203
+
204
+ # If the confusion matrix is smaller, pad it to match the largest size
205
+ if cm.shape[0] < max_num_labels:
206
+ padded_cm = np.zeros((max_num_labels, max_num_labels))
207
+ padded_cm[: cm.shape[0], : cm.shape[1]] = cm
208
+ confusion_matrices.append(padded_cm)
209
+ else:
210
+ confusion_matrices.append(cm)
211
+
212
+ if confusion_matrices:
213
+ avg_cm = np.mean(confusion_matrices, axis=0)
214
+ return avg_cm
215
+ else:
216
  return None
 
217
 
218
  ########################## LOS/NLOS CLASSIFICATION #############################3
219
 
 
273
  plt.xlabel('Predicted label', color=text_color, fontsize=14)
274
  plt.tight_layout()
275
 
276
+ # Save the plot as an image
277
+ plt.savefig(save_path, transparent=True) # Use transparent to blend with the website
278
  plt.close()
279
+
280
+ # Return the saved image
281
+ return Image.open(save_path)
282
 
283
  # Function to load confusion matrix based on percentage and input_type
284
  def display_confusion_matrices_los(percentage):
 
291
  # Process raw confusion matrix
292
  raw_csv_file = os.path.join(raw_folder, f"test_predictions_raw_{percentage/100:.3f}_los.csv")
293
  raw_cm_img_path = os.path.join(raw_folder, "confusion_matrix_raw.png")
294
+ raw_img = plot_confusion_matrix_from_csv(
295
+ raw_csv_file,
296
+ f"Confusion Matrix (Raw Channels)\n{percentage:.1f}% data",
297
+ raw_cm_img_path,
298
+ )
 
 
 
299
 
300
  # Process embeddings confusion matrix
301
  embeddings_csv_file = os.path.join(embeddings_folder, f"test_predictions_embedding_{percentage/100:.3f}_los.csv")
302
  embeddings_cm_img_path = os.path.join(embeddings_folder, "confusion_matrix_embeddings.png")
303
+ embeddings_img = plot_confusion_matrix_from_csv(
304
+ embeddings_csv_file,
305
+ f"Confusion Matrix (LWM Embeddings)\n{percentage:.1f}% data",
306
+ embeddings_cm_img_path,
307
+ )
 
 
 
 
 
308
 
309
  return raw_img, embeddings_img
310