EzekielMW commited on
Commit
a6aced3
·
verified ·
1 Parent(s): e7bf265

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -27
app.py CHANGED
@@ -7,18 +7,17 @@ from sklearn.decomposition import PCA
7
  from scipy.signal import savgol_filter
8
  from math import pi
9
 
10
- # Ensure interactive backend for plotting
11
- plt.switch_backend('agg')
12
 
13
  # Load dataset
14
  df = pd.read_csv("milk_absorbance.csv")
15
  df.rename(columns={df.columns[0]: 'Label'}, inplace=True)
16
 
17
- # Gradio plot functions
18
  def plot_all():
19
  plots = []
20
 
21
- # Plot 1: Mean Spectra per Class
22
  fig1 = plt.figure(figsize=(12, 6))
23
  for label in df['Label'].unique():
24
  class_df = df[df['Label'] == label]
@@ -32,7 +31,7 @@ def plot_all():
32
  plt.tight_layout()
33
  plots.append(fig1)
34
 
35
- # Plot 2: Offset Mean Spectra
36
  fig2 = plt.figure(figsize=(12, 6))
37
  offset_step = 0.1
38
  for i, label in enumerate(df['Label'].unique()):
@@ -40,15 +39,15 @@ def plot_all():
40
  mean_spectrum = class_df.iloc[:, 1:].mean()
41
  offset = i * offset_step
42
  plt.plot(mean_spectrum.index.astype(int), mean_spectrum + offset, label=f'Label {label}')
43
- plt.title('Mean NIR Spectrum per Milk Ratio Class (with Offset)')
44
  plt.xlabel('Wavelength (nm)')
45
- plt.ylabel('Absorbance (Offset Applied)')
46
- plt.legend(title='Class (Milk Ratio)')
47
  plt.grid(True)
48
  plt.tight_layout()
49
  plots.append(fig2)
50
 
51
- # Plot 3: Radar Plot
52
  fig3 = plt.figure(figsize=(8, 8))
53
  ax = plt.subplot(111, polar=True)
54
  subset_cols = df.columns[1:][::20]
@@ -63,12 +62,12 @@ def plot_all():
63
  ax.fill(angles, values, alpha=0.1)
64
  ax.set_xticks(angles[:-1])
65
  ax.set_xticklabels(subset_cols.astype(int))
66
- plt.title('Radar Plot of Mean Spectra (Subset Wavelengths)')
67
  plt.legend(loc='upper right', bbox_to_anchor=(1.3, 1.1))
68
  plt.tight_layout()
69
  plots.append(fig3)
70
 
71
- # Plot 4: Cumulative PCA Explained Variance
72
  fig4 = plt.figure(figsize=(8, 5))
73
  X = df.iloc[:, 1:].values
74
  X_scaled = StandardScaler().fit_transform(X)
@@ -78,14 +77,14 @@ def plot_all():
78
  plt.plot(range(1, 21), explained, marker='o')
79
  plt.axhline(y=0.95, color='r', linestyle='--', label='95% Variance')
80
  plt.title('Cumulative Explained Variance by PCA')
81
- plt.xlabel('Number of Principal Components')
82
  plt.ylabel('Cumulative Variance')
83
  plt.legend()
84
  plt.grid(True)
85
  plt.tight_layout()
86
  plots.append(fig4)
87
 
88
- # Plot 5: Derivative + Normalized Spectra
89
  fig5 = plt.figure(figsize=(16, 8))
90
  y_vals = df['Label'].values
91
  wavelengths = df.columns[1:].astype(float)
@@ -99,38 +98,49 @@ def plot_all():
99
  indices = np.where(y_vals == label)[0]
100
  for i in indices:
101
  plt.plot(wavelengths, X_deriv_norm[i], color=color, alpha=0.3, label=f'Milk {label}' if i == indices[0] else '')
102
- plt.title("All Spectra After First Derivative + Normalization")
103
  plt.xlabel("Wavelength (nm)")
104
- plt.ylabel("Normalized First Derivative")
105
- plt.legend(title="Group")
106
  plt.grid(True)
107
  plt.tight_layout()
108
  plots.append(fig5)
109
 
110
- # Plot 6: Derivative Only (No Norm)
111
  fig6 = plt.figure(figsize=(16, 8))
112
  for label, color in zip(unique_labels, colors):
113
  indices = np.where(y_vals == label)[0]
114
  for i in indices:
115
  plt.plot(wavelengths, X_deriv[i], color=color, alpha=0.3, label=f'Milk {label}' if i == indices[0] else '')
116
- plt.title("All Spectra After First Derivative (No Normalization)")
117
  plt.xlabel("Wavelength (nm)")
118
- plt.ylabel("First Derivative Absorbance")
119
- plt.legend(title="Group")
120
  plt.grid(True)
121
  plt.tight_layout()
122
  plots.append(fig6)
123
 
124
- return [gr.Plot(plt_fig) for plt_fig in plots]
125
 
126
- # Gradio UI for Dataset Description
127
  with gr.Blocks() as demo:
128
  gr.Markdown("# 🧪 Dataset Description")
129
- gr.DataFrame(df.head(50), label="Preview of Raw Data")
130
- plot_button = gr.Button("Generate Spectroscopy Visualizations")
131
- out_gallery = gr.Gallery(label="All Plots", columns=2)
132
- plot_button.click(fn=plot_all, inputs=[], outputs=out_gallery)
133
 
