clementBE commited on
Commit
7b2348f
·
verified ·
1 Parent(s): 92e5b44

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +55 -177
app.py CHANGED
@@ -15,72 +15,48 @@ import io
15
  from datetime import datetime
16
 
17
  import gradio as gr
18
-
19
- # Face analysis
20
  from deepface import DeepFace
21
  import cv2
22
 
23
- # ---------------------------
24
- # Force CPU if no CUDA
25
- # ---------------------------
26
  if not torch.cuda.is_available():
27
  os.environ["CUDA_VISIBLE_DEVICES"] = ""
28
-
29
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
30
 
31
- # ---------------------------
32
- # Load ResNet50
33
- # ---------------------------
34
  weights = ResNet50_Weights.DEFAULT
35
  model = resnet50(weights=weights).to(device)
36
  model.eval()
37
 
38
- # ---------------------------
39
- # Transformations
40
- # ---------------------------
41
  transform = transforms.Compose([
42
  transforms.Resize(256),
43
  transforms.CenterCrop(224),
44
  transforms.ToTensor(),
45
- transforms.Normalize(mean=[0.485, 0.456, 0.406],
46
- std=[0.229, 0.224, 0.225]),
47
  ])
48
 
49
- # ---------------------------
50
- # ImageNet labels
51
- # ---------------------------
52
  LABELS_URL = "https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt"
53
  imagenet_classes = [line.strip() for line in requests.get(LABELS_URL).text.splitlines()]
54
 
55
- # ---------------------------
56
- # Color utilities
57
- # ---------------------------
58
  BASIC_COLORS = {
59
- "Red": (255, 0, 0),
60
- "Green": (0, 255, 0),
61
- "Blue": (0, 0, 255),
62
- "Yellow": (255, 255, 0),
63
- "Cyan": (0, 255, 255),
64
- "Magenta": (255, 0, 255),
65
- "Black": (0, 0, 0),
66
- "White": (255, 255, 255),
67
- "Gray": (128, 128, 128),
68
  }
69
 
70
  def closest_basic_color(rgb):
71
- r, g, b = rgb
72
- min_dist = float("inf")
73
- closest_color = None
74
- for name, (cr, cg, cb) in BASIC_COLORS.items():
75
- dist = (r - cr) ** 2 + (g - cg) ** 2 + (b - cb) ** 2
76
  if dist < min_dist:
77
  min_dist = dist
78
  closest_color = name
79
  return closest_color
80
 
81
- def get_dominant_color(image, num_colors=5):
82
- image = image.resize((100, 100))
83
- pixels = np.array(image).reshape(-1, 3)
84
  kmeans = MiniBatchKMeans(n_clusters=num_colors, random_state=0, n_init=5)
85
  kmeans.fit(pixels)
86
  dominant_color = kmeans.cluster_centers_[np.argmax(np.bincount(kmeans.labels_))]
@@ -88,25 +64,24 @@ def get_dominant_color(image, num_colors=5):
88
  hex_color = f"#{dominant_color[0]:02x}{dominant_color[1]:02x}{dominant_color[2]:02x}"
89
  return dominant_color, hex_color
90
 
91
- # ---------------------------
92
- # Core function
93
- # ---------------------------
94
  def classify_zip_and_analyze_color(zip_file):
95
  results = []
96
-
97
  zip_name = os.path.splitext(os.path.basename(zip_file.name))[0]
98
  date_str = datetime.now().strftime("%Y%m%d")
99
 
100
  with tempfile.TemporaryDirectory() as tmpdir:
101
- with zipfile.ZipFile(zip_file.name, 'r') as zip_ref:
102
  zip_ref.extractall(tmpdir)
103
 
104
  for fname in sorted(os.listdir(tmpdir)):
105
- if fname.lower().endswith(('.png', '.jpg', '.jpeg')):
106
- img_path = os.path.join(tmpdir, fname)
107
  try:
108
  image = Image.open(img_path).convert("RGB")
109
- except Exception:
 
110
  continue
111
 
112
  # Classification
@@ -114,40 +89,25 @@ def classify_zip_and_analyze_color(zip_file):
114
  with torch.no_grad():
115
  output = model(input_tensor)
116
  probs = F.softmax(output, dim=1)[0]
117
-
118
- top3_prob, top3_idx = torch.topk(probs, 3)
119
- preds = [(imagenet_classes[idx], f"{prob.item()*100:.2f}%") for idx, prob in zip(top3_idx, top3_prob)]
120
 
