Cruicis commited on
Commit
9ab8368
·
verified ·
1 Parent(s): 7d8cbc4

Delete app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -190
app.py DELETED
@@ -1,190 +0,0 @@
1
- import matplotlib.pyplot as plt
2
- import seaborn as sns
3
- plt.rcParams["figure.figsize"] = (15, 10)
4
- plt.rcParams["figure.dpi"] = 125
5
- plt.rcParams["font.size"] = 14
6
- plt.rcParams['font.family'] = ['sans-serif']
7
- plt.rcParams['font.sans-serif'] = ['DejaVu Sans']
8
- plt.style.use('ggplot')
9
- sns.set_style("whitegrid", {'axes.grid': False})
10
- plt.rcParams['image.cmap'] = 'gray' # grayscale looks better
11
- from pathlib import Path
12
- import numpy as np
13
- import pandas as pd
14
- import os
15
- from skimage.io import imread as imread
16
- from skimage.util import montage
17
- from PIL import Image
18
- montage_rgb = lambda x: np.stack([montage(x[:, :, :, i]) for i in range(x.shape[3])], -1)
19
- from skimage.color import label2rgb
20
-
21
- image_dir = Path('D:/vscodeim/food classification/feature_food/')
22
-
23
-
24
- mapping_file = Path('D:/vscodeim/food classification/feature_food/clean_list.json')
25
- alleg_df = pd.read_json(mapping_file)
26
-
27
- alleg_df['image_path'] = alleg_df['image_path'].map(lambda x: image_dir / 'subset' / x)
28
- print(alleg_df['image_path'].map(lambda x: x.exists()).value_counts())
29
- allergens = alleg_df.columns[3:].tolist()
30
- alleg_df.sample(2)
31
-
32
- import os
33
- from pathlib import Path
34
-
35
-
36
- color_file = Path('D:/vscodeim/food classification/feature_food/color_features.json')
37
- color_feat_df = pd.read_json(color_file)
38
- color_feat_df['image_path'] = color_feat_df['image_path'].map(lambda x: image_dir / 'subset' / x)
39
-
40
- color_feat_dict = {c_row['image_path']: c_row['color_features'] for _, c_row in color_feat_df.iterrows()}
41
- # add a new color feature column
42
- alleg_df['color_features'] = alleg_df['image_path'].map(color_feat_dict.get)
43
- alleg_df.sample(2)
44
-
45
-
46
- co_all = np.corrcoef(np.stack(alleg_df[allergens].applymap(lambda x: 1 if x>0 else 0).values, 0).T)
47
- fig, ax1 = plt.subplots(1, 1, figsize=(10, 10))
48
- sns.heatmap(co_all, annot=True, fmt='2.1%', ax=ax1, cmap='RdBu', vmin=-1, vmax=1)
49
- ax1.set_xticklabels(allergens, rotation=90)
50
- ax1.set_yticklabels(allergens);
51
-
52
-
53
- # package the allergens together
54
- alleg_df['allergy_vec'] = alleg_df[allergens].applymap(lambda x: 1 if x>0 else 0).values.tolist()
55
-
56
- from sklearn.model_selection import train_test_split
57
- train_df, valid_df = train_test_split(alleg_df.drop(columns='ingredients_list'),
58
- test_size=0.1,
59
- random_state=2019,
60
- stratify=alleg_df['allergy_vec'].map(lambda x: x[0:3]))
61
-
62
- train_df.reset_index(inplace=True)
63
- valid_df.reset_index(inplace=True)
64
-
65
-
66
- print(train_df.shape[0], 'training images')
67
- print(valid_df.shape[0], 'validation images')
68
-
69
- train_x_vec = np.stack(train_df['color_features'].values, 0)
70
- train_y_vec = np.stack(train_df['allergy_vec'], 0)
71
-
72
-
73
-
74
-
75
-
76
- from sklearn.pipeline import make_pipeline
77
- from sklearn.ensemble import RandomForestRegressor
78
- from sklearn.preprocessing import RobustScaler # 导入 RobustScaler
79
-
80
- rf_pipe = make_pipeline(RobustScaler(), RandomForestRegressor(n_estimators=15))
81
- rf_pipe.fit(train_x_vec, train_y_vec)
82
-
83
-
84
- import streamlit as st
85
- import pandas as pd
86
- import numpy as np
87
- from sklearn.metrics import roc_auc_score, roc_curve, accuracy_score
88
- import matplotlib.pyplot as plt
89
- from imageio import imread
90
-
91
- @st.cache
92
- def show_model_results(in_model, picture_number=None):
93
- # Your data loading code here
94
- # Ensure 'valid_df', 'valid_x_vec', 'valid_y_vec', 'train_df', 'train_x_vec', 'train_y_vec', 'allergens' are defined
95
-
96
-
97
-
98
- valid_x_vec = np.stack(valid_df['color_features'].values, 0)
99
- valid_y_vec = np.stack(valid_df['allergy_vec'], 0)
100
-
101
- x_vec = valid_x_vec
102
- y_vec = valid_y_vec
103
-
104
-
105
- valid_pred = in_model.predict(x_vec)
106
- valid_num = picture_number
107
-
108
- fig, m_axs = plt.subplots(2, 2, figsize=(10, 10))
109
- all_rows = []
110
- ax1 = m_axs[0, 0]
111
- for i, c_allergen in enumerate(allergens):
112
- tpr, fpr, _ = roc_curve(y_vec[:, i], valid_pred[:, i])
113
- auc = roc_auc_score(y_vec[:, i], valid_pred[:, i])
114
- acc = accuracy_score(y_vec[:, i], valid_pred[:, i] > 0.5)
115
- ax1.plot(tpr, fpr, '.-', label='{}: AUC {:0.2f}, Accuracy: {:2.0%}'.format(c_allergen, auc, acc))
116
- all_rows += [{'allegen': c_allergen,
117
- 'prediction': valid_pred[j, i],
118
- 'class': 'Positive' if y_vec[j, i] > 0.5 else 'Negative'}
119
- for j in range(valid_pred.shape[0])]
120
-
121
- d_ax = m_axs[0, 1]
122
- t_yp = np.mean(valid_pred, 0)
123
- t_y = np.mean(y_vec, 0)
124
- d_ax.barh(np.arange(len(allergens)) + 0.1, t_yp, alpha=0.5, label='Predicted')
125
- d_ax.barh(np.arange(len(allergens)) - 0.1, t_y + 0.001, alpha=0.5, label='Ground Truth')
126
- d_ax.set_xlim(0, 1)
127
- d_ax.set_yticks(range(len(allergens)))
128
- d_ax.set_yticklabels(allergens, rotation=0)
129
- d_ax.set_title('Overall')
130
- d_ax.legend()
131
-
132
- ax1.legend()
133
- for (_, c_row), (c_ax, d_ax) in zip(
134
- valid_df.iloc[valid_num:valid_num+1].iterrows(),
135
- m_axs[1:]):
136
- c_ax.imshow(imread(c_row['image_path']))
137
- c_ax.set_title(c_row['title'])
138
- c_ax.axis('off')
139
- t_yp = in_model.predict(np.expand_dims(c_row['color_features'], 0))
140
- t_y = np.array(c_row['allergy_vec'])
141
- d_ax.barh(np.arange(len(allergens)) + 0.1, t_yp[0], alpha=0.5, label='Predicted')
142
- d_ax.barh(np.arange(len(allergens)) - 0.1, t_y + 0.001, alpha=0.5, label='Ground Truth')
143
- d_ax.set_yticks(range(len(allergens)))
144
- d_ax.set_yticklabels(allergens, rotation=0)
145
- d_ax.set_xlim(0, 1)
146
- d_ax.legend()
147
-
148
- # 将当前图像添加到 Streamlit 页面
149
- st.pyplot(fig)
150
- return st.write("Completed")
151
-
152
- # Assuming you have already defined 'rf_pipe' and 'valid_df' with image paths
153
- image_paths = valid_df['image_path'].tolist()
154
-
155
-
156
- # Streamlit app
157
- # Streamlit app
158
- def main():
159
- st.title('Model Results')
160
- st.write(f'<span style="font-size:20px;">This is a prototype, so we use the images from the test set as examples.</span>', unsafe_allow_html=True)
161
-
162
- image_paths = valid_df['image_path'].tolist()
163
-
164
- num_rows = 2
165
- num_cols = 5
166
- fig, axes = plt.subplots(num_rows, num_cols, figsize=(15, 6))
167
-
168
- for i in range(num_rows):
169
- for j in range(num_cols):
170
- index = i * num_cols + j + 21 # Starting from index 21
171
- image = imread(image_paths[index])
172
- axes[i, j].imshow(image)
173
- axes[i, j].axis('off')
174
- axes[i, j].set_title(f'Image {index-20}')
175
-
176
- st.pyplot(fig)
177
-
178
- num_images = 10
179
-
180
- # User interaction to select image
181
- st.write(f'<span style="font-size:20px;">Enter the image number you want to analyze.</span>', unsafe_allow_html=True)
182
-
183
- choice = st.number_input(f"Range (1-{num_images}): ", min_value=1, max_value=num_images)
184
-
185
- # Show model results
186
- if st.button('Show Results'):
187
- show_model_results(rf_pipe, choice+20)
188
-
189
- if __name__ == "__main__":
190
- main()