134
- demo.launch(server_name="0.0.0.0", server_port=7860, ssr_mode=False)
135
 
 
 
 
 
 
 
 
136
 
 
 
 
 
 
 
 
 
7
  from scipy.signal import savgol_filter
8
  from math import pi
9
 
10
+ plt.switch_backend('agg') # Required for headless environments
 
11
 
12
  # Load dataset
13
  df = pd.read_csv("milk_absorbance.csv")
14
  df.rename(columns={df.columns[0]: 'Label'}, inplace=True)
15
 
16
+ # Function to generate all plots
17
  def plot_all():
18
  plots = []
19
 
20
+ # 1. Mean Spectra per Class
21
  fig1 = plt.figure(figsize=(12, 6))
22
  for label in df['Label'].unique():
23
  class_df = df[df['Label'] == label]
 
31
  plt.tight_layout()
32
  plots.append(fig1)
33
 
34
+ # 2. Offset Mean Spectra
35
  fig2 = plt.figure(figsize=(12, 6))
36
  offset_step = 0.1
37
  for i, label in enumerate(df['Label'].unique()):
 
39
  mean_spectrum = class_df.iloc[:, 1:].mean()
40
  offset = i * offset_step
41
  plt.plot(mean_spectrum.index.astype(int), mean_spectrum + offset, label=f'Label {label}')
42
+ plt.title('Offset Mean NIR Spectra')
43
  plt.xlabel('Wavelength (nm)')
44
+ plt.ylabel('Offset Absorbance')
45
+ plt.legend()
46
  plt.grid(True)
47
  plt.tight_layout()
48
  plots.append(fig2)
49
 
50
+ # 3. Radar Plot
51
  fig3 = plt.figure(figsize=(8, 8))
52
  ax = plt.subplot(111, polar=True)
53
  subset_cols = df.columns[1:][::20]
 
62
  ax.fill(angles, values, alpha=0.1)
63
  ax.set_xticks(angles[:-1])
64
  ax.set_xticklabels(subset_cols.astype(int))
65
+ plt.title('Radar Plot of Mean Spectra (Subset)')
66
  plt.legend(loc='upper right', bbox_to_anchor=(1.3, 1.1))
67
  plt.tight_layout()
68
  plots.append(fig3)
69
 
70
+ # 4. PCA Cumulative Variance
71
  fig4 = plt.figure(figsize=(8, 5))
72
  X = df.iloc[:, 1:].values
73
  X_scaled = StandardScaler().fit_transform(X)
 
77
  plt.plot(range(1, 21), explained, marker='o')
78
  plt.axhline(y=0.95, color='r', linestyle='--', label='95% Variance')
79
  plt.title('Cumulative Explained Variance by PCA')
80
+ plt.xlabel('Principal Components')
81
  plt.ylabel('Cumulative Variance')
82
  plt.legend()
83
  plt.grid(True)
84
  plt.tight_layout()
85
  plots.append(fig4)
86
 
87
+ # 5. Derivative + Normalized
88
  fig5 = plt.figure(figsize=(16, 8))
89
  y_vals = df['Label'].values
90
  wavelengths = df.columns[1:].astype(float)
 
98
  indices = np.where(y_vals == label)[0]
99
  for i in indices:
100
  plt.plot(wavelengths, X_deriv_norm[i], color=color, alpha=0.3, label=f'Milk {label}' if i == indices[0] else '')
101
+ plt.title("Spectra After 1st Derivative + Normalization")
102
  plt.xlabel("Wavelength (nm)")
103
+ plt.ylabel("Normalized Derivative")
104
+ plt.legend(title="Class")
105
  plt.grid(True)
106
  plt.tight_layout()
107
  plots.append(fig5)
108
 
109
+ # 6. Derivative Only (No Norm)
110
  fig6 = plt.figure(figsize=(16, 8))
111
  for label, color in zip(unique_labels, colors):
112
  indices = np.where(y_vals == label)[0]
113
  for i in indices:
114
  plt.plot(wavelengths, X_deriv[i], color=color, alpha=0.3, label=f'Milk {label}' if i == indices[0] else '')
115
+ plt.title("Spectra After 1st Derivative (No Normalization)")
116
  plt.xlabel("Wavelength (nm)")
117
+ plt.ylabel("Derivative Absorbance")
118
+ plt.legend(title="Class")
119
  plt.grid(True)
120
  plt.tight_layout()
121
  plots.append(fig6)
122
 
123
+ return plots
124
 
125
+ # Gradio UI
126
  with gr.Blocks() as demo:
127
  gr.Markdown("# 🧪 Dataset Description")
128
+ gr.DataFrame(df.head(10), label="📋 Preview of Milk Spectroscopy Data")
 
 
 
129
 
130
+ plot_button = gr.Button("📊 Generate Spectroscopy Visualizations")
131
 
132
+ # Individual Plot Outputs
133
+ plot1 = gr.Plot(label="Mean Spectra")
134
+ plot2 = gr.Plot(label="Offset Mean Spectra")
135
+ plot3 = gr.Plot(label="Radar Plot")
136
+ plot4 = gr.Plot(label="PCA Variance")
137
+ plot5 = gr.Plot(label="Derivative + Normalized")
138
+ plot6 = gr.Plot(label="Derivative Only")
139
 
140
+ plot_button.click(
141
+ fn=plot_all,
142
+ inputs=[],
143
+ outputs=[plot1, plot2, plot3, plot4, plot5, plot6]
144
+ )
145
+
146
+ demo.launch(server_name="0.0.0.0", server_port=7860, ssr_mode=False)