Cruicis commited on
Commit
8fdbedb
·
verified ·
1 Parent(s): a96b0bd

Upload 4 files

Browse files
Files changed (5) hide show
  1. .gitattributes +2 -0
  2. app.py +190 -0
  3. clean_list.json +0 -0
  4. color_features.json +3 -0
  5. image_subset.json +3 -0
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ color_features.json filter=lfs diff=lfs merge=lfs -text
37
+ image_subset.json filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib.pyplot as plt
2
+ import seaborn as sns
3
+ plt.rcParams["figure.figsize"] = (15, 10)
4
+ plt.rcParams["figure.dpi"] = 125
5
+ plt.rcParams["font.size"] = 14
6
+ plt.rcParams['font.family'] = ['sans-serif']
7
+ plt.rcParams['font.sans-serif'] = ['DejaVu Sans']
8
+ plt.style.use('ggplot')
9
+ sns.set_style("whitegrid", {'axes.grid': False})
10
+ plt.rcParams['image.cmap'] = 'gray' # grayscale looks better
11
+ from pathlib import Path
12
+ import numpy as np
13
+ import pandas as pd
14
+ import os
15
+ from skimage.io import imread as imread
16
+ from skimage.util import montage
17
+ from PIL import Image
18
+ montage_rgb = lambda x: np.stack([montage(x[:, :, :, i]) for i in range(x.shape[3])], -1)
19
+ from skimage.color import label2rgb
20
+
21
+ image_dir = Path('D:/vscodeim/food classification/feature_food/')
22
+
23
+
24
+ mapping_file = Path('D:/vscodeim/food classification/feature_food/clean_list.json')
25
+ alleg_df = pd.read_json(mapping_file)
26
+
27
+ alleg_df['image_path'] = alleg_df['image_path'].map(lambda x: image_dir / 'subset' / x)
28
+ print(alleg_df['image_path'].map(lambda x: x.exists()).value_counts())
29
+ allergens = alleg_df.columns[3:].tolist()
30
+ alleg_df.sample(2)
31
+
32
+ import os
33
+ from pathlib import Path
34
+
35
+
36
+ color_file = Path('D:/vscodeim/food classification/feature_food/color_features.json')
37
+ color_feat_df = pd.read_json(color_file)
38
+ color_feat_df['image_path'] = color_feat_df['image_path'].map(lambda x: image_dir / 'subset' / x)
39
+
40
+ color_feat_dict = {c_row['image_path']: c_row['color_features'] for _, c_row in color_feat_df.iterrows()}
41
+ # add a new color feature column
42
+ alleg_df['color_features'] = alleg_df['image_path'].map(color_feat_dict.get)
43
+ alleg_df.sample(2)
44
+
45
+
46
+ co_all = np.corrcoef(np.stack(alleg_df[allergens].applymap(lambda x: 1 if x>0 else 0).values, 0).T)
47
+ fig, ax1 = plt.subplots(1, 1, figsize=(10, 10))
48
+ sns.heatmap(co_all, annot=True, fmt='2.1%', ax=ax1, cmap='RdBu', vmin=-1, vmax=1)
49
+ ax1.set_xticklabels(allergens, rotation=90)
50
+ ax1.set_yticklabels(allergens);
51
+
52
+
53
+ # package the allergens together
54
+ alleg_df['allergy_vec'] = alleg_df[allergens].applymap(lambda x: 1 if x>0 else 0).values.tolist()
55
+
56
+ from sklearn.model_selection import train_test_split
57
+ train_df, valid_df = train_test_split(alleg_df.drop(columns='ingredients_list'),
58
+ test_size=0.1,
59
+ random_state=2019,
60
+ stratify=alleg_df['allergy_vec'].map(lambda x: x[0:3]))
61
+
62
+ train_df.reset_index(inplace=True)
63
+ valid_df.reset_index(inplace=True)
64
+
65
+
66
+ print(train_df.shape[0], 'training images')
67
+ print(valid_df.shape[0], 'validation images')
68
+
69
+ train_x_vec = np.stack(train_df['color_features'].values, 0)
70
+ train_y_vec = np.stack(train_df['allergy_vec'], 0)
71
+
72
+
73
+
74
+
75
+
76
+ from sklearn.pipeline import make_pipeline
77
+ from sklearn.ensemble import RandomForestRegressor
78
+ from sklearn.preprocessing import RobustScaler # 导入 RobustScaler
79
+
80
+ rf_pipe = make_pipeline(RobustScaler(), RandomForestRegressor(n_estimators=15))
81
+ rf_pipe.fit(train_x_vec, train_y_vec)
82
+
83
+
84
+ import streamlit as st
85
+ import pandas as pd
86
+ import numpy as np
87
+ from sklearn.metrics import roc_auc_score, roc_curve, accuracy_score
88
+ import matplotlib.pyplot as plt
89
+ from imageio import imread
90
+
91
+ @st.cache
92
+ def show_model_results(in_model, picture_number=None):
93
+ # Your data loading code here
94
+ # Ensure 'valid_df', 'valid_x_vec', 'valid_y_vec', 'train_df', 'train_x_vec', 'train_y_vec', 'allergens' are defined
95
+
96
+
97
+
98
+ valid_x_vec = np.stack(valid_df['color_features'].values, 0)
99
+ valid_y_vec = np.stack(valid_df['allergy_vec'], 0)
100
+
101
+ x_vec = valid_x_vec
102
+ y_vec = valid_y_vec
103
+
104
+
105
+ valid_pred = in_model.predict(x_vec)
106
+ valid_num = picture_number
107
+
108
+ fig, m_axs = plt.subplots(2, 2, figsize=(10, 10))
109
+ all_rows = []
110
+ ax1 = m_axs[0, 0]
111
+ for i, c_allergen in enumerate(allergens):
112
+ tpr, fpr, _ = roc_curve(y_vec[:, i], valid_pred[:, i])
113
+ auc = roc_auc_score(y_vec[:, i], valid_pred[:, i])
114
+ acc = accuracy_score(y_vec[:, i], valid_pred[:, i] > 0.5)
115
+ ax1.plot(tpr, fpr, '.-', label='{}: AUC {:0.2f}, Accuracy: {:2.0%}'.format(c_allergen, auc, acc))
116
+ all_rows += [{'allegen': c_allergen,
117
+ 'prediction': valid_pred[j, i],
118
+ 'class': 'Positive' if y_vec[j, i] > 0.5 else 'Negative'}
119
+ for j in range(valid_pred.shape[0])]
120
+
121
+ d_ax = m_axs[0, 1]
122
+ t_yp = np.mean(valid_pred, 0)
123
+ t_y = np.mean(y_vec, 0)
124
+ d_ax.barh(np.arange(len(allergens)) + 0.1, t_yp, alpha=0.5, label='Predicted')
125
+ d_ax.barh(np.arange(len(allergens)) - 0.1, t_y + 0.001, alpha=0.5, label='Ground Truth')
126
+ d_ax.set_xlim(0, 1)
127
+ d_ax.set_yticks(range(len(allergens)))
128
+ d_ax.set_yticklabels(allergens, rotation=0)
129
+ d_ax.set_title('Overall')
130
+ d_ax.legend()
131
+
132
+ ax1.legend()
133
+ for (_, c_row), (c_ax, d_ax) in zip(
134
+ valid_df.iloc[valid_num:valid_num+1].iterrows(),
135
+ m_axs[1:]):
136
+ c_ax.imshow(imread(c_row['image_path']))
137
+ c_ax.set_title(c_row['title'])
138
+ c_ax.axis('off')
139
+ t_yp = in_model.predict(np.expand_dims(c_row['color_features'], 0))
140
+ t_y = np.array(c_row['allergy_vec'])
141
+ d_ax.barh(np.arange(len(allergens)) + 0.1, t_yp[0], alpha=0.5, label='Predicted')
142
+ d_ax.barh(np.arange(len(allergens)) - 0.1, t_y + 0.001, alpha=0.5, label='Ground Truth')
143
+ d_ax.set_yticks(range(len(allergens)))
144
+ d_ax.set_yticklabels(allergens, rotation=0)
145
+ d_ax.set_xlim(0, 1)
146
+ d_ax.legend()
147
+
148
+ # 将当前图像添加到 Streamlit 页面
149
+ st.pyplot(fig)
150
+ return st.write("Completed")
151
+
152
+ # Assuming you have already defined 'rf_pipe' and 'valid_df' with image paths
153
+ image_paths = valid_df['image_path'].tolist()
154
+
155
+
156
+ # Streamlit app
157
+ # Streamlit app
158
+ def main():
159
+ st.title('Model Results')
160
+ st.write(f'<span style="font-size:20px;">This is a prototype, so we use the images from the test set as examples.</span>', unsafe_allow_html=True)
161
+
162
+ image_paths = valid_df['image_path'].tolist()
163
+
164
+ num_rows = 2
165
+ num_cols = 5
166
+ fig, axes = plt.subplots(num_rows, num_cols, figsize=(15, 6))
167
+
168
+ for i in range(num_rows):
169
+ for j in range(num_cols):
170
+ index = i * num_cols + j + 21 # Starting from index 21
171
+ image = imread(image_paths[index])
172
+ axes[i, j].imshow(image)
173
+ axes[i, j].axis('off')
174
+ axes[i, j].set_title(f'Image {index-20}')
175
+
176
+ st.pyplot(fig)
177
+
178
+ num_images = 10
179
+
180
+ # User interaction to select image
181
+ st.write(f'<span style="font-size:20px;">Enter the image number you want to analyze.</span>', unsafe_allow_html=True)
182
+
183
+ choice = st.number_input(f"Range (1-{num_images}): ", min_value=1, max_value=num_images)
184
+
185
+ # Show model results
186
+ if st.button('Show Results'):
187
+ show_model_results(rf_pipe, choice+20)
188
+
189
+ if __name__ == "__main__":
190
+ main()
clean_list.json ADDED
The diff for this file is too large to render. See raw diff
 
color_features.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1c77b1fb4901ec6abd8c19c9bacd2728cedaad6cb3e5a699da6808d5bf4bdb44
3
+ size 18577814
image_subset.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:040c523a6f1d15f1e78ba215766a82b71274bbf128885d840b68582ee40ad9a8
3
+ size 67109125