clementBE commited on
Commit
517b23e
·
verified ·
1 Parent(s): 19ff977

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -118
app.py CHANGED
@@ -1,97 +1,12 @@
1
- import os
2
- import zipfile
3
- import tempfile
4
- import requests
5
- import numpy as np
6
- import pandas as pd
7
- from PIL import Image
8
- import torch
9
- import torch.nn.functional as F
10
- from torchvision import transforms
11
- from torchvision.models import resnet50, ResNet50_Weights
12
- from sklearn.cluster import MiniBatchKMeans
13
- import matplotlib.pyplot as plt
14
- import io
15
-
16
- import gradio as gr
17
-
18
- # Face analysis
19
- from deepface import DeepFace
20
- import cv2
21
-
22
- # ---------------------------
23
- # Force CPU if no CUDA
24
- # ---------------------------
25
- if not torch.cuda.is_available():
26
- os.environ["CUDA_VISIBLE_DEVICES"] = ""
27
-
28
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
29
-
30
- # ---------------------------
31
- # Load ResNet50
32
- # ---------------------------
33
- weights = ResNet50_Weights.DEFAULT
34
- model = resnet50(weights=weights).to(device)
35
- model.eval()
36
-
37
- # ---------------------------
38
- # Transformations
39
- # ---------------------------
40
- transform = transforms.Compose([
41
- transforms.Resize(256),
42
- transforms.CenterCrop(224),
43
- transforms.ToTensor(),
44
- transforms.Normalize(mean=[0.485, 0.456, 0.406],
45
- std=[0.229, 0.224, 0.225]),
46
- ])
47
-
48
- # ---------------------------
49
- # ImageNet labels
50
- # ---------------------------
51
- LABELS_URL = "https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt"
52
- imagenet_classes = [line.strip() for line in requests.get(LABELS_URL).text.splitlines()]
53
-
54
  # ---------------------------
55
- # Color utilities
56
- # ---------------------------
57
- BASIC_COLORS = {
58
- "Red": (255, 0, 0),
59
- "Green": (0, 255, 0),
60
- "Blue": (0, 0, 255),
61
- "Yellow": (255, 255, 0),
62
- "Cyan": (0, 255, 255),
63
- "Magenta": (255, 0, 255),
64
- "Black": (0, 0, 0),
65
- "White": (255, 255, 255),
66
- "Gray": (128, 128, 128),
67
- }
68
-
69
- def closest_basic_color(rgb):
70
- r, g, b = rgb
71
- min_dist = float("inf")
72
- closest_color = None
73
- for name, (cr, cg, cb) in BASIC_COLORS.items():
74
- dist = (r - cr) ** 2 + (g - cg) ** 2 + (b - cb) ** 2
75
- if dist < min_dist:
76
- min_dist = dist
77
- closest_color = name
78
- return closest_color
79
-
80
- def get_dominant_color(image, num_colors=5):
81
- image = image.resize((100, 100))
82
- pixels = np.array(image).reshape(-1, 3)
83
- kmeans = MiniBatchKMeans(n_clusters=num_colors, random_state=0, n_init=5)
84
- kmeans.fit(pixels)
85
- dominant_color = kmeans.cluster_centers_[np.argmax(np.bincount(kmeans.labels_))]
86
- dominant_color = tuple(dominant_color.astype(int))
87
- hex_color = f"#{dominant_color[0]:02x}{dominant_color[1]:02x}{dominant_color[2]:02x}"
88
- return dominant_color, hex_color
89
-
90
- # ---------------------------
91
- # Core function
92
  # ---------------------------
93
  def classify_zip_and_analyze_color(zip_file):
94
  results = []
 
 
 
 
95
 
96
  with tempfile.TemporaryDirectory() as tmpdir:
97
  with zipfile.ZipFile(zip_file.name, 'r') as zip_ref:
@@ -105,6 +20,11 @@ def classify_zip_and_analyze_color(zip_file):
105
  except Exception:
106
  continue
107
 
 
 
 
 
 
108
  # Classification
109
  input_tensor = transform(image).unsqueeze(0).to(device)
110
  with torch.no_grad():
@@ -118,9 +38,7 @@ def classify_zip_and_analyze_color(zip_file):
118
  rgb, hex_color = get_dominant_color(image)
119
  basic_color = closest_basic_color(rgb)
120
 
121
- # ---------------------------
122
  # Face detection & characterization
123
- # ---------------------------
124
  faces_data = []
125
  try:
126
  img_cv2 = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
@@ -131,13 +49,13 @@ def classify_zip_and_analyze_color(zip_file):
131
  for f in detected_faces:
132
  faces_data.append({
133
  "age": f["age"],
134
- "gender": f["gender"], # dict of probabilities
135
  "emotion": f["dominant_emotion"]
136
  })
137
  else:
138
  faces_data.append({
139
  "age": detected_faces["age"],
140
- "gender": detected_faces["gender"], # dict
141
  "emotion": detected_faces["dominant_emotion"]
142
  })