121
  # Dominant color
122
  rgb, hex_color = get_dominant_color(image)
123
  basic_color = closest_basic_color(rgb)
124
 
125
- # Face detection & characterization
126
  faces_data = []
127
  try:
128
  img_cv2 = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
129
- detected_faces = DeepFace.analyze(
130
- img_cv2, actions=["age", "gender", "emotion"], enforce_detection=False
131
- )
132
  if isinstance(detected_faces, list):
133
  for f in detected_faces:
134
- faces_data.append({
135
- "age": f["age"],
136
- "gender": f["gender"],
137
- "emotion": f["dominant_emotion"]
138
- })
139
  else:
140
- faces_data.append({
141
- "age": detected_faces["age"],
142
- "gender": detected_faces["gender"],
143
- "emotion": detected_faces["dominant_emotion"]
144
- })
145
- except Exception:
146
- faces_data = []
147
-
148
- # Thumbnail preview
149
- thumbnail = image.copy()
150
- thumbnail.thumbnail((64, 64))
151
 
152
  results.append((
153
  fname,
@@ -155,123 +115,41 @@ def classify_zip_and_analyze_color(zip_file):
155
  ", ".join([p[1] for p in preds]),
156
  hex_color,
157
  basic_color,
158
- faces_data,
159
- thumbnail
160
  ))
161
 
162
- # Build dataframe
163
- df = pd.DataFrame(results, columns=[
164
- "Filename", "Top 3 Predictions", "Confidence",
165
- "Dominant Color", "Basic Color", "Face Info", "Thumbnail"
166
- ])
167
 
168
- # Save XLSX with zip name + date
169
  out_xlsx = os.path.join(tempfile.gettempdir(), f"{zip_name}_{date_str}_results.xlsx")
170
- df.to_excel(out_xlsx, index=False)
171
-
172
- # ---------------------------
173
- # Plot 1: Basic color frequency
174
- # ---------------------------
175
- fig1, ax1 = plt.subplots()
176
- color_counts = df["Basic Color"].value_counts()
177
- ax1.bar(color_counts.index, color_counts.values, color="skyblue")
178
- ax1.set_title("Basic Color Frequency")
179
- ax1.set_ylabel("Count")
180
- buf1 = io.BytesIO()
181
- plt.savefig(buf1, format="png")
182
- plt.close(fig1)
183
- buf1.seek(0)
184
- plot1_img = Image.open(buf1)
185
-
186
- # ---------------------------
187
- # Plot 2: Top prediction distribution
188
- # ---------------------------
189
- fig2, ax2 = plt.subplots()
190
- preds_flat = []
191
- for p in df["Top 3 Predictions"]:
192
- preds_flat.extend(p.split(", "))
193
- pred_counts = pd.Series(preds_flat).value_counts().head(20)
194
- ax2.barh(pred_counts.index[::-1], pred_counts.values[::-1], color="salmon")
195
- ax2.set_title("Top Prediction Distribution")
196
- ax2.set_xlabel("Count")
197
- buf2 = io.BytesIO()
198
- plt.savefig(buf2, format="png", bbox_inches="tight")
199
- plt.close(fig2)
200
- buf2.seek(0)
201
- plot2_img = Image.open(buf2)
202
 
203
- # ---------------------------
204
- # Extract ages and genders
205
- # ---------------------------
206
- ages_male, ages_female = [], []
207
- gender_confidence = {"Homme": 0, "Femme": 0}
208
 
209
- for face_list in df["Face Info"]:
210
- for face in face_list:
211
- age = face["age"]
212
- gender_dict = face["gender"]
213
- gender = max(gender_dict, key=gender_dict.get)
214
- conf = float(gender_dict[gender]) / 100
215
- weight = min(conf, 0.9)
216
- gender_trans = "Homme" if gender == "Man" else "Femme"
217
- gender_confidence[gender_trans] += weight
218
- if gender_trans == "Homme":
219
- ages_male.append(age)
220
- else:
221
- ages_female.append(age)
222
 
223
- # ---------------------------
224
- # Plot 3: Gender distribution
225
- # ---------------------------
226
- fig3, ax3 = plt.subplots()
227
- ax3.bar(gender_confidence.keys(), gender_confidence.values(), color=["lightblue", "pink"])
228
- ax3.set_title("Gender Distribution (Weighted ≤90%)")
229
- ax3.set_ylabel("Sum of Confidence")
230
- buf3 = io.BytesIO()
231
- plt.savefig(buf3, format="png")
232
- plt.close(fig3)
233
- buf3.seek(0)
234
- plot3_img = Image.open(buf3)
235
 
