clementBE commited on
Commit
265fcb2
·
verified ·
1 Parent(s): 0ef5858

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +166 -136
app.py CHANGED
@@ -1,4 +1,7 @@
1
- import os, zipfile, tempfile, io
 
 
 
2
  import numpy as np
3
  import pandas as pd
4
  from PIL import Image
@@ -8,13 +11,21 @@ from torchvision import transforms
8
  from torchvision.models import resnet50, ResNet50_Weights
9
  from sklearn.cluster import MiniBatchKMeans
10
  import matplotlib.pyplot as plt
 
 
 
11
  import gradio as gr
 
 
 
12
  import cv2
13
- import requests
14
 
15
  # ---------------------------
16
- # Device
17
  # ---------------------------
 
 
 
18
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19
 
20
  # ---------------------------
@@ -23,35 +34,53 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
23
  weights = ResNet50_Weights.DEFAULT
24
  model = resnet50(weights=weights).to(device)
25
  model.eval()
 
 
 
 
26
  transform = transforms.Compose([
27
  transforms.Resize(256),
28
  transforms.CenterCrop(224),
29
  transforms.ToTensor(),
30
- transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]),
 
31
  ])
32
 
 
 
 
33
  LABELS_URL = "https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt"
34
  imagenet_classes = [line.strip() for line in requests.get(LABELS_URL).text.splitlines()]
35
 
 
 
 
36
  BASIC_COLORS = {
37
- "Red": (255,0,0), "Green":(0,255,0), "Blue":(0,0,255),
38
- "Yellow":(255,255,0), "Cyan":(0,255,255), "Magenta":(255,0,255),
39
- "Black":(0,0,0), "White":(255,255,255), "Gray":(128,128,128),
 
 
 
 
 
 
40
  }
41
 
42
  def closest_basic_color(rgb):
43
- r,g,b = rgb
44
- min_dist, closest_color = float("inf"), None
45
- for name,(cr,cg,cb) in BASIC_COLORS.items():
46
- dist = (r-cr)**2 + (g-cg)**2 + (b-cb)**2
 
47
  if dist < min_dist:
48
  min_dist = dist
49
  closest_color = name
50
  return closest_color
51
 
52
- def get_dominant_color(image,num_colors=5):
53
- image = image.resize((100,100))
54
- pixels = np.array(image).reshape(-1,3)
55
  kmeans = MiniBatchKMeans(n_clusters=num_colors, random_state=0, n_init=5)
56
  kmeans.fit(pixels)
57
  dominant_color = kmeans.cluster_centers_[np.argmax(np.bincount(kmeans.labels_))]
@@ -60,151 +89,152 @@ def get_dominant_color(image,num_colors=5):
60
  return dominant_color, hex_color
61
 
62
  # ---------------------------