143
  except Exception:
@@ -155,13 +73,14 @@ def classify_zip_and_analyze_color(zip_file):
155
  # Build dataframe
156
  df = pd.DataFrame(results, columns=["Filename", "Top 3 Predictions", "Confidence", "Dominant Color", "Basic Color", "Face Info"])
157
 
158
- # Save XLSX
159
- out_xlsx = os.path.join(tempfile.gettempdir(), "results.xlsx")
160
  df.to_excel(out_xlsx, index=False)
161
 
162
  # ---------------------------
163
- # Plot 1: Basic color frequency
164
  # ---------------------------
 
165
  fig1, ax1 = plt.subplots()
166
  color_counts = df["Basic Color"].value_counts()
167
  ax1.bar(color_counts.index, color_counts.values, color="skyblue")
@@ -173,9 +92,7 @@ def classify_zip_and_analyze_color(zip_file):
173
  buf1.seek(0)
174
  plot1_img = Image.open(buf1)
175
 
176
- # ---------------------------
177
- # Plot 2: Top prediction distribution
178
- # ---------------------------
179
  fig2, ax2 = plt.subplots()
180
  preds_flat = []
181
  for p in df["Top 3 Predictions"]:
@@ -190,30 +107,24 @@ def classify_zip_and_analyze_color(zip_file):
190
  buf2.seek(0)
191
  plot2_img = Image.open(buf2)
192
 
193
- # ---------------------------
194
- # Extract age and weighted gender (confidence ≤ 0.9)
195
- # ---------------------------
196
  ages = []
197
  gender_confidence = {"Man": 0, "Woman": 0}
198
-
199
  for face_list in df["Face Info"]:
200
- for face in face_list: # each face is a dict
201
  ages.append(face["age"])
202
- gender_dict = face["gender"] # dict of probabilities
203
  gender = max(gender_dict, key=gender_dict.get)
204
- conf = float(gender_dict[gender]) / 100 # convert % to 0-1
205
  weight = min(conf, 0.9)
206
  if gender in gender_confidence:
207
  gender_confidence[gender] += weight
208
  else:
209
  gender_confidence[gender] = weight
210
 
211
- # ---------------------------
212
- # Plot 3: Gender distribution (weighted ≤ 0.9)
213
- # ---------------------------
214
  fig3, ax3 = plt.subplots()
215
  ax3.bar(gender_confidence.keys(), gender_confidence.values(), color=["lightblue", "pink"])
216
- ax3.set_title("Gender Distribution (Weighted ≤90% Confidence)")
217
  ax3.set_ylabel("Sum of Confidence")
218
  buf3 = io.BytesIO()
219
  plt.savefig(buf3, format="png")
@@ -221,9 +132,7 @@ def classify_zip_and_analyze_color(zip_file):
221
  buf3.seek(0)
222
  plot3_img = Image.open(buf3)
223
 
224
- # ---------------------------
225
- # Plot 4: Age distribution
226
- # ---------------------------
227
  fig4, ax4 = plt.subplots()
228
  ax4.hist(ages, bins=range(0, 101, 5), color="lightgreen", edgecolor="black")
229
  ax4.set_title("Age Distribution")
@@ -235,10 +144,10 @@ def classify_zip_and_analyze_color(zip_file):
235
  buf4.seek(0)
236
  plot4_img = Image.open(buf4)
237
 
238
- return df, out_xlsx, plot1_img, plot2_img, plot3_img, plot4_img
239
 
240
  # ---------------------------
241
- # Gradio Interface
242
  # ---------------------------
