clementBE commited on
Commit
64c0fa8
·
verified ·
1 Parent(s): 79d7042

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -53
app.py CHANGED
@@ -73,7 +73,7 @@ def get_dominant_color(image,num_colors=5):
73
  # ---------------------------
74
  def classify_zip_and_analyze_color(zip_file):
75
  results = []
76
- images_dict = {} # store images for preview
77
  zip_name = os.path.splitext(os.path.basename(zip_file.name))[0]
78
  date_str = datetime.now().strftime("%Y%m%d")
79
 
@@ -90,7 +90,6 @@ def classify_zip_and_analyze_color(zip_file):
90
  except:
91
  continue
92
 
93
- # Classification
94
  input_tensor = transform(image).unsqueeze(0).to(device)
95
  with torch.no_grad():
96
  output = model(input_tensor)
@@ -98,11 +97,9 @@ def classify_zip_and_analyze_color(zip_file):
98
  top3_prob, top3_idx = torch.topk(probs,3)
99
  preds = [(imagenet_classes[idx], f"{prob.item()*100:.2f}%") for idx,prob in zip(top3_idx, top3_prob)]
100
 
101
- # Dominant color
102
  rgb, hex_color = get_dominant_color(image)
103
  basic_color = closest_basic_color(rgb)
104
 
105
- # Face detection
106
  faces_data = []
107
  try:
108
  img_cv2 = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
@@ -129,39 +126,27 @@ def classify_zip_and_analyze_color(zip_file):
129
  df.to_excel(out_xlsx,index=False)
130
 
131
  # ---------------------------
132
- # Plot 1: Basic color frequency
133
  # ---------------------------
 
134
  fig1, ax1 = plt.subplots()
135
  color_counts = df["Basic Color"].value_counts()
136
  ax1.bar(color_counts.index, color_counts.values, color="skyblue")
137
  ax1.set_title("Basic Color Frequency")
138
  ax1.set_ylabel("Count")
139
- buf1 = io.BytesIO()
140
- plt.savefig(buf1, format="png")
141
- plt.close(fig1)
142
- buf1.seek(0)
143
- plot1_img = Image.open(buf1)
144
 
145
- # ---------------------------
146
- # Plot 2: Top prediction distribution
147
- # ---------------------------
148
  fig2, ax2 = plt.subplots()
149
  preds_flat = []
150
- for p in df["Top 3 Predictions"]:
151
- preds_flat.extend(p.split(", "))
152
  pred_counts = pd.Series(preds_flat).value_counts().head(20)
153
  ax2.barh(pred_counts.index[::-1], pred_counts.values[::-1], color="salmon")
154
  ax2.set_title("Top Prediction Distribution")
155
  ax2.set_xlabel("Count")
156
- buf2 = io.BytesIO()
157
- plt.savefig(buf2, format="png", bbox_inches="tight")
158
- plt.close(fig2)
159
- buf2.seek(0)
160
- plot2_img = Image.open(buf2)
161
 
162
- # ---------------------------
163
- # Gender & Age
164
- # ---------------------------
165
  ages_male, ages_female = [], []
166
  gender_confidence = {"Homme":0, "Femme":0}
167
  for face_list in df["Face Info"]:
@@ -173,55 +158,41 @@ def classify_zip_and_analyze_color(zip_file):
173
  weight = min(conf,0.9)
174
  gender_trans = "Homme" if gender=="Man" else "Femme"
175
  gender_confidence[gender_trans] += weight
176
- if gender_trans=="Homme":
177
- ages_male.append(age)
178
- else:
179
- ages_female.append(age)
180
 
181
- # Gender distribution
182
  fig3, ax3 = plt.subplots()
183
  ax3.bar(gender_confidence.keys(), gender_confidence.values(), color=["lightblue","pink"])
184
  ax3.set_title("Gender Distribution (Weighted ≤90%)")
185
  ax3.set_ylabel("Sum of Confidence")
186
- buf3 = io.BytesIO()
187
- plt.savefig(buf3, format="png")
188
- plt.close(fig3)
189
- buf3.seek(0)
190
- plot3_img = Image.open(buf3)
191
 