63
- # OpenCV DNN Face + Gender
64
- # ---------------------------
65
- os.makedirs("models", exist_ok=True)
66
-
67
- # Face detection model
68
- FACE_PROTO = "models/deploy.prototxt"
69
- FACE_MODEL = "models/res10_300x300_ssd_iter_140000_fp16.caffemodel"
70
- if not os.path.exists(FACE_PROTO):
71
- r = requests.get("https://raw.githubusercontent.com/opencv/opencv/master/samples/dnn/face_detector/deploy.prototxt"); open(FACE_PROTO,"wb").write(r.content)
72
- if not os.path.exists(FACE_MODEL):
73
- r = requests.get("https://raw.githubusercontent.com/opencv/opencv_3rdparty/master/res10_300x300_ssd_iter_140000_fp16.caffemodel"); open(FACE_MODEL,"wb").write(r.content)
74
-
75
- # Gender model
76
- GENDER_PROTO = "models/deploy_gender.prototxt"
77
- GENDER_MODEL = "models/gender_net.caffemodel"
78
- if not os.path.exists(GENDER_PROTO):
79
- r = requests.get("https://raw.githubusercontent.com/spmallick/learnopencv/master/AgeGender/deploy_gender.prototxt"); open(GENDER_PROTO,"wb").write(r.content)
80
- if not os.path.exists(GENDER_MODEL):
81
- r = requests.get("https://raw.githubusercontent.com/spmallick/learnopencv/master/AgeGender/gender_net.caffemodel"); open(GENDER_MODEL,"wb").write(r.content)
82
-
83
- face_net = cv2.dnn.readNet(FACE_MODEL, FACE_PROTO)
84
- gender_net = cv2.dnn.readNet(GENDER_MODEL, GENDER_PROTO)
85
- GENDER_LIST = ["Homme","Femme"]
86
-
87
- def detect_faces_and_gender(image):
88
- img = np.array(image)[:, :, ::-1] # PIL RGB -> BGR
89
- h, w = img.shape[:2]
90
- blob = cv2.dnn.blobFromImage(img, 1.0, (300,300), [104,117,123], swapRB=False)
91
- face_net.setInput(blob)
92
- detections = face_net.forward()
93
- faces_data = []
94
-
95
- for i in range(detections.shape[2]):
96
- confidence = detections[0,0,i,2]
97
- if confidence > 0.5:
98
- box = detections[0,0,i,3:7] * np.array([w,h,w,h])
99
- x1,y1,x2,y2 = box.astype(int)
100
- x1,y1,x2,y2 = max(0,x1), max(0,y1), min(w,x2), min(h,y2)
101
- face_img = img[y1:y2, x1:x2]
102
- if face_img.size == 0:
103
- continue
104
- face_blob = cv2.dnn.blobFromImage(face_img, 1.0, (227,227),
105
- [78.4263377603, 87.7689143744, 114.895847746], swapRB=False)
106
- gender_net.setInput(face_blob)
107
- gender_preds = gender_net.forward()
108
- gender = GENDER_LIST[gender_preds[0].argmax()]
109
- faces_data.append({"bbox":(x1,y1,x2,y2),"gender":gender})
110
- return faces_data
111
-
112
- # ---------------------------
113
- # Core analysis
114
  # ---------------------------
115
  def classify_zip_and_analyze_color(zip_file):
116
  results = []
117
- images_list = []
118
  zip_name = os.path.splitext(os.path.basename(zip_file.name))[0]
 
119
 
120
  with tempfile.TemporaryDirectory() as tmpdir:
121
- with zipfile.ZipFile(zip_file.name,'r') as zip_ref:
122
  zip_ref.extractall(tmpdir)
123
 
124
  for fname in sorted(os.listdir(tmpdir)):
