clementBE commited on
Commit
bb115d5
·
verified ·
1 Parent(s): 0be21d8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -27
app.py CHANGED
@@ -1,7 +1,4 @@
1
- import os
2
- import zipfile
3
- import tempfile
4
- import requests
5
  import numpy as np
6
  import pandas as pd
7
  from PIL import Image
@@ -11,9 +8,7 @@ from torchvision import transforms
11
  from torchvision.models import resnet50, ResNet50_Weights
12
  from sklearn.cluster import MiniBatchKMeans
13
  import matplotlib.pyplot as plt
14
- import io
15
  from datetime import datetime
16
-
17
  import gradio as gr
18
  from deepface import DeepFace
19
  import cv2
@@ -31,7 +26,6 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
31
  weights = ResNet50_Weights.DEFAULT
32
  model = resnet50(weights=weights).to(device)
33
  model.eval()
34
-
35
  transform = transforms.Compose([
36
  transforms.Resize(256),
37
  transforms.CenterCrop(224),
@@ -50,7 +44,7 @@ BASIC_COLORS = {
50
 
51
  def closest_basic_color(rgb):
52
  r,g,b = rgb
53
- min_dist = float("inf"); closest_color = None
54
  for name,(cr,cg,cb) in BASIC_COLORS.items():
55
  dist = (r-cr)**2 + (g-cg)**2 + (b-cb)**2
56
  if dist < min_dist:
@@ -73,7 +67,7 @@ def get_dominant_color(image,num_colors=5):
73
  # ---------------------------
74
  def classify_zip_and_analyze_color(zip_file):
75
  results = []
76
- images_list = [] # list of (image, label) for gallery
77
  zip_name = os.path.splitext(os.path.basename(zip_file.name))[0]
78
  date_str = datetime.now().strftime("%Y%m%d")
79
 
@@ -90,6 +84,7 @@ def classify_zip_and_analyze_color(zip_file):
90
  except:
91
  continue
92
 
 
93
  input_tensor = transform(image).unsqueeze(0).to(device)
94
  with torch.no_grad():
95
  output = model(input_tensor)
@@ -100,6 +95,7 @@ def classify_zip_and_analyze_color(zip_file):
100
  rgb, hex_color = get_dominant_color(image)
101
  basic_color = closest_basic_color(rgb)
102
 
 
103
  faces_data = []
104
  try:
105
  img_cv2 = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
@@ -112,13 +108,16 @@ def classify_zip_and_analyze_color(zip_file):
112
  except:
113
  faces_data=[]
114
 
 
 
 
115
  results.append((
116
  fname,
117
  ", ".join([p[0] for p in preds]),
118
  ", ".join([p[1] for p in preds]),
119
  hex_color,
120
  basic_color,
121
- faces_data
122
  ))
123
 
124
  df = pd.DataFrame(results, columns=["Filename","Top 3 Predictions","Confidence","Dominant Color","Basic Color","Face Info"])
@@ -131,8 +130,7 @@ def classify_zip_and_analyze_color(zip_file):
131
  fig1, ax1 = plt.subplots()
132
  color_counts = df["Basic Color"].value_counts()
133
  ax1.bar(color_counts.index, color_counts.values, color="skyblue")
134
- ax1.set_title("Basic Color Frequency")
135
- ax1.set_ylabel("Count")
136
  buf1 = io.BytesIO(); plt.savefig(buf1, format="png"); plt.close(fig1); buf1.seek(0); plot1_img = Image.open(buf1)
137
 
138
  fig2, ax2 = plt.subplots()
@@ -140,35 +138,33 @@ def classify_zip_and_analyze_color(zip_file):
140
  for p in df["Top 3 Predictions"]: preds_flat.extend(p.split(", "))
141
  pred_counts = pd.Series(preds_flat).value_counts().head(20)
142
  ax2.barh(pred_counts.index[::-1], pred_counts.values[::-1], color="salmon")
143
- ax2.set_title("Top Prediction Distribution")
144
- ax2.set_xlabel("Count")
145
  buf2 = io.BytesIO(); plt.savefig(buf2, format="png", bbox_inches="tight"); plt.close(fig2); buf2.seek(0); plot2_img = Image.open(buf2)
146
 
 
147
  ages_male, ages_female = [], []
148
  gender_confidence = {"Homme":0, "Femme":0}
149
  for face_list in df["Face Info"]:
150
- for face in face_list:
151
- age = face["age"]
152
- gender_dict = face["gender"]
153
- gender = max(gender_dict, key=gender_dict.get)
154
- conf = float(gender_dict[gender])/100
155
- weight = min(conf,0.9)
156
- gender_trans = "Homme" if gender=="Man" else "Femme"
157
- gender_confidence[gender_trans] += weight
158
- if gender_trans=="Homme": ages_male.append(age)
159
  else: ages_female.append(age)
160
 
161
  fig3, ax3 = plt.subplots()
162
  ax3.bar(gender_confidence.keys(), gender_confidence.values(), color=["lightblue","pink"])
163
- ax3.set_title("Gender Distribution (Weighted ≤90%)")
164
- ax3.set_ylabel("Sum of Confidence")
165
  buf3 = io.BytesIO(); plt.savefig(buf3, format="png"); plt.close(fig3); buf3.seek(0); plot3_img = Image.open(buf3)
166
 
167
  fig4, ax4 = plt.subplots()
168
  bins = range(0,101,5)
169
  ax4.hist([ages_male, ages_female], bins=bins, color=["lightblue","pink"], label=["Homme","Femme"], edgecolor="black")
170
- ax4.set_title("Age Distribution by Gender")
171
- ax4.set_xlabel("Age"); ax4.set_ylabel("Count"); ax4.legend()
172
  buf4 = io.BytesIO(); plt.savefig(buf4, format="png"); plt.close(fig4); buf4.seek(0); plot4_img = Image.open(buf4)
173
 
174
  return df, images_list, out_xlsx, plot1_img, plot2_img, plot3_img, plot4_img
 
1
+ import os, zipfile, tempfile, requests, io
 
 
 
2
  import numpy as np
3
  import pandas as pd
4
  from PIL import Image
 
8
  from torchvision.models import resnet50, ResNet50_Weights
9
  from sklearn.cluster import MiniBatchKMeans
10
  import matplotlib.pyplot as plt
 
11
  from datetime import datetime
 
12
  import gradio as gr
13
  from deepface import DeepFace
14
  import cv2
 
26
  weights = ResNet50_Weights.DEFAULT
27
  model = resnet50(weights=weights).to(device)
28
  model.eval()
 
29
  transform = transforms.Compose([
30
  transforms.Resize(256),
31
  transforms.CenterCrop(224),
 
44
 
45
  def closest_basic_color(rgb):
46
  r,g,b = rgb
47
+ min_dist, closest_color = float("inf"), None
48
  for name,(cr,cg,cb) in BASIC_COLORS.items():
49
  dist = (r-cr)**2 + (g-cg)**2 + (b-cb)**2
50
  if dist < min_dist:
 
67
  # ---------------------------
68
  def classify_zip_and_analyze_color(zip_file):
69
  results = []
70
+ images_list = []
71
  zip_name = os.path.splitext(os.path.basename(zip_file.name))[0]
72
  date_str = datetime.now().strftime("%Y%m%d")
73
 
 
84
  except:
85
  continue
86
 
87
+ # Classification
88
  input_tensor = transform(image).unsqueeze(0).to(device)
89
  with torch.no_grad():
90
  output = model(input_tensor)
 
95
  rgb, hex_color = get_dominant_color(image)
96
  basic_color = closest_basic_color(rgb)
97
 
98
+ # Face analysis
99
  faces_data = []
100
  try:
101
  img_cv2 = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
 
108
  except:
109
  faces_data=[]
110
 
111
+ # Convert faces data to readable string
112
+ faces_str = "; ".join([f"Age: {face['age']}, Gender: {'Homme' if face['gender']=='Man' else 'Femme'}, Emotion: {face['emotion']}" for face in faces_data])
113
+
114
  results.append((
115
  fname,
116
  ", ".join([p[0] for p in preds]),
117
  ", ".join([p[1] for p in preds]),
118
  hex_color,
119
  basic_color,
120
+ faces_str
121
  ))
122
 
123
  df = pd.DataFrame(results, columns=["Filename","Top 3 Predictions","Confidence","Dominant Color","Basic Color","Face Info"])
 
130
  fig1, ax1 = plt.subplots()
131
  color_counts = df["Basic Color"].value_counts()
132
  ax1.bar(color_counts.index, color_counts.values, color="skyblue")
133
+ ax1.set_title("Basic Color Frequency"); ax1.set_ylabel("Count")
 
134
  buf1 = io.BytesIO(); plt.savefig(buf1, format="png"); plt.close(fig1); buf1.seek(0); plot1_img = Image.open(buf1)
135
 
136
  fig2, ax2 = plt.subplots()
 
138
  for p in df["Top 3 Predictions"]: preds_flat.extend(p.split(", "))
139
  pred_counts = pd.Series(preds_flat).value_counts().head(20)
140
  ax2.barh(pred_counts.index[::-1], pred_counts.values[::-1], color="salmon")
141
+ ax2.set_title("Top Prediction Distribution"); ax2.set_xlabel("Count")
 
142
  buf2 = io.BytesIO(); plt.savefig(buf2, format="png", bbox_inches="tight"); plt.close(fig2); buf2.seek(0); plot2_img = Image.open(buf2)
143
 
144
+ # Gender and age
145
  ages_male, ages_female = [], []
146
  gender_confidence = {"Homme":0, "Femme":0}
147
  for face_list in df["Face Info"]:
148
+ if face_list.strip()=="":
149
+ continue
150
+ for face_str in face_list.split("; "):
151
+ parts = face_str.split(", ")
152
+ age = int(parts[0].split(": ")[1])
153
+ gender = parts[1].split(": ")[1]
154
+ conf = 0.9 # approximation for histogram
155
+ gender_confidence[gender] += conf
156
+ if gender=="Homme": ages_male.append(age)
157
  else: ages_female.append(age)
158
 
159
  fig3, ax3 = plt.subplots()
160
  ax3.bar(gender_confidence.keys(), gender_confidence.values(), color=["lightblue","pink"])
161
+ ax3.set_title("Gender Distribution"); ax3.set_ylabel("Sum of Confidence")
 
162
  buf3 = io.BytesIO(); plt.savefig(buf3, format="png"); plt.close(fig3); buf3.seek(0); plot3_img = Image.open(buf3)
163
 
164
  fig4, ax4 = plt.subplots()
165
  bins = range(0,101,5)
166
  ax4.hist([ages_male, ages_female], bins=bins, color=["lightblue","pink"], label=["Homme","Femme"], edgecolor="black")
167
+ ax4.set_title("Age Distribution by Gender"); ax4.set_xlabel("Age"); ax4.set_ylabel("Count"); ax4.legend()
 
168
  buf4 = io.BytesIO(); plt.savefig(buf4, format="png"); plt.close(fig4); buf4.seek(0); plot4_img = Image.open(buf4)
169
 
170
  return df, images_list, out_xlsx, plot1_img, plot2_img, plot3_img, plot4_img