236
- # ---------------------------
237
- # Plot 4: Age distribution by gender
238
- # ---------------------------
239
- fig4, ax4 = plt.subplots()
240
- bins = range(0, 101, 5)
241
- ax4.hist([ages_male, ages_female], bins=bins, color=["lightblue", "pink"], label=["Homme", "Femme"], edgecolor="black")
242
- ax4.set_title("Age Distribution by Gender")
243
- ax4.set_xlabel("Age")
244
- ax4.set_ylabel("Count")
245
- ax4.legend()
246
- buf4 = io.BytesIO()
247
- plt.savefig(buf4, format="png")
248
- plt.close(fig4)
249
- buf4.seek(0)
250
- plot4_img = Image.open(buf4)
251
 
252
- return df, out_xlsx, plot1_img, plot2_img, plot3_img, plot4_img
 
253
 
254
- # ---------------------------
255
- # Gradio Interface
256
- # ---------------------------
257
- demo = gr.Interface(
258
- fn=classify_zip_and_analyze_color,
259
- inputs=gr.File(file_types=[".zip"], label="Upload ZIP of images"),
260
- outputs=[
261
- gr.Dataframe(
262
- headers=["Filename", "Top 3 Predictions", "Confidence",
263
- "Dominant Color", "Basic Color", "Face Info", "Thumbnail"],
264
- datatype=["str","str","str","str","str","str","pil"]
265
- ),
266
- gr.File(label="Download XLSX"),
267
- gr.Image(type="pil", label="Basic Color Frequency"),
268
- gr.Image(type="pil", label="Top Prediction Distribution"),
269
- gr.Image(type="pil", label="Gender Distribution (Weighted ≤90%)"),
270
- gr.Image(type="pil", label="Age Distribution by Gender"),
271
- ],
272
- title="Image Classifier with Color & Face Analysis",
273
- description="Upload a ZIP of images. Classifies images, analyzes dominant color, detects/characterizes faces (age, gender, emotion), and shows thumbnails.",
274
- )
275
 
276
- if __name__ == "__main__":
277
- demo.launch(server_name="0.0.0.0", server_port=7860)
 
15
  from datetime import datetime
16
 
17
  import gradio as gr
 
 
18
  from deepface import DeepFace
19
  import cv2
20
 
21
+ # CUDA setup
 
 
22
  if not torch.cuda.is_available():
23
  os.environ["CUDA_VISIBLE_DEVICES"] = ""
 
24
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
25
 
26
+ # Load model
 
 
27
  weights = ResNet50_Weights.DEFAULT
28
  model = resnet50(weights=weights).to(device)
29
  model.eval()
30
 
 
 
 
31
  transform = transforms.Compose([
32
  transforms.Resize(256),
33
  transforms.CenterCrop(224),
34
  transforms.ToTensor(),
35
+ transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]),
 
36
  ])
37
 
 
 
 
38
  LABELS_URL = "https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt"
39
  imagenet_classes = [line.strip() for line in requests.get(LABELS_URL).text.splitlines()]
40
 
 
 
 
41
  BASIC_COLORS = {
42
+ "Red": (255,0,0), "Green":(0,255,0), "Blue":(0,0,255),
43
+ "Yellow":(255,255,0), "Cyan":(0,255,255), "Magenta":(255,0,255),
44
+ "Black":(0,0,0), "White":(255,255,255), "Gray":(128,128,128),
 
 
 
 
 
 
45
  }
46
 
47
  def closest_basic_color(rgb):
48
+ r,g,b = rgb
49
+ min_dist = float("inf"); closest_color = None
50
+ for name,(cr,cg,cb) in BASIC_COLORS.items():
51
+ dist = (r-cr)**2 + (g-cg)**2 + (b-cb)**2
 
52
  if dist < min_dist:
53
  min_dist = dist
54
  closest_color = name
55
  return closest_color
56
 
57
+ def get_dominant_color(image,num_colors=5):
58
+ image = image.resize((100,100))
59
+ pixels = np.array(image).reshape(-1,3)
60
  kmeans = MiniBatchKMeans(n_clusters=num_colors, random_state=0, n_init=5)
61
  kmeans.fit(pixels)
62
  dominant_color = kmeans.cluster_centers_[np.argmax(np.bincount(kmeans.labels_))]
 
