Cruicis commited on
Commit
f8f6829
·
verified ·
1 Parent(s): be331fa

Upload 4 files

Browse files
Files changed (5) hide show
  1. .gitattributes +2 -0
  2. app.py +187 -0
  3. clean_list.json +0 -0
  4. color_features.json +3 -0
  5. image_subset.json +3 -0
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ color_features.json filter=lfs diff=lfs merge=lfs -text
37
+ image_subset.json filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib.pyplot as plt
2
+ import seaborn as sns
3
+ plt.rcParams["figure.figsize"] = (15, 10)
4
+ plt.rcParams["figure.dpi"] = 125
5
+ plt.rcParams["font.size"] = 14
6
+ plt.rcParams['font.family'] = ['sans-serif']
7
+ plt.rcParams['font.sans-serif'] = ['DejaVu Sans']
8
+ plt.style.use('ggplot')
9
+ sns.set_style("whitegrid", {'axes.grid': False})
10
+ plt.rcParams['image.cmap'] = 'gray' # grayscale looks better
11
+ from pathlib import Path
12
+ import numpy as np
13
+ import pandas as pd
14
+ import os
15
+ from skimage.io import imread as imread
16
+ from skimage.util import montage
17
+ from PIL import Image
18
+ montage_rgb = lambda x: np.stack([montage(x[:, :, :, i]) for i in range(x.shape[3])], -1)
19
+ from skimage.color import label2rgb
20
+
21
+ image_dir = Path('D:/vscodeim/food classification/feature_food/')
22
+
23
+
24
+ mapping_file = Path('D:/vscodeim/food classification/feature_food/clean_list.json')
25
+ alleg_df = pd.read_json(mapping_file)
26
+
27
+ alleg_df['image_path'] = alleg_df['image_path'].map(lambda x: image_dir / 'new' / x)
28
+ print(alleg_df['image_path'].map(lambda x: x.exists()).value_counts())
29
+ allergens = alleg_df.columns[3:].tolist()
30
+ alleg_df.sample(2)
31
+
32
+ import os
33
+ from pathlib import Path
34
+
35
+
36
+ color_file = Path('D:/vscodeim/food classification/feature_food/color_features.json')
37
+ color_feat_df = pd.read_json(color_file)
38
+ color_feat_df['image_path'] = color_feat_df['image_path'].map(lambda x: image_dir / 'new' / x)
39
+
40
+ color_feat_dict = {c_row['image_path']: c_row['color_features'] for _, c_row in color_feat_df.iterrows()}
41
+ # add a new color feature column
42
+ alleg_df['color_features'] = alleg_df['image_path'].map(color_feat_dict.get)
43
+ alleg_df.sample(2)
44
+
45
+
46
+ co_all = np.corrcoef(np.stack(alleg_df[allergens].applymap(lambda x: 1 if x>0 else 0).values, 0).T)
47
+ fig, ax1 = plt.subplots(1, 1, figsize=(10, 10))
48
+ sns.heatmap(co_all, annot=True, fmt='2.1%', ax=ax1, cmap='RdBu', vmin=-1, vmax=1)
49
+ ax1.set_xticklabels(allergens, rotation=90)
50
+ ax1.set_yticklabels(allergens);
51
+
52
+
53
+ # package the allergens together
54
+ alleg_df['allergy_vec'] = alleg_df[allergens].applymap(lambda x: 1 if x>0 else 0).values.tolist()
55
+
56
+ from sklearn.model_selection import train_test_split
57
+ train_df, valid_df = train_test_split(alleg_df.drop(columns='ingredients_list'),
58
+ test_size=0.1,
59
+ random_state=2019,
60
+ stratify=alleg_df['allergy_vec'].map(lambda x: x[0:3]))
61
+
62
+ train_df.reset_index(inplace=True)
63
+ valid_df.reset_index(inplace=True)
64
+
65
+
66
+ print(train_df.shape[0], 'training images')
67
+ print(valid_df.shape[0], 'validation images')
68
+
69
+ train_x_vec = np.stack(train_df['color_features'].values, 0)
70
+ train_y_vec = np.stack(train_df['allergy_vec'], 0)
71
+
72
+
73
+
74
+ # from sklearn.pipeline import make_pipeline
75
+ # from sklearn.ensemble import RandomForestRegressor
76
+ # from sklearn.preprocessing import RobustScaler # 导入 RobustScaler
77
+
78
+ import joblib
79
+ loaded_rf_pipe = joblib.load('rf_model.joblib')
80
+
81
+ import streamlit as st
82
+ import pandas as pd
83
+ import numpy as np
84
+ from sklearn.metrics import roc_auc_score, roc_curve, accuracy_score
85
+ import matplotlib.pyplot as plt
86
+ from imageio import imread
87
+
88
+ @st.cache
89
+ def show_model_results(in_model, picture_number=None):
90
+ # Your data loading code here
91
+ # Ensure 'valid_df', 'valid_x_vec', 'valid_y_vec', 'train_df', 'train_x_vec', 'train_y_vec', 'allergens' are defined
92
+
93
+
94
+
95
+ valid_x_vec = np.stack(valid_df['color_features'].values, 0)
96
+ valid_y_vec = np.stack(valid_df['allergy_vec'], 0)
97
+
98
+ x_vec = valid_x_vec
99
+ y_vec = valid_y_vec
100
+
101
+
102
+ valid_pred = in_model.predict(x_vec)
103
+ valid_num = picture_number
104
+
105
+ fig, m_axs = plt.subplots(2, 2, figsize=(10, 10))
106
+ all_rows = []
107
+ ax1 = m_axs[0, 0]
108
+ for i, c_allergen in enumerate(allergens):
109
+ tpr, fpr, _ = roc_curve(y_vec[:, i], valid_pred[:, i])
110
+ auc = roc_auc_score(y_vec[:, i], valid_pred[:, i])
111
+ acc = accuracy_score(y_vec[:, i], valid_pred[:, i] > 0.5)
112
+ ax1.plot(tpr, fpr, '.-', label='{}: AUC {:0.2f}, Accuracy: {:2.0%}'.format(c_allergen, auc, acc))
113
+ all_rows += [{'allegen': c_allergen,
114
+ 'prediction': valid_pred[j, i],
115
+ 'class': 'Positive' if y_vec[j, i] > 0.5 else 'Negative'}
116
+ for j in range(valid_pred.shape[0])]
117
+
118
+ d_ax = m_axs[0, 1]
119
+ t_yp = np.mean(valid_pred, 0)
120
+ t_y = np.mean(y_vec, 0)
121
+ d_ax.barh(np.arange(len(allergens)) + 0.1, t_yp, alpha=0.5, label='Predicted')
122
+ d_ax.barh(np.arange(len(allergens)) - 0.1, t_y + 0.001, alpha=0.5, label='Ground Truth')
123
+ d_ax.set_xlim(0, 1)
124
+ d_ax.set_yticks(range(len(allergens)))
125
+ d_ax.set_yticklabels(allergens, rotation=0)
126
+ d_ax.set_title('Overall')
127
+ d_ax.legend()
128
+
129
+ ax1.legend()
130
+ for (_, c_row), (c_ax, d_ax) in zip(
131
+ valid_df.iloc[valid_num:valid_num+1].iterrows(),
132
+ m_axs[1:]):
133
+ c_ax.imshow(imread(c_row['image_path']))
134
+ c_ax.set_title(c_row['title'])
135
+ c_ax.axis('off')
136
+ t_yp = in_model.predict(np.expand_dims(c_row['color_features'], 0))
137
+ t_y = np.array(c_row['allergy_vec'])
138
+ d_ax.barh(np.arange(len(allergens)) + 0.1, t_yp[0], alpha=0.5, label='Predicted')
139
+ d_ax.barh(np.arange(len(allergens)) - 0.1, t_y + 0.001, alpha=0.5, label='Ground Truth')
140
+ d_ax.set_yticks(range(len(allergens)))
141
+ d_ax.set_yticklabels(allergens, rotation=0)
142
+ d_ax.set_xlim(0, 1)
143
+ d_ax.legend()
144
+
145
+ # 将当前图像添加到 Streamlit 页面
146
+ st.pyplot(fig)
147
+ return st.write("Completed")
148
+
149
+ # Assuming you have already defined 'rf_pipe' and 'valid_df' with image paths
150
+ image_paths = valid_df['image_path'].tolist()
151
+
152
+
153
+
154
+ # Streamlit app
155
+ def main():
156
+ st.title('Model Results')
157
+ st.write(f'<span style="font-size:20px;">This is a prototype, so we use the images from the test set as examples.</span>', unsafe_allow_html=True)
158
+
159
+ image_paths = valid_df['image_path'].tolist()
160
+
161
+ num_rows = 2
162
+ num_cols = 5
163
+ fig, axes = plt.subplots(num_rows, num_cols, figsize=(15, 6))
164
+
165
+ for i in range(num_rows):
166
+ for j in range(num_cols):
167
+ index = i * num_cols + j
168
+ image = imread(image_paths[index])
169
+ axes[i, j].imshow(image)
170
+ axes[i, j].axis('off')
171
+ axes[i, j].set_title(f'Image {index+1}')
172
+
173
+ st.pyplot(fig)
174
+
175
+ num_images = 10
176
+
177
+ # User interaction to select image
178
+ st.write(f'<span style="font-size:20px;">Enter the image number you want to analyze.</span>', unsafe_allow_html=True)
179
+
180
+ choice = st.number_input(f"Range (1-{num_images}): ", min_value=1, max_value=num_images)
181
+
182
+ # Show model results
183
+ if st.button('Show Results'):
184
+ show_model_results(loaded_rf_pipe, choice-1)
185
+
186
+ if __name__ == "__main__":
187
+ main()
clean_list.json ADDED
The diff for this file is too large to render. See raw diff
 
color_features.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1c77b1fb4901ec6abd8c19c9bacd2728cedaad6cb3e5a699da6808d5bf4bdb44
3
+ size 18577814
image_subset.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:040c523a6f1d15f1e78ba215766a82b71274bbf128885d840b68582ee40ad9a8
3
+ size 67109125