192
- # Age distribution
193
  fig4, ax4 = plt.subplots()
194
  bins = range(0,101,5)
195
  ax4.hist([ages_male, ages_female], bins=bins, color=["lightblue","pink"], label=["Homme","Femme"], edgecolor="black")
196
  ax4.set_title("Age Distribution by Gender")
197
- ax4.set_xlabel("Age")
198
- ax4.set_ylabel("Count")
199
- ax4.legend()
200
- buf4 = io.BytesIO()
201
- plt.savefig(buf4, format="png")
202
- plt.close(fig4)
203
- buf4.seek(0)
204
- plot4_img = Image.open(buf4)
205
 
206
- return df, list(images_dict.keys()), images_dict, out_xlsx, plot1_img, plot2_img, plot3_img, plot4_img
207
 
208
  # ---------------------------
209
  # Preview callback
210
  # ---------------------------
211
- def show_preview(selected_file, images_state):
212
- if images_state is None or selected_file is None:
213
  return None
214
- return images_state.get(selected_file, None)
 
215
 
216
  # ---------------------------
217
  # Gradio interface
218
  # ---------------------------
219
  with gr.Blocks() as demo:
220
  uploaded_zip = gr.File(label="Upload ZIP of images", file_types=[".zip"])
221
- analyze_btn = gr.Button("Run Analysis") # Run button just after upload
222
 
223
- output_df = gr.Dataframe(headers=["Filename","Top 3 Predictions","Confidence","Dominant Color","Basic Color","Face Info"])
224
- image_dropdown = gr.Dropdown(label="Select image to preview")
225
  image_preview = gr.Image(label="Image Preview")
226
  download_file = gr.File(label="Download XLSX")
227
  images_state = gr.State()
@@ -232,15 +203,15 @@ with gr.Blocks() as demo:
232
  plot4 = gr.Image(label="Age Distribution by Gender")
233
 
234
  def run_analysis(zip_file):
235
- df, filenames, images_dict, out_xlsx, p1, p2, p3, p4 = classify_zip_and_analyze_color(zip_file)
236
- return df, filenames, images_dict, out_xlsx, p1, p2, p3, p4
237
 
238
  analyze_btn.click(
239
  run_analysis,
240
  inputs=uploaded_zip,
241
- outputs=[output_df, image_dropdown, images_state, download_file, plot1, plot2, plot3, plot4]
242
  )
243
 
244
- image_dropdown.change(show_preview, inputs=[image_dropdown, images_state], outputs=image_preview)
245
 
246
  demo.launch(server_name="0.0.0.0", server_port=7860)
 
73
  # ---------------------------
74
  def classify_zip_and_analyze_color(zip_file):
75
  results = []
76
+ images_dict = {}
77
  zip_name = os.path.splitext(os.path.basename(zip_file.name))[0]
78
  date_str = datetime.now().strftime("%Y%m%d")
79
 
 
90
  except:
91
  continue
92
 
 
93
  input_tensor = transform(image).unsqueeze(0).to(device)
94
  with torch.no_grad():
95
  output = model(input_tensor)
 
97
  top3_prob, top3_idx = torch.topk(probs,3)
98
  preds = [(imagenet_classes[idx], f"{prob.item()*100:.2f}%") for idx,prob in zip(top3_idx, top3_prob)]
99
 
 
100
  rgb, hex_color = get_dominant_color(image)
101
  basic_color = closest_basic_color(rgb)
102
 
 
103
  faces_data = []
104
  try:
105
  img_cv2 = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
 
126
  df.to_excel(out_xlsx,index=False)
127
 
128
  # ---------------------------
129
+ # Plots
130
  # ---------------------------
131
+ # Basic color
132
  fig1, ax1 = plt.subplots()
133
  color_counts = df["Basic Color"].value_counts()
134
  ax1.bar(color_counts.index, color_counts.values, color="skyblue")
135
  ax1.set_title("Basic Color Frequency")
136
  ax1.set_ylabel("Count")
137
+ buf1 = io.BytesIO(); plt.savefig(buf1, format="png"); plt.close(fig1); buf1.seek(0); plot1_img = Image.open(buf1)
 
 
 
 
138
 
