clementBE commited on
Commit
328faae
·
verified ·
1 Parent(s): 517b23e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +99 -10
app.py CHANGED
@@ -1,11 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # ---------------------------
2
- # Core function with gallery and renamed XLSX
3
  # ---------------------------
4
  def classify_zip_and_analyze_color(zip_file):
5
  results = []
6
  thumbnails = []
7
 
8
- # Get base name of zip to rename XLSX
9
  zip_basename = os.path.splitext(os.path.basename(zip_file.name))[0]
10
 
11
  with tempfile.TemporaryDirectory() as tmpdir:
@@ -20,7 +109,7 @@ def classify_zip_and_analyze_color(zip_file):
20
  except Exception:
21
  continue
22
 
23
- # Create small thumbnail for gallery
24
  thumb = image.copy()
25
  thumb.thumbnail((100, 100))
26
  thumbnails.append(thumb)
@@ -73,14 +162,14 @@ def classify_zip_and_analyze_color(zip_file):
73
  # Build dataframe
74
  df = pd.DataFrame(results, columns=["Filename", "Top 3 Predictions", "Confidence", "Dominant Color", "Basic Color", "Face Info"])
75
 
76
- # Save XLSX with zip name
77
  out_xlsx = os.path.join(tempfile.gettempdir(), f"{zip_basename}_results.xlsx")
78
  df.to_excel(out_xlsx, index=False)
79
 
80
  # ---------------------------
81
- # Plotting code (same as before)
82
  # ---------------------------
83
- # Basic color frequency
84
  fig1, ax1 = plt.subplots()
85
  color_counts = df["Basic Color"].value_counts()
86
  ax1.bar(color_counts.index, color_counts.values, color="skyblue")
@@ -92,7 +181,7 @@ def classify_zip_and_analyze_color(zip_file):
92
  buf1.seek(0)
93
  plot1_img = Image.open(buf1)
94
 
95
- # Top prediction distribution
96
  fig2, ax2 = plt.subplots()
97
  preds_flat = []
98
  for p in df["Top 3 Predictions"]:
@@ -107,7 +196,7 @@ def classify_zip_and_analyze_color(zip_file):
107
  buf2.seek(0)
108
  plot2_img = Image.open(buf2)
109
 
110
- # Gender distribution
111
  ages = []
112
  gender_confidence = {"Man": 0, "Woman": 0}
113
  for face_list in df["Face Info"]:
@@ -132,7 +221,7 @@ def classify_zip_and_analyze_color(zip_file):
132
  buf3.seek(0)
133
  plot3_img = Image.open(buf3)
134
 
135
- # Age distribution
136
  fig4, ax4 = plt.subplots()
137
  ax4.hist(ages, bins=range(0, 101, 5), color="lightgreen", edgecolor="black")
138
  ax4.set_title("Age Distribution")
@@ -147,7 +236,7 @@ def classify_zip_and_analyze_color(zip_file):
147
  return df, out_xlsx, thumbnails, plot1_img, plot2_img, plot3_img, plot4_img
148
 
149
  # ---------------------------
150
- # Gradio Interface with gallery
151
  # ---------------------------
