clementBE commited on
Commit
25bf82c
Β·
verified Β·
1 Parent(s): c0ff534

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +250 -273
app.py CHANGED
@@ -1,277 +1,254 @@
 
 
1
  import os
2
- import zipfile
3
  import tempfile
4
- import requests
5
- import numpy as np
6
- import pandas as pd
7
- from PIL import Image
8
- import torch
9
- import torch.nn.functional as F
10
- from torchvision import transforms
11
- from torchvision.models import resnet50, ResNet50_Weights
12
- from sklearn.cluster import MiniBatchKMeans
13
- import matplotlib.pyplot as plt
14
- import io
15
- from datetime import datetime
16
-
17
- import gradio as gr
18
-
19
- # Face analysis
20
- from deepface import DeepFace
21
- import cv2
22
-
23
- # ---------------------------
24
- # Force CPU if no CUDA
25
- # ---------------------------
26
- if not torch.cuda.is_available():
27
- os.environ["CUDA_VISIBLE_DEVICES"] = ""
28
-
29
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
30
-
31
- # ---------------------------
32
- # Load ResNet50
33
- # ---------------------------
34
- weights = ResNet50_Weights.DEFAULT
35
- model = resnet50(weights=weights).to(device)
36
- model.eval()
37
-
38
- # ---------------------------
39
- # Transformations
40
- # ---------------------------
41
- transform = transforms.Compose([
42
- transforms.Resize(256),
43
- transforms.CenterCrop(224),
44
- transforms.ToTensor(),
45
- transforms.Normalize(mean=[0.485, 0.456, 0.406],
46
- std=[0.229, 0.224, 0.225]),
47
- ])
48
-
49
- # ---------------------------
50
- # ImageNet labels
51
- # ---------------------------
52
- LABELS_URL = "https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt"
53
- imagenet_classes = [line.strip() for line in requests.get(LABELS_URL).text.splitlines()]
54
-
55
- # ---------------------------
56
- # Color utilities
57
- # ---------------------------
58
- BASIC_COLORS = {
59
- "Red": (255, 0, 0),
60
- "Green": (0, 255, 0),
61
- "Blue": (0, 0, 255),
62
- "Yellow": (255, 255, 0),
63
- "Cyan": (0, 255, 255),
64
- "Magenta": (255, 0, 255),
65
- "Black": (0, 0, 0),
66
- "White": (255, 255, 255),
67
- "Gray": (128, 128, 128),
68
- }
69
-
70
- def closest_basic_color(rgb):
71
- r, g, b = rgb
72
- min_dist = float("inf")
73
- closest_color = None
74
- for name, (cr, cg, cb) in BASIC_COLORS.items():
75
- dist = (r - cr) ** 2 + (g - cg) ** 2 + (b - cb) ** 2
76
- if dist < min_dist:
77
- min_dist = dist
78
- closest_color = name
79
- return closest_color
80
-
81
- def get_dominant_color(image, num_colors=5):
82
- image = image.resize((100, 100))
83
- pixels = np.array(image).reshape(-1, 3)
84
- kmeans = MiniBatchKMeans(n_clusters=num_colors, random_state=0, n_init=5)
85
- kmeans.fit(pixels)
86
- dominant_color = kmeans.cluster_centers_[np.argmax(np.bincount(kmeans.labels_))]
87
- dominant_color = tuple(dominant_color.astype(int))
88
- hex_color = f"#{dominant_color[0]:02x}{dominant_color[1]:02x}{dominant_color[2]:02x}"
89
- return dominant_color, hex_color
90
-
91
- # ---------------------------
92
- # Core function
93
- # ---------------------------
94
- def classify_zip_and_analyze_color(zip_file):
95
- results = []
96
-
97
- zip_name = os.path.splitext(os.path.basename(zip_file.name))[0]
98
- date_str = datetime.now().strftime("%Y%m%d")
99
-
100
- with tempfile.TemporaryDirectory() as tmpdir:
101
- with zipfile.ZipFile(zip_file.name, 'r') as zip_ref:
102
- zip_ref.extractall(tmpdir)
103
-
104
- for fname in sorted(os.listdir(tmpdir)):
105
- if fname.lower().endswith(('.png', '.jpg', '.jpeg')):
106
- img_path = os.path.join(tmpdir, fname)
107
- try:
108
- image = Image.open(img_path).convert("RGB")
109
- except Exception:
110
- continue
111
-
112
- # Classification
113
- input_tensor = transform(image).unsqueeze(0).to(device)
114
- with torch.no_grad():
115
- output = model(input_tensor)
116
- probs = F.softmax(output, dim=1)[0]
117
-
118
- top3_prob, top3_idx = torch.topk(probs, 3)
119
- preds = [(imagenet_classes[idx], f"{prob.item()*100:.2f}%") for idx, prob in zip(top3_idx, top3_prob)]
120
-
121
- # Dominant color
122
- rgb, hex_color = get_dominant_color(image)
123
- basic_color = closest_basic_color(rgb)
124
-
125
- # Face detection & characterization
126
- faces_data = []
127
- try:
128
- img_cv2 = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
129
- detected_faces = DeepFace.analyze(
130
- img_cv2, actions=["age", "gender", "emotion"], enforce_detection=False
131
- )
132
- if isinstance(detected_faces, list):
133
- for f in detected_faces:
134
- faces_data.append({
135
- "age": f["age"],
136
- "gender": f["gender"],
137
- "emotion": f["dominant_emotion"]
138
- })
139
- else:
140
- faces_data.append({
141
- "age": detected_faces["age"],
142
- "gender": detected_faces["gender"],
143
- "emotion": detected_faces["dominant_emotion"]
144
- })
145
- except Exception:
146
- faces_data = []
147
-
148
- # Thumbnail preview
149
- thumbnail = image.copy()
150
- thumbnail.thumbnail((64, 64))
151
-
152
- results.append((
153
- fname,
154
- ", ".join([p[0] for p in preds]),
155
- ", ".join([p[1] for p in preds]),
156
- hex_color,
157
- basic_color,
158
- faces_data,
159
- thumbnail
160
- ))
161
-
162
- # Build dataframe
163
- df = pd.DataFrame(results, columns=[
164
- "Filename", "Top 3 Predictions", "Confidence",
165
- "Dominant Color", "Basic Color", "Face Info", "Thumbnail"
166
- ])
167
-
168
- # Save XLSX with zip name + date
169
- out_xlsx = os.path.join(tempfile.gettempdir(), f"{zip_name}_{date_str}_results.xlsx")
170
- df.to_excel(out_xlsx, index=False)
171
-
172
- # ---------------------------
173
- # Plot 1: Basic color frequency
174
- # ---------------------------
175
- fig1, ax1 = plt.subplots()
176
- color_counts = df["Basic Color"].value_counts()
177
- ax1.bar(color_counts.index, color_counts.values, color="skyblue")
178
- ax1.set_title("Basic Color Frequency")
179
- ax1.set_ylabel("Count")
180
- buf1 = io.BytesIO()
181
- plt.savefig(buf1, format="png")
182
- plt.close(fig1)
183
- buf1.seek(0)
184
- plot1_img = Image.open(buf1)
185
-
186
- # ---------------------------
187
- # Plot 2: Top prediction distribution
188
- # ---------------------------
189
- fig2, ax2 = plt.subplots()
190
- preds_flat = []
191
- for p in df["Top 3 Predictions"]:
192
- preds_flat.extend(p.split(", "))
193
- pred_counts = pd.Series(preds_flat).value_counts().head(20)
194
- ax2.barh(pred_counts.index[::-1], pred_counts.values[::-1], color="salmon")
195
- ax2.set_title("Top Prediction Distribution")
196
- ax2.set_xlabel("Count")
197
- buf2 = io.BytesIO()
198
- plt.savefig(buf2, format="png", bbox_inches="tight")
199
- plt.close(fig2)
200
- buf2.seek(0)
201
- plot2_img = Image.open(buf2)
202
-
203
- # ---------------------------
204
- # Extract ages and genders
205
- # ---------------------------
206
- ages_male, ages_female = [], []
207
- gender_confidence = {"Homme": 0, "Femme": 0}
208
-
209
- for face_list in df["Face Info"]:
210
- for face in face_list:
211
- age = face["age"]
212
- gender_dict = face["gender"]
213
- gender = max(gender_dict, key=gender_dict.get)
214
- conf = float(gender_dict[gender]) / 100
215
- weight = min(conf, 0.9)
216
- gender_trans = "Homme" if gender == "Man" else "Femme"
217
- gender_confidence[gender_trans] += weight
218
- if gender_trans == "Homme":
219
- ages_male.append(age)
220
- else:
221
- ages_female.append(age)
222
-
223
- # ---------------------------
224
- # Plot 3: Gender distribution
225
- # ---------------------------
226
- fig3, ax3 = plt.subplots()
227
- ax3.bar(gender_confidence.keys(), gender_confidence.values(), color=["lightblue", "pink"])
228
- ax3.set_title("Gender Distribution (Weighted ≀90%)")
229
- ax3.set_ylabel("Sum of Confidence")
230
- buf3 = io.BytesIO()
231
- plt.savefig(buf3, format="png")
232
- plt.close(fig3)
233
- buf3.seek(0)
234
- plot3_img = Image.open(buf3)
235
-
236
- # ---------------------------
237
- # Plot 4: Age distribution by gender
238
- # ---------------------------
239
- fig4, ax4 = plt.subplots()
240
- bins = range(0, 101, 5)
241
- ax4.hist([ages_male, ages_female], bins=bins, color=["lightblue", "pink"], label=["Homme", "Femme"], edgecolor="black")
242
- ax4.set_title("Age Distribution by Gender")
243
- ax4.set_xlabel("Age")
244
- ax4.set_ylabel("Count")
245
- ax4.legend()
246
- buf4 = io.BytesIO()
247
- plt.savefig(buf4, format="png")
248
- plt.close(fig4)
249
- buf4.seek(0)
250
- plot4_img = Image.open(buf4)
251
-
252
- return df, out_xlsx, plot1_img, plot2_img, plot3_img, plot4_img
253
-
254
- # ---------------------------
255
- # Gradio Interface
256
- # ---------------------------
257
- demo = gr.Interface(
258
- fn=classify_zip_and_analyze_color,
259
- inputs=gr.File(file_types=[".zip"], label="Upload ZIP of images"),
260
- outputs=[
261
- gr.Dataframe(
262
- headers=["Filename", "Top 3 Predictions", "Confidence",
263
- "Dominant Color", "Basic Color", "Face Info", "Thumbnail"],
264
- datatype=["str","str","str","str","str","str","pil"]
265
- ),
266
- gr.File(label="Download XLSX"),
267
- gr.Image(type="pil", label="Basic Color Frequency"),
268
- gr.Image(type="pil", label="Top Prediction Distribution"),
269
- gr.Image(type="pil", label="Gender Distribution (Weighted ≀90%)"),
270
- gr.Image(type="pil", label="Age Distribution by Gender"),
271
- ],
272
- title="Image Classifier with Color & Face Analysis",
273
- description="Upload a ZIP of images. Classifies images, analyzes dominant color, detects/characterizes faces (age, gender, emotion), and shows thumbnails.",
274
- )
275
 