125
- if not fname.lower().endswith(('.png','.jpg','.jpeg')):
126
- continue
127
- img_path = os.path.join(tmpdir,fname)
128
- try:
129
- image = Image.open(img_path).convert("RGB")
130
- images_list.append((image.copy(), fname))
131
- except:
132
- continue
133
-
134
- # Image classification
135
- input_tensor = transform(image).unsqueeze(0).to(device)
136
- with torch.no_grad():
137
- output = model(input_tensor)
138
- probs = F.softmax(output, dim=1)[0]
139
- top3_prob, top3_idx = torch.topk(probs,3)
140
- preds = [(imagenet_classes[idx], f"{prob.item()*100:.2f}%") for idx,prob in zip(top3_idx, top3_prob)]
141
-
142
- # Dominant color
143
- rgb, hex_color = get_dominant_color(image)
144
- basic_color = closest_basic_color(rgb)
145
-
146
- # Face + gender detection
147
- faces_data = detect_faces_and_gender(image)
148
- faces_str = "; ".join([f"Gender: {f['gender']}" for f in faces_data])
149
-
150
- results.append((
151
- fname,
152
- ", ".join([p[0] for p in preds]),
153
- ", ".join([p[1] for p in preds]),
154
- hex_color,
155
- basic_color,
156
- faces_str
157
- ))
158
-
159
- df = pd.DataFrame(results, columns=["Filename","Top 3 Predictions","Confidence","Dominant Color","Basic Color","Face Info"])
160
- out_xlsx = os.path.join(tempfile.gettempdir(), f"{zip_name}_results.xlsx")
161
- df.to_excel(out_xlsx,index=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
162
 
163
  # ---------------------------
164
- # Plots
165
  # ---------------------------
166
  fig1, ax1 = plt.subplots()
167
  color_counts = df["Basic Color"].value_counts()
168
  ax1.bar(color_counts.index, color_counts.values, color="skyblue")
169
- ax1.set_title("Basic Color Frequency"); ax1.set_ylabel("Count")
170
- buf1 = io.BytesIO(); plt.savefig(buf1, format="png"); plt.close(fig1); buf1.seek(0); plot1_img = Image.open(buf1)
 
 
 
 
 
171
 
 
 
 
172
  fig2, ax2 = plt.subplots()
173
  preds_flat = []
174
- for p in df["Top 3 Predictions"]: preds_flat.extend(p.split(", "))
 
175
  pred_counts = pd.Series(preds_flat).value_counts().head(20)
176
  ax2.barh(pred_counts.index[::-1], pred_counts.values[::-1], color="salmon")
177
- ax2.set_title("Top Prediction Distribution"); ax2.set_xlabel("Count")
178
- buf2 = io.BytesIO(); plt.savefig(buf2, format="png", bbox_inches="tight"); plt.close(fig2); buf2.seek(0); plot2_img = Image.open(buf2)
 
 
 
 
 
 
 
 
 
 
 
 
 
179
 
180
- # Gender distribution
181
- gender_counts = [df["Face Info"].str.count("Homme").sum(), df["Face Info"].str.count("Femme").sum()]
182
  fig3, ax3 = plt.subplots()
183
- ax3.bar(["Homme","Femme"], gender_counts, color=["lightblue","pink"])
184
- ax3.set_title("Gender Distribution"); ax3.set_ylabel("Count")
185
- buf3 = io.BytesIO(); plt.savefig(buf3, format="png"); plt.close(fig3); buf3.seek(0); plot3_img = Image.open(buf3)
 
 
 
 
 
186
 
187
- return df, images_list, out_xlsx, plot1_img, plot2_img, plot3_img
188
 
189
  # ---------------------------
190
- # Gradio interface
191
  # ---------------------------
192
- with gr.Blocks() as demo:
193
- uploaded_zip = gr.File(label="Upload ZIP of images", file_types=[".zip"])
194
- analyze_btn = gr.Button("Run Analysis")
195
-
196
- output_df = gr.Dataframe(headers=["Filename","Top 3 Predictions","Confidence","Dominant Color","Basic Color","Face Info"])
197
- image_gallery = gr.Gallery(label="Preview Images", columns=4, show_label=True)
198
- download_file = gr.File(label="Download XLSX")
199
-
200
- plot1 = gr.Image(label="Basic Color Frequency")
201
- plot2 = gr.Image(label="Top Prediction Distribution")
202
- plot3 = gr.Image(label="Gender Distribution")
203
-
204
- analyze_btn.click(
205
- classify_zip_and_analyze_color,
206
- inputs=uploaded_zip,
207
- outputs=[output_df, image_gallery, download_file, plot1, plot2, plot3]
208
- )
209
-
210
- demo.launch(server_name="0.0.0.0", server_port=7860)
 
 
1
+ import os
2
+ import zipfile
3
+ import tempfile
4
+ import requests
5
  import numpy as np
6
  import pandas as pd
7
  from PIL import Image
 
11
  from torchvision.models import resnet50, ResNet50_Weights
12
  from sklearn.cluster import MiniBatchKMeans
13
  import matplotlib.pyplot as plt
14
+ import io
15
+ from datetime import datetime
16
+
17
  import gradio as gr
18
+
19
+ # Face analysis
20
+ from deepface import DeepFace
21
  import cv2
 
22
 
23
  # ---------------------------
24
+ # Force CPU if no CUDA
25
  # ---------------------------
26
+ if not torch.cuda.is_available():
27
+ os.environ["CUDA_VISIBLE_DEVICES"] = ""
28
+
29
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
30
 
31
  # ---------------------------
 
34
  weights = ResNet50_Weights.DEFAULT
35
  model = resnet50(weights=weights).to(device)
36
  model.eval()
37
+
38
+ # ---------------------------
39
+ # Transformations
40
+ # ---------------------------
41
  transform = transforms.Compose([
42
  transforms.Resize(256),
43
  transforms.CenterCrop(224),
44
  transforms.ToTensor(),
45
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
46
+ std=[0.229, 0.224, 0.225]),
47
  ])
48
 