152
  demo = gr.Interface(
153
  fn=classify_zip_and_analyze_color,
 
1
+ import os
2
+ import zipfile
3
+ import tempfile
4
+ import requests
5
+ import numpy as np
6
+ import pandas as pd
7
+ from PIL import Image
8
+ import torch
9
+ import torch.nn.functional as F
10
+ from torchvision import transforms
11
+ from torchvision.models import resnet50, ResNet50_Weights
12
+ from sklearn.cluster import MiniBatchKMeans
13
+ import matplotlib.pyplot as plt
14
+ import io
15
+
16
+ import gradio as gr
17
+
18
+ # Face analysis
19
+ from deepface import DeepFace
20
+ import cv2
21
+
22
+ # ---------------------------
23
+ # Force CPU if no CUDA
24
+ # ---------------------------
25
+ if not torch.cuda.is_available():
26
+ os.environ["CUDA_VISIBLE_DEVICES"] = ""
27
+
28
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
29
+
30
+ # ---------------------------
31
+ # Load ResNet50
32
+ # ---------------------------
33
+ weights = ResNet50_Weights.DEFAULT
34
+ model = resnet50(weights=weights).to(device)
35
+ model.eval()
36
+
37
+ # ---------------------------
38
+ # Transformations
39
+ # ---------------------------
40
+ transform = transforms.Compose([
41
+ transforms.Resize(256),
42
+ transforms.CenterCrop(224),
43
+ transforms.ToTensor(),
44
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
45
+ std=[0.229, 0.224, 0.225]),
46
+ ])
47
+
48
+ # ---------------------------
49
+ # ImageNet labels
50
+ # ---------------------------
51
+ LABELS_URL = "https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt"
52
+ imagenet_classes = [line.strip() for line in requests.get(LABELS_URL).text.splitlines()]
53
+
54
+ # ---------------------------
55
+ # Basic color utilities
56
+ # ---------------------------
57
+ BASIC_COLORS = {
58
+ "Red": (255, 0, 0),
59
+ "Green": (0, 255, 0),
60
+ "Blue": (0, 0, 255),
61
+ "Yellow": (255, 255, 0),
62
+ "Cyan": (0, 255, 255),
63
+ "Magenta": (255, 0, 255),
64
+ "Black": (0, 0, 0),
65
+ "White": (255, 255, 255),
66
+ "Gray": (128, 128, 128),
67
+ }
68
+
69
+ def closest_basic_color(rgb):
70
+ r, g, b = rgb
71
+ min_dist = float("inf")
72
+ closest_color = None
73
+ for name, (cr, cg, cb) in BASIC_COLORS.items():
74
+ dist = (r - cr) ** 2 + (g - cg) ** 2 + (b - cb) ** 2
75
+ if dist < min_dist:
76
+ min_dist = dist
77
+ closest_color = name
78
+ return closest_color
79
+
80
+ def get_dominant_color(image, num_colors=5):
81
+ image = image.resize((100, 100))
82
+ pixels = np.array(image).reshape(-1, 3)
83
+ kmeans = MiniBatchKMeans(n_clusters=num_colors, random_state=0, n_init=5)
84
+ kmeans.fit(pixels)
85
+ dominant_color = kmeans.cluster_centers_[np.argmax(np.bincount(kmeans.labels_))]
86
+ dominant_color = tuple(dominant_color.astype(int))
87
+ hex_color = f"#{dominant_color[0]:02x}{dominant_color[1]:02x}{dominant_color[2]:02x}"
88
+ return dominant_color, hex_color
89
+
90
  # ---------------------------
91
+ # Core function
92
  # ---------------------------
93
  def classify_zip_and_analyze_color(zip_file):
94
  results = []
95
  thumbnails = []
96
 
97
+ # Name XLSX after zip
98
  zip_basename = os.path.splitext(os.path.basename(zip_file.name))[0]
99
 
100
  with tempfile.TemporaryDirectory() as tmpdir:
 
109
  except Exception:
110
  continue
111
 
112
+ # Thumbnail for gallery
113
  thumb = image.copy()
114
  thumb.thumbnail((100, 100))
115
  thumbnails.append(thumb)
 
162
  # Build dataframe
163
  df = pd.DataFrame(results, columns=["Filename", "Top 3 Predictions", "Confidence", "Dominant Color", "Basic Color", "Face Info"])
164
 
165
+ # Save XLSX
166
  out_xlsx = os.path.join(tempfile.gettempdir(), f"{zip_basename}_results.xlsx")
167
  df.to_excel(out_xlsx, index=False)
168
 
169
  # ---------------------------
170
+ # Plots
171
  # ---------------------------
172
+ # 1. Basic color frequency
173
  fig1, ax1 = plt.subplots()
174
  color_counts = df["Basic Color"].value_counts()
175
  ax1.bar(color_counts.index, color_counts.values, color="skyblue")
 
181
  buf1.seek(0)
182
  plot1_img = Image.open(buf1)
183
 
184
+ # 2. Top prediction distribution
185
  fig2, ax2 = plt.subplots()
186
  preds_flat = []
187
  for p in df["Top 3 Predictions"]:
 
196
  buf2.seek(0)
197
  plot2_img = Image.open(buf2)
198
 
199
+ # 3. Gender distribution
200
  ages = []
201
  gender_confidence = {"Man": 0, "Woman": 0}
202
  for face_list in df["Face Info"]:
 
221
  buf3.seek(0)
222
  plot3_img = Image.open(buf3)
223
 
224
+ # 4. Age distribution
225
  fig4, ax4 = plt.subplots()
226
  ax4.hist(ages, bins=range(0, 101, 5), color="lightgreen", edgecolor="black")
227
  ax4.set_title("Age Distribution")
 
236
  return df, out_xlsx, thumbnails, plot1_img, plot2_img, plot3_img, plot4_img
237
 
238
  # ---------------------------
239
+ # Gradio Interface
240
  # ---------------------------
241
  demo = gr.Interface(
242
  fn=classify_zip_and_analyze_color,