64
  hex_color = f"#{dominant_color[0]:02x}{dominant_color[1]:02x}{dominant_color[2]:02x}"
65
  return dominant_color, hex_color
66
 
67
+ # Main function
 
 
68
  def classify_zip_and_analyze_color(zip_file):
69
  results = []
70
+ images_dict = {} # store images by filename for preview
71
  zip_name = os.path.splitext(os.path.basename(zip_file.name))[0]
72
  date_str = datetime.now().strftime("%Y%m%d")
73
 
74
  with tempfile.TemporaryDirectory() as tmpdir:
75
+ with zipfile.ZipFile(zip_file.name,'r') as zip_ref:
76
  zip_ref.extractall(tmpdir)
77
 
78
  for fname in sorted(os.listdir(tmpdir)):
79
+ if fname.lower().endswith(('.png','.jpg','.jpeg')):
80
+ img_path = os.path.join(tmpdir,fname)
81
  try:
82
  image = Image.open(img_path).convert("RGB")
83
+ images_dict[fname] = image.copy() # save for preview
84
+ except:
85
  continue
86
 
87
  # Classification
 
89
  with torch.no_grad():
90
  output = model(input_tensor)
91
  probs = F.softmax(output, dim=1)[0]
92
+ top3_prob, top3_idx = torch.topk(probs,3)
93
+ preds = [(imagenet_classes[idx], f"{prob.item()*100:.2f}%") for idx,prob in zip(top3_idx, top3_prob)]
 
94
 
95
  # Dominant color
96
  rgb, hex_color = get_dominant_color(image)
97
  basic_color = closest_basic_color(rgb)
98
 
99
+ # Face detection
100
  faces_data = []
101
  try:
102
  img_cv2 = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
103
+ detected_faces = DeepFace.analyze(img_cv2, actions=["age","gender","emotion"], enforce_detection=False)
 
 
104
  if isinstance(detected_faces, list):
105
  for f in detected_faces:
106
+ faces_data.append({"age": f["age"], "gender": f["gender"], "emotion": f["dominant_emotion"]})
 
 
 
 
107
  else:
108
+ faces_data.append({"age": detected_faces["age"], "gender": detected_faces["gender"], "emotion": detected_faces["dominant_emotion"]})
109
+ except:
110
+ faces_data=[]
 
 
 
 
 
 
 
 
111
 
112
  results.append((
113
  fname,
 
115
  ", ".join([p[1] for p in preds]),
116
  hex_color,
117
  basic_color,
118
+ faces_data
 
119
  ))
120
 
121
+ # DataFrame
122
+ df = pd.DataFrame(results, columns=["Filename","Top 3 Predictions","Confidence","Dominant Color","Basic Color","Face Info"])
 
 
 
123
 
124
+ # XLSX output
125
  out_xlsx = os.path.join(tempfile.gettempdir(), f"{zip_name}_{date_str}_results.xlsx")
126
+ df.to_excel(out_xlsx,index=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
 
128
+ return df, images_dict, out_xlsx
 
 
 
 
129
 
130
+ # Callback to show image preview when clicking filename
131
+ def show_preview(filename, images_dict):
132
+ if filename in images_dict:
133
+ return images_dict[filename]
134
+ else:
135
+ return None
 
 
 
 
 
 
 
136
 
137
+ # Gradio interface
138
+ with gr.Blocks() as demo:
139
+ uploaded_zip = gr.File(label="Upload ZIP of images", file_types=[".zip"])
140
+ output_df = gr.Dataframe(headers=["Filename","Top 3 Predictions","Confidence","Dominant Color","Basic Color","Face Info"])
141
+ image_preview = gr.Image(label="Image Preview")
142
+ download_file = gr.File(label="Download XLSX")
 
 
 
 
 
 
143
 
144
+ # Run analysis
145
+ def run_analysis(zip_file):
146
+ df, images_dict, out_xlsx = classify_zip_and_analyze_color(zip_file)
147
+ return df, images_dict, out_xlsx
 
 
 
 
 
 
 
 
 
 
 
148
 
149
+ analyze_btn = gr.Button("Run Analysis")
150
+ analyze_btn.click(run_analysis, inputs=uploaded_zip, outputs=[output_df, "state", download_file])
151
 
152
+ # Update preview when clicking filename
153
+ output_df.select(show_preview, inputs=[output_df, "state"], outputs=image_preview)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
 
155
+ demo.launch(server_name="0.0.0.0", server_port=7860)