49
+ # ---------------------------
50
+ # ImageNet labels
51
+ # ---------------------------
52
  LABELS_URL = "https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt"
53
  imagenet_classes = [line.strip() for line in requests.get(LABELS_URL).text.splitlines()]
54
 
55
+ # ---------------------------
56
+ # Color utilities
57
+ # ---------------------------
58
  BASIC_COLORS = {
59
+ "Red": (255, 0, 0),
60
+ "Green": (0, 255, 0),
61
+ "Blue": (0, 0, 255),
62
+ "Yellow": (255, 255, 0),
63
+ "Cyan": (0, 255, 255),
64
+ "Magenta": (255, 0, 255),
65
+ "Black": (0, 0, 0),
66
+ "White": (255, 255, 255),
67
+ "Gray": (128, 128, 128),
68
  }
69
 
70
  def closest_basic_color(rgb):
71
+ r, g, b = rgb
72
+ min_dist = float("inf")
73
+ closest_color = None
74
+ for name, (cr, cg, cb) in BASIC_COLORS.items():
75
+ dist = (r - cr) ** 2 + (g - cg) ** 2 + (b - cb) ** 2
76
  if dist < min_dist:
77
  min_dist = dist
78
  closest_color = name
79
  return closest_color
80
 
81
+ def get_dominant_color(image, num_colors=5):
82
+ image = image.resize((100, 100))
83
+ pixels = np.array(image).reshape(-1, 3)
84
  kmeans = MiniBatchKMeans(n_clusters=num_colors, random_state=0, n_init=5)
85
  kmeans.fit(pixels)
86
  dominant_color = kmeans.cluster_centers_[np.argmax(np.bincount(kmeans.labels_))]
 
89
  return dominant_color, hex_color
90
 
91
  # ---------------------------
92
+ # Core function
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
  # ---------------------------
94
  def classify_zip_and_analyze_color(zip_file):
95
  results = []
96
+
97
  zip_name = os.path.splitext(os.path.basename(zip_file.name))[0]
98
+ date_str = datetime.now().strftime("%Y%m%d")
99
 
100
  with tempfile.TemporaryDirectory() as tmpdir:
101
+ with zipfile.ZipFile(zip_file.name, 'r') as zip_ref:
102
  zip_ref.extractall(tmpdir)
103
 
104
  for fname in sorted(os.listdir(tmpdir)):
105
+ if fname.lower().endswith(('.png', '.jpg', '.jpeg')):
106
+ img_path = os.path.join(tmpdir, fname)
107
+ try:
108
+ image = Image.open(img_path).convert("RGB")
109
+ except Exception:
110
+ continue
111
+
112
+ # Classification
113
+ input_tensor = transform(image).unsqueeze(0).to(device)
114
+ with torch.no_grad():
115
+ output = model(input_tensor)
116
+ probs = F.softmax(output, dim=1)[0]
117
+
118
+ top3_prob, top3_idx = torch.topk(probs, 3)
119
+ preds = [(imagenet_classes[idx], f"{prob.item()*100:.2f}%") for idx, prob in zip(top3_idx, top3_prob)]
120
+
121
+ # Dominant color
122
+ rgb, hex_color = get_dominant_color(image)
123
+ basic_color = closest_basic_color(rgb)
124
+
125
+ # Face detection & characterization
126
+ faces_data = []
127
+ try:
128
+ img_cv2 = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
129
+ detected_faces = DeepFace.analyze(
130
+ img_cv2, actions=["age", "gender", "emotion"], enforce_detection=False
131
+ )
132
+ if isinstance(detected_faces, list):
133
+ for f in detected_faces:
134
+ faces_data.append({
135
+ "age": f["age"],
136
+ "gender": "Homme" if f["gender"]=="Man" else "Femme",
137
+ "emotion": f["dominant_emotion"]
138
+ })
139
+ else:
140
+ faces_data.append({
141
+ "age": detected_faces["age"],
142
+ "gender": "Homme" if detected_faces["gender"]=="Man" else "Femme",
143
+ "emotion": detected_faces["dominant_emotion"]
144
+ })
145
+ except Exception:
146
+ faces_data = []
147
+
148
+ results.append((
149
+ fname,
150
+ ", ".join([p[0] for p in preds]),
151
+ ", ".join([p[1] for p in preds]),
152
+ hex_color,
153
+ basic_color,
154
+ faces_data
155
+ ))
156
+
157
+ # Build dataframe
158
+ df = pd.DataFrame(results, columns=[
159
+ "Filename", "Top 3 Predictions", "Confidence",
160
+ "Dominant Color", "Basic Color", "Face Info"
161
+ ])
162
+
163
+ # Save XLSX with zip name + date
164
+ out_xlsx = os.path.join(tempfile.gettempdir(), f"{zip_name}_{date_str}_results.xlsx")
165
+ df.to_excel(out_xlsx, index=False)
166
 