276
  if __name__ == "__main__":
277
- demo.launch(server_name="0.0.0.0", server_port=7860)
 
1
+ import gradio as gr
2
+ import requests
3
  import os
 
4
  import tempfile
5
+ import shutil
6
+ import urllib.request
7
+ import isodate
8
+ import datetime
9
+
10
+ # --- IMPORTANT: Ensure this environment variable is set ---
11
+ API_KEY = os.getenv("YOUTUBE_API_KEY")
12
+ BASE_URL = "https://www.googleapis.com/youtube/v3"
13
+
14
+ # -----------------------
15
+ # API Usage Tracker
16
+ # -----------------------
17
+ API_USAGE = {"units": 0}
18
+
19
+ def api_get(url, cost, **kwargs):
20
+ """Wrapper to count quota usage"""
21
+ API_USAGE["units"] += cost
22
+ r = requests.get(url, **kwargs)
23
+ return r
24
+
25
+ # -----------------------
26
+ # Helper Functions (Simplified)
27
+ # -----------------------
28
+ def parse_duration(duration_str):
29
+ try:
30
+ return int(isodate.parse_duration(duration_str).total_seconds())
31
+ except Exception:
32
+ return 0
33
+
34
+ def get_channel_info(channel_id):
35
+ """Fetches channel snippet (including title) (Cost: 1)."""
36
+ r = api_get(f"{BASE_URL}/channels?part=snippet&id={channel_id}&key={API_KEY}", 1)
37
+ if r.status_code == 200 and 'items' in r.json() and r.json()['items']:
38
+ return r.json()['items'][0]['snippet']
39
+ return None
40
+
41
+ def extract_channel_id(url: str):
42
+ """Extracts the Channel ID from various YouTube URLs."""
43
+ if "channel/" in url:
44
+ return url.split("channel/")[1].split("/")[0]
45
+ elif "/@" in url:
46
+ handle = url.split("/@")[1].split("/")[0]
47
+ r = api_get(f"{BASE_URL}/search?part=snippet&type=channel&q={handle}&key={API_KEY}", 100)
48
+ if r.status_code != 200: return None
49
+ data = r.json()
50
+ if "items" in data and data["items"]:
51
+ return data["items"][0]["snippet"]["channelId"]
52
+ elif "user/" in url:
53
+ username = url.split("user/")[1].split("/")[0]
54
+ r = api_get(f"{BASE_URL}/channels?part=id&forUsername={username}&key={API_KEY}", 1)
55
+ if r.status_code != 200: return None
56
+ data = r.json()
57
+ if "items" in data and data["items"]:
58
+ return data["items"][0]["id"]
59
+ return None
60
+
61
+ def get_uploads_playlist(channel_id):
62
+ """Fetches the 'uploads' playlist ID for a given channel (Cost: 1)."""
63
+ r = api_get(f"{BASE_URL}/channels?part=contentDetails&id={channel_id}&key={API_KEY}", 1).json()
64
+ return r['items'][0]['contentDetails']['relatedPlaylists']['uploads']
65
+
66
+ # -----------------------
67
+ # Fetch and Filter Video IDs
68
+ # -----------------------
69
+ def filter_video_ids(video_ids, mode="videos"):
70
+ """Filters a list of video IDs based on their duration (Cost: 1 unit per 50 videos)."""
71
+ selected = []
72
+ for i in range(0, len(video_ids), 50):
73
+ batch = video_ids[i:i+50]
74
+ r = api_get(f"{BASE_URL}/videos?part=contentDetails&id={','.join(batch)}&key={API_KEY}", 1).json()
75
+
76
+ for item in r.get("items", []):
77
+ if 'contentDetails' not in item: continue
78
+ duration = parse_duration(item["contentDetails"]["duration"])
79
+ vid = item["id"]
80
+
81
+ if mode == "videos":
82
+ if duration >= 60:
83
+ selected.append(vid)
84
+ elif mode == "shorts":
85
+ if duration < 60:
86
+ selected.append(vid)
87
+ elif mode == "all":
88
+ selected.append(vid)
89
+
90
+ return selected
91
+
92
+ def get_playlist_video_ids(playlist_id, max_videos=50, mode="videos"):
93
+ """Pulls video IDs from a playlist and filters them until max_videos is reached."""
94
+ video_ids = []
95
+ next_page = None
96
+
97
+ while len(video_ids) < max_videos:
98
+ fetch_count = 50
99
+ url = f"{BASE_URL}/playlistItems?part=snippet&playlistId={playlist_id}&maxResults={fetch_count}&key={API_KEY}"
100
+ if next_page: url += f"&pageToken={next_page}"
101
+
102
+ r = api_get(url, 1).json()
103
+ raw_ids = [item["snippet"]["resourceId"]["videoId"] for item in r.get("items", [])]
104
+ filtered_ids = filter_video_ids(raw_ids, mode=mode)
105
+
106
+ remaining_slots = max_videos - len(video_ids)
107
+ video_ids.extend(filtered_ids[:remaining_slots])
108
+
109
+ next_page = r.get("nextPageToken")
110
+ if not next_page or len(raw_ids) == 0:
111
+ break
112
+
113
+ return video_ids[:max_videos]
114
+
115
+ def get_live_video_ids(channel_id, max_videos=50):
116
+ """Fetches completed live streams (Cost: 100)."""
117
+ video_ids = []
118
+ url = f"{BASE_URL}/search?part=id&channelId={channel_id}&eventType=completed&type=video&maxResults={max_videos}&key={API_KEY}"
119
+ r = api_get(url, 100).json()
120
+ for item in r.get("items", []):
121
+ video_ids.append(item["id"]["videoId"])
122
+ return video_ids
123
+
124
+ # -----------------------
125
+ # Thumbnails Download and Prep
126
+ # -----------------------
127
+ def download_thumbnails(video_ids):
128
+ """Downloads thumbnails to a temp directory (Cost: 1 unit per 50 thumbnails)."""
129
+ tmp_dir = tempfile.mkdtemp()
130
+ thumb_paths = []
131
+ for i in range(0, len(video_ids), 50):
132
+ batch = video_ids[i:i+50]
133
+ r = api_get(f"{BASE_URL}/videos?part=snippet&id={','.join(batch)}&key={API_KEY}", 1).json()
134
+ for item in r.get("items", []):
135
+ if 'snippet' not in item: continue
136
+ snippet = item['snippet']
137
+ thumbnails = snippet['thumbnails']
138
+ thumb_url = thumbnails.get("maxres", thumbnails.get("standard", thumbnails.get("high", thumbnails.get("default"))))["url"]
139
+
140
+ # Use the video title for the filename for better context in gr.Files
141
+ title_safe = "".join(c if c.isalnum() or c in (' ', '_') else '_' for c in snippet['title']).strip().replace(' ', '_')
142
+ filename = os.path.join(tmp_dir, f"{title_safe}_{item['id']}.jpg")
143
+
144
+ urllib.request.urlretrieve(thumb_url, filename)
145
+ thumb_paths.append(filename)
146
+ return tmp_dir, thumb_paths
147
+
148
+ def fetch_channel_thumbnails(channel_url, max_videos, page_mode):
149
+ """Main function to orchestrate video fetching and thumbnail download."""
150
+ channel_id = extract_channel_id(channel_url)
151
+ if not channel_id:
152
+ return "❌ Could not extract channel ID", None, None, None
153
+
154
+ channel_info = get_channel_info(channel_id)
155
+ if not channel_info:
156
+ return "❌ Could not fetch channel info", None, None, None
157
+
158
+ channel_name = channel_info.get("title", "unknown_channel")
159
+
160
+ if page_mode in ["videos", "shorts", "all"]:
161
+ try:
162
+ playlist_id = get_uploads_playlist(channel_id)
163
+ except Exception:
164
+ return "❌ Could not find channel 'uploads' playlist ID", None, None, None
165
+ video_ids = get_playlist_video_ids(playlist_id, max_videos=max_videos, mode=page_mode)
166
+
167
+ elif page_mode == "live":
168
+ video_ids = get_live_video_ids(channel_id, max_videos=max_videos)
169
+ else:
170
+ return "❌ Unknown mode", None, None, None
171
+
172
+ if not video_ids:
173
+ return f"❌ No {page_mode} found", None, None, None
174
+
175
+ tmp_dir, thumbs = download_thumbnails(video_ids)
176
+ return f"βœ… Fetched {len(thumbs)} {page_mode}", thumbs, tmp_dir, channel_name
177
+
178
+ def prepare_zip(thumb_dir, channel_name):
179
+ """Creates a zip archive with a custom filename."""
180
+ safe_channel_name = "".join(c if c.isalnum() or c in (' ', '_') else '_' for c in channel_name).strip().replace(' ', '_')
181
+ date_str = datetime.datetime.now().strftime("%Y%m%d")
182
+
183
+ zip_filename_base = f"{safe_channel_name}_Thumbnails_{date_str}"
184
+ zip_path_no_ext = os.path.join(tempfile.gettempdir(), zip_filename_base)
185
+
186
+ shutil.make_archive(zip_path_no_ext, 'zip', thumb_dir)
187
+
188
+ final_zip_path = zip_path_no_ext + ".zip"
189
+ return final_zip_path
190
+
191
+ # -----------------------
192
+ # Generator for live status updates
193
+ # -----------------------
194
+ def fetch_and_zip_progress(channel_url, max_videos, page_mode):
195
+ API_USAGE["units"] = 0
196
+ yield f"Starting fetch... | API quota used: {API_USAGE['units']} units", [], None, gr.File(visible=False) # πŸ’‘ Added gr.File update
197
+
198
+ status, thumbs, tmp_dir, channel_name = fetch_channel_thumbnails(channel_url, max_videos, page_mode)
199
+ quota_used = API_USAGE["units"]
200
+
201
+ final_status = status.replace("videos", "long-form videos (>= 60s)") if page_mode == "videos" else status
202
+ final_status = final_status.replace("shorts", "shorts (< 60s)") if page_mode == "shorts" else final_status
203
+
204
+ zip_file = None
205
+ if thumbs:
206
+ zip_file = prepare_zip(tmp_dir, channel_name)
207
+
208
+ elif tmp_dir and os.path.isdir(tmp_dir):
209
+ shutil.rmtree(tmp_dir)
210
+
211
+ # πŸ’‘ IMPORTANT: Now yielding a list of file paths (thumbs) and the zip file path.
212
+ # The 'thumbs' list goes to gr.Files.
213
+ yield f"{final_status} | API quota used: {quota_used} units", thumbs, zip_file, gr.File(visible=True) # πŸ’‘ Set visible=True on success
214
+
215
+ # -----------------------
216
+ # Gradio Interface (Modified)
217
+ # -----------------------
218
+ with gr.Blocks() as demo:
219
+ gr.Markdown("## 🎬 YouTube Channel Thumbnails Downloader (Files Preview)")
220
+ gr.Markdown("Thumbnails are now listed as individual files. Click the filename to preview/download.")
221
+
222
+ url_input = gr.Textbox(label="YouTube Channel URL", placeholder="https://www.youtube.com/@roisinmurphyofficial")
223
+ page_selector = gr.Dropdown(
224
+ choices=["videos", "shorts", "live", "all"],
225
+ value="videos",
226
+ label="Page to Collect"
227
+ )
228
+ max_videos_slider = gr.Slider(minimum=1, maximum=100, step=1, value=20, label="Max Items to Fetch")
229
+ start_btn = gr.Button("πŸš€ Start Collect")
230
+
231
+ status_output = gr.Textbox(label="Status")
232
+
233
+ # πŸ’‘ REPLACED gr.Gallery with gr.Files
234
+ thumbs_list = gr.Files(
235
+ label="Thumbnails Preview and Download (Click name for preview)",
236
+ file_count="multiple", # Allows multiple files
237
+ type="filepath", # Returns the path, which is what we need
238
+ visible=True # Ensure it's visible initially
239
+ )
240
+
241
+ download_btn = gr.File(label="Download All Thumbnails (ZIP)")
242
+
243
+ start_btn.click(
244
+ fetch_and_zip_progress,
245
+ inputs=[url_input, max_videos_slider, page_selector],
246
+ # πŸ’‘ Updated output targets to match the new return values
247
+ outputs=[status_output, thumbs_list, download_btn, download_btn]
248
+ # Note: Added download_btn twice as the generator yields 4 items,
249
+ # but the last one is a gr.File update to hide/show the component.
250
+ # This is a slightly awkward necessity of Gradio's generator API.
251
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
252
 
253
  if __name__ == "__main__":
254
+ demo.launch()