243
  demo = gr.Interface(
244
  fn=classify_zip_and_analyze_color,
@@ -246,14 +155,15 @@ demo = gr.Interface(
246
  outputs=[
247
  gr.Dataframe(headers=["Filename", "Top 3 Predictions", "Confidence", "Dominant Color", "Basic Color", "Face Info"]),
248
  gr.File(label="Download XLSX"),
 
249
  gr.Image(type="pil", label="Basic Color Frequency"),
250
  gr.Image(type="pil", label="Top Prediction Distribution"),
251
  gr.Image(type="pil", label="Gender Distribution (Weighted ≤90%)"),
252
  gr.Image(type="pil", label="Age Distribution"),
253
  ],
254
  title="Image Classifier with Color & Face Analysis",
255
- description="Upload a ZIP of images. Classifies images, analyzes dominant color, and detects/characterizes faces (age, gender, emotion).",
256
  )
257
 
258
  if __name__ == "__main__":
259
- demo.launch(server_name="0.0.0.0", server_port=7860)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # ---------------------------
2
+ # Core function with gallery and renamed XLSX
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  # ---------------------------
4
  def classify_zip_and_analyze_color(zip_file):
5
  results = []
6
+ thumbnails = []
7
+
8
+ # Get base name of zip to rename XLSX
9
+ zip_basename = os.path.splitext(os.path.basename(zip_file.name))[0]
10
 
11
  with tempfile.TemporaryDirectory() as tmpdir:
12
  with zipfile.ZipFile(zip_file.name, 'r') as zip_ref:
 
20
  except Exception:
21
  continue
22
 
23
+ # Create small thumbnail for gallery
24
+ thumb = image.copy()
25
+ thumb.thumbnail((100, 100))
26
+ thumbnails.append(thumb)
27
+
28
  # Classification
29
  input_tensor = transform(image).unsqueeze(0).to(device)
30
  with torch.no_grad():
 
38
  rgb, hex_color = get_dominant_color(image)
39
  basic_color = closest_basic_color(rgb)
40
 
 
41
  # Face detection & characterization
 
42
  faces_data = []
43
  try:
44
  img_cv2 = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
 
49
  for f in detected_faces:
50
  faces_data.append({
51
  "age": f["age"],
52
+ "gender": f["gender"],
53
  "emotion": f["dominant_emotion"]
54
  })
55
  else:
56
  faces_data.append({
57
  "age": detected_faces["age"],
58
+ "gender": detected_faces["gender"],
59
  "emotion": detected_faces["dominant_emotion"]
60
  })
61
  except Exception:
 
73
  # Build dataframe
74
  df = pd.DataFrame(results, columns=["Filename", "Top 3 Predictions", "Confidence", "Dominant Color", "Basic Color", "Face Info"])
75
 
76
+ # Save XLSX with zip name
77
+ out_xlsx = os.path.join(tempfile.gettempdir(), f"{zip_basename}_results.xlsx")
78
  df.to_excel(out_xlsx, index=False)
79
 
80
  # ---------------------------
81
+ # Plotting code (same as before)
82
  # ---------------------------
83
+ # Basic color frequency
84
  fig1, ax1 = plt.subplots()
85
  color_counts = df["Basic Color"].value_counts()
86
  ax1.bar(color_counts.index, color_counts.values, color="skyblue")
 
92
  buf1.seek(0)
93
  plot1_img = Image.open(buf1)
94
 
95
+ # Top prediction distribution
 
 
96
  fig2, ax2 = plt.subplots()
97
  preds_flat = []
98
  for p in df["Top 3 Predictions"]:
 
107
  buf2.seek(0)
108
  plot2_img = Image.open(buf2)
109
 
110
+ # Gender distribution
 
 
111
  ages = []
112
  gender_confidence = {"Man": 0, "Woman": 0}
 
113
  for face_list in df["Face Info"]:
114
+ for face in face_list:
115
  ages.append(face["age"])
116
+ gender_dict = face["gender"]
117
  gender = max(gender_dict, key=gender_dict.get)
118
+ conf = float(gender_dict[gender]) / 100
119
  weight = min(conf, 0.9)
120
  if gender in gender_confidence:
121
  gender_confidence[gender] += weight
122
  else:
123
  gender_confidence[gender] = weight
124
 
 
 
 
125
  fig3, ax3 = plt.subplots()
126
  ax3.bar(gender_confidence.keys(), gender_confidence.values(), color=["lightblue", "pink"])
127
+ ax3.set_title("Gender Distribution (Weighted ≤90%)")
128
  ax3.set_ylabel("Sum of Confidence")
129
  buf3 = io.BytesIO()
130
  plt.savefig(buf3, format="png")
 
132
  buf3.seek(0)
133
  plot3_img = Image.open(buf3)
134
 
135
+ # Age distribution
 
 
136
  fig4, ax4 = plt.subplots()
137
  ax4.hist(ages, bins=range(0, 101, 5), color="lightgreen", edgecolor="black")
138
  ax4.set_title("Age Distribution")
 
144
  buf4.seek(0)
145
  plot4_img = Image.open(buf4)
146
 
147
+ return df, out_xlsx, thumbnails, plot1_img, plot2_img, plot3_img, plot4_img
148
 
149
  # ---------------------------
150
+ # Gradio Interface with gallery
151
  # ---------------------------
152
  demo = gr.Interface(
153
  fn=classify_zip_and_analyze_color,
 
155
  outputs=[
156
  gr.Dataframe(headers=["Filename", "Top 3 Predictions", "Confidence", "Dominant Color", "Basic Color", "Face Info"]),
157
  gr.File(label="Download XLSX"),
158
+ gr.Gallery(label="Thumbnails", elem_id="thumbnail-gallery").style(grid=[5], height="auto"),
159
  gr.Image(type="pil", label="Basic Color Frequency"),
160
  gr.Image(type="pil", label="Top Prediction Distribution"),
161
  gr.Image(type="pil", label="Gender Distribution (Weighted ≤90%)"),
162
  gr.Image(type="pil", label="Age Distribution"),
163
  ],
164
  title="Image Classifier with Color & Face Analysis",
165
+ description="Upload a ZIP of images. Classifies images, analyzes dominant color, detects faces, and displays thumbnails.",
166
  )
167
 
168
  if __name__ == "__main__":
169
+ demo.launch(server_name="0.0.0.0", server_port=7860)