139
+ # Top predictions
 
 
140
  fig2, ax2 = plt.subplots()
141
  preds_flat = []
142
+ for p in df["Top 3 Predictions"]: preds_flat.extend(p.split(", "))
 
143
  pred_counts = pd.Series(preds_flat).value_counts().head(20)
144
  ax2.barh(pred_counts.index[::-1], pred_counts.values[::-1], color="salmon")
145
  ax2.set_title("Top Prediction Distribution")
146
  ax2.set_xlabel("Count")
147
+ buf2 = io.BytesIO(); plt.savefig(buf2, format="png", bbox_inches="tight"); plt.close(fig2); buf2.seek(0); plot2_img = Image.open(buf2)
 
 
 
 
148
 
149
+ # Gender and age
 
 
150
  ages_male, ages_female = [], []
151
  gender_confidence = {"Homme":0, "Femme":0}
152
  for face_list in df["Face Info"]:
 
158
  weight = min(conf,0.9)
159
  gender_trans = "Homme" if gender=="Man" else "Femme"
160
  gender_confidence[gender_trans] += weight
161
+ if gender_trans=="Homme": ages_male.append(age)
162
+ else: ages_female.append(age)
 
 
163
 
 
164
  fig3, ax3 = plt.subplots()
165
  ax3.bar(gender_confidence.keys(), gender_confidence.values(), color=["lightblue","pink"])
166
  ax3.set_title("Gender Distribution (Weighted ≤90%)")
167
  ax3.set_ylabel("Sum of Confidence")
168
+ buf3 = io.BytesIO(); plt.savefig(buf3, format="png"); plt.close(fig3); buf3.seek(0); plot3_img = Image.open(buf3)
 
 
 
 
169
 
 
170
  fig4, ax4 = plt.subplots()
171
  bins = range(0,101,5)
172
  ax4.hist([ages_male, ages_female], bins=bins, color=["lightblue","pink"], label=["Homme","Femme"], edgecolor="black")
173
  ax4.set_title("Age Distribution by Gender")
174
+ ax4.set_xlabel("Age"); ax4.set_ylabel("Count"); ax4.legend()
175
+ buf4 = io.BytesIO(); plt.savefig(buf4, format="png"); plt.close(fig4); buf4.seek(0); plot4_img = Image.open(buf4)
 
 
 
 
 
 
176
 
177
+ return df, images_dict, out_xlsx, plot1_img, plot2_img, plot3_img, plot4_img
178
 
179
  # ---------------------------
180
  # Preview callback
181
  # ---------------------------
182
+ def show_preview(selected_row, images_state):
183
+ if images_state is None or selected_row is None:
184
  return None
185
+ filename = selected_row[0] # first column is filename
186
+ return images_state.get(filename, None)
187
 
188
  # ---------------------------
189
  # Gradio interface
190
  # ---------------------------
191
  with gr.Blocks() as demo:
192
  uploaded_zip = gr.File(label="Upload ZIP of images", file_types=[".zip"])
193
+ analyze_btn = gr.Button("Run Analysis")
194
 
195
+ output_df = gr.Dataframe(headers=["Filename","Top 3 Predictions","Confidence","Dominant Color","Basic Color","Face Info"], interactive=True)
 
196
  image_preview = gr.Image(label="Image Preview")
197
  download_file = gr.File(label="Download XLSX")
198
  images_state = gr.State()
 
203
  plot4 = gr.Image(label="Age Distribution by Gender")
204
 
205
  def run_analysis(zip_file):
206
+ df, images_dict, out_xlsx, p1, p2, p3, p4 = classify_zip_and_analyze_color(zip_file)
207
+ return df, images_dict, out_xlsx, p1, p2, p3, p4
208
 
209
  analyze_btn.click(
210
  run_analysis,
211
  inputs=uploaded_zip,
212
+ outputs=[output_df, images_state, download_file, plot1, plot2, plot3, plot4]
213
  )
214
 
215
+ output_df.select(show_preview, inputs=[output_df, images_state], outputs=image_preview)
216
 
217
  demo.launch(server_name="0.0.0.0", server_port=7860)