167
  # ---------------------------
168
+ # Plot 1: Basic color frequency
169
  # ---------------------------
170
  fig1, ax1 = plt.subplots()
171
  color_counts = df["Basic Color"].value_counts()
172
  ax1.bar(color_counts.index, color_counts.values, color="skyblue")
173
+ ax1.set_title("Basic Color Frequency")
174
+ ax1.set_ylabel("Count")
175
+ buf1 = io.BytesIO()
176
+ plt.savefig(buf1, format="png")
177
+ plt.close(fig1)
178
+ buf1.seek(0)
179
+ plot1_img = Image.open(buf1)
180
 
181
+ # ---------------------------
182
+ # Plot 2: Top prediction distribution
183
+ # ---------------------------
184
  fig2, ax2 = plt.subplots()
185
  preds_flat = []
186
+ for p in df["Top 3 Predictions"]:
187
+ preds_flat.extend(p.split(", "))
188
  pred_counts = pd.Series(preds_flat).value_counts().head(20)
189
  ax2.barh(pred_counts.index[::-1], pred_counts.values[::-1], color="salmon")
190
+ ax2.set_title("Top Prediction Distribution")
191
+ ax2.set_xlabel("Count")
192
+ buf2 = io.BytesIO()
193
+ plt.savefig(buf2, format="png", bbox_inches="tight")
194
+ plt.close(fig2)
195
+ buf2.seek(0)
196
+ plot2_img = Image.open(buf2)
197
+
198
+ # ---------------------------
199
+ # Plot 3: Gender distribution
200
+ # ---------------------------
201
+ gender_counts = {"Homme":0, "Femme":0}
202
+ for face_list in df["Face Info"]:
203
+ for face in face_list:
204
+ gender_counts[face["gender"]] += 1
205
 
 
 
206
  fig3, ax3 = plt.subplots()
207
+ ax3.bar(gender_counts.keys(), gender_counts.values(), color=["lightblue","pink"])
208
+ ax3.set_title("Gender Distribution")
209
+ ax3.set_ylabel("Count")
210
+ buf3 = io.BytesIO()
211
+ plt.savefig(buf3, format="png")
212
+ plt.close(fig3)
213
+ buf3.seek(0)
214
+ plot3_img = Image.open(buf3)
215
 
216
+ return df, out_xlsx, plot1_img, plot2_img, plot3_img
217
 
218
  # ---------------------------
219
+ # Gradio Interface
220
  # ---------------------------
221
+ demo = gr.Interface(
222
+ fn=classify_zip_and_analyze_color,
223
+ inputs=gr.File(file_types=[".zip"], label="Upload ZIP of images"),
224
+ outputs=[
225
+ gr.Dataframe(
226
+ headers=["Filename", "Top 3 Predictions", "Confidence",
227
+ "Dominant Color", "Basic Color", "Face Info"],
228
+ datatype=["str","str","str","str","str","str"]
229
+ ),
230
+ gr.File(label="Download XLSX"),
231
+ gr.Image(type="pil", label="Basic Color Frequency"),
232
+ gr.Image(type="pil", label="Top Prediction Distribution"),
233
+ gr.Image(type="pil", label="Gender Distribution"),
234
+ ],
235
+ title="Image Classifier with Color & Face Analysis",
236
+ description="Upload a ZIP of images. Classifies images, analyzes dominant color, detects/characterizes faces (age, gender, emotion).",
237
+ )
238
+
239
+ if __name__ == "__main__":
240
+ demo.launch(server_name="0.0.0.0", server_port=7860)