clementBE commited on
Commit
92e5b44
·
verified ·
1 Parent(s): d00da7d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +273 -250
app.py CHANGED
@@ -1,254 +1,277 @@
1
- import gradio as gr
2
- import requests
3
  import os
 
4
  import tempfile
5
- import shutil
6
- import urllib.request
7
- import isodate
8
- import datetime
9
-
10
- # --- IMPORTANT: Ensure this environment variable is set ---
11
- API_KEY = os.getenv("YOUTUBE_API_KEY")
12
- BASE_URL = "https://www.googleapis.com/youtube/v3"
13
-
14
- # -----------------------
15
- # API Usage Tracker
16
- # -----------------------
17
- API_USAGE = {"units": 0}
18
-
19
- def api_get(url, cost, **kwargs):
20
- """Wrapper to count quota usage"""
21
- API_USAGE["units"] += cost
22
- r = requests.get(url, **kwargs)
23
- return r
24
-
25
- # -----------------------
26
- # Helper Functions (Simplified)
27
- # -----------------------
28
- def parse_duration(duration_str):
29
- try:
30
- return int(isodate.parse_duration(duration_str).total_seconds())
31
- except Exception:
32
- return 0
33
-
34
- def get_channel_info(channel_id):
35
- """Fetches channel snippet (including title) (Cost: 1)."""
36
- r = api_get(f"{BASE_URL}/channels?part=snippet&id={channel_id}&key={API_KEY}", 1)
37
- if r.status_code == 200 and 'items' in r.json() and r.json()['items']:
38
- return r.json()['items'][0]['snippet']
39
- return None
40
-
41
- def extract_channel_id(url: str):
42
- """Extracts the Channel ID from various YouTube URLs."""
43
- if "channel/" in url:
44
- return url.split("channel/")[1].split("/")[0]
45
- elif "/@" in url:
46
- handle = url.split("/@")[1].split("/")[0]
47
- r = api_get(f"{BASE_URL}/search?part=snippet&type=channel&q={handle}&key={API_KEY}", 100)
48
- if r.status_code != 200: return None
49
- data = r.json()
50
- if "items" in data and data["items"]:
51
- return data["items"][0]["snippet"]["channelId"]
52
- elif "user/" in url:
53
- username = url.split("user/")[1].split("/")[0]
54
- r = api_get(f"{BASE_URL}/channels?part=id&forUsername={username}&key={API_KEY}", 1)
55
- if r.status_code != 200: return None
56
- data = r.json()
57
- if "items" in data and data["items"]:
58
- return data["items"][0]["id"]
59
- return None
60
-
61
- def get_uploads_playlist(channel_id):
62
- """Fetches the 'uploads' playlist ID for a given channel (Cost: 1)."""
63
- r = api_get(f"{BASE_URL}/channels?part=contentDetails&id={channel_id}&key={API_KEY}", 1).json()
64
- return r['items'][0]['contentDetails']['relatedPlaylists']['uploads']
65
-
66
- # -----------------------
67
- # Fetch and Filter Video IDs
68
- # -----------------------
69
- def filter_video_ids(video_ids, mode="videos"):
70
- """Filters a list of video IDs based on their duration (Cost: 1 unit per 50 videos)."""
71
- selected = []
72
- for i in range(0, len(video_ids), 50):
73
- batch = video_ids[i:i+50]
74
- r = api_get(f"{BASE_URL}/videos?part=contentDetails&id={','.join(batch)}&key={API_KEY}", 1).json()
75
-
76
- for item in r.get("items", []):
77
- if 'contentDetails' not in item: continue
78
- duration = parse_duration(item["contentDetails"]["duration"])
79
- vid = item["id"]
80
-
81
- if mode == "videos":
82
- if duration >= 60:
83
- selected.append(vid)
84
- elif mode == "shorts":
85
- if duration < 60:
86
- selected.append(vid)
87
- elif mode == "all":
88
- selected.append(vid)
89
-
90
- return selected
91
-
92
- def get_playlist_video_ids(playlist_id, max_videos=50, mode="videos"):
93
- """Pulls video IDs from a playlist and filters them until max_videos is reached."""
94
- video_ids = []
95
- next_page = None
96
-
97
- while len(video_ids) < max_videos:
98
- fetch_count = 50
99
- url = f"{BASE_URL}/playlistItems?part=snippet&playlistId={playlist_id}&maxResults={fetch_count}&key={API_KEY}"
100
- if next_page: url += f"&pageToken={next_page}"
101
-
102
- r = api_get(url, 1).json()
103
- raw_ids = [item["snippet"]["resourceId"]["videoId"] for item in r.get("items", [])]
104
- filtered_ids = filter_video_ids(raw_ids, mode=mode)
105
-
106
- remaining_slots = max_videos - len(video_ids)
107
- video_ids.extend(filtered_ids[:remaining_slots])
108
-
109
- next_page = r.get("nextPageToken")
110
- if not next_page or len(raw_ids) == 0:
111
- break
112
-
113
- return video_ids[:max_videos]
114
-
115
- def get_live_video_ids(channel_id, max_videos=50):
116
- """Fetches completed live streams (Cost: 100)."""
117
- video_ids = []
118
- url = f"{BASE_URL}/search?part=id&channelId={channel_id}&eventType=completed&type=video&maxResults={max_videos}&key={API_KEY}"
119
- r = api_get(url, 100).json()
120
- for item in r.get("items", []):
121
- video_ids.append(item["id"]["videoId"])
122
- return video_ids
123
-
124
- # -----------------------
125
- # Thumbnails Download and Prep
126
- # -----------------------
127
- def download_thumbnails(video_ids):
128
- """Downloads thumbnails to a temp directory (Cost: 1 unit per 50 thumbnails)."""
129
- tmp_dir = tempfile.mkdtemp()
130
- thumb_paths = []
131
- for i in range(0, len(video_ids), 50):
132
- batch = video_ids[i:i+50]
133
- r = api_get(f"{BASE_URL}/videos?part=snippet&id={','.join(batch)}&key={API_KEY}", 1).json()
134
- for item in r.get("items", []):
135
- if 'snippet' not in item: continue
136
- snippet = item['snippet']
137
- thumbnails = snippet['thumbnails']
138
- thumb_url = thumbnails.get("maxres", thumbnails.get("standard", thumbnails.get("high", thumbnails.get("default"))))["url"]
139
-
140
- # Use the video title for the filename for better context in gr.Files
141
- title_safe = "".join(c if c.isalnum() or c in (' ', '_') else '_' for c in snippet['title']).strip().replace(' ', '_')
142
- filename = os.path.join(tmp_dir, f"{title_safe}_{item['id']}.jpg")
143
-
144
- urllib.request.urlretrieve(thumb_url, filename)
145
- thumb_paths.append(filename)
146
- return tmp_dir, thumb_paths
147
-
148
- def fetch_channel_thumbnails(channel_url, max_videos, page_mode):
149
- """Main function to orchestrate video fetching and thumbnail download."""
150
- channel_id = extract_channel_id(channel_url)
151
- if not channel_id:
152
- return "❌ Could not extract channel ID", None, None, None
153
-
154
- channel_info = get_channel_info(channel_id)
155
- if not channel_info:
156
- return " Could not fetch channel info", None, None, None
157
-
158
- channel_name = channel_info.get("title", "unknown_channel")
159
-
160
- if page_mode in ["videos", "shorts", "all"]:
161
- try:
162
- playlist_id = get_uploads_playlist(channel_id)
163
- except Exception:
164
- return "❌ Could not find channel 'uploads' playlist ID", None, None, None
165
- video_ids = get_playlist_video_ids(playlist_id, max_videos=max_videos, mode=page_mode)
166
-
167
- elif page_mode == "live":
168
- video_ids = get_live_video_ids(channel_id, max_videos=max_videos)
169
- else:
170
- return "❌ Unknown mode", None, None, None
171
-
172
- if not video_ids:
173
- return f"❌ No {page_mode} found", None, None, None
174
-
175
- tmp_dir, thumbs = download_thumbnails(video_ids)
176
- return f"✅ Fetched {len(thumbs)} {page_mode}", thumbs, tmp_dir, channel_name
177
-
178
- def prepare_zip(thumb_dir, channel_name):
179
- """Creates a zip archive with a custom filename."""
180
- safe_channel_name = "".join(c if c.isalnum() or c in (' ', '_') else '_' for c in channel_name).strip().replace(' ', '_')
181
- date_str = datetime.datetime.now().strftime("%Y%m%d")
182
-
183
- zip_filename_base = f"{safe_channel_name}_Thumbnails_{date_str}"
184
- zip_path_no_ext = os.path.join(tempfile.gettempdir(), zip_filename_base)
185
-
186
- shutil.make_archive(zip_path_no_ext, 'zip', thumb_dir)
187
-
188
- final_zip_path = zip_path_no_ext + ".zip"
189
- return final_zip_path
190
-
191
- # -----------------------
192
- # Generator for live status updates
193
- # -----------------------
194
- def fetch_and_zip_progress(channel_url, max_videos, page_mode):
195
- API_USAGE["units"] = 0
196
- yield f"Starting fetch... | API quota used: {API_USAGE['units']} units", [], None, gr.File(visible=False) # 💡 Added gr.File update
197
-
198
- status, thumbs, tmp_dir, channel_name = fetch_channel_thumbnails(channel_url, max_videos, page_mode)
199
- quota_used = API_USAGE["units"]
200
-
201
- final_status = status.replace("videos", "long-form videos (>= 60s)") if page_mode == "videos" else status
202
- final_status = final_status.replace("shorts", "shorts (< 60s)") if page_mode == "shorts" else final_status
203
-
204
- zip_file = None
205
- if thumbs:
206
- zip_file = prepare_zip(tmp_dir, channel_name)
207
-
208
- elif tmp_dir and os.path.isdir(tmp_dir):
209
- shutil.rmtree(tmp_dir)
210
-
211
- # 💡 IMPORTANT: Now yielding a list of file paths (thumbs) and the zip file path.
212
- # The 'thumbs' list goes to gr.Files.
213
- yield f"{final_status} | API quota used: {quota_used} units", thumbs, zip_file, gr.File(visible=True) # 💡 Set visible=True on success
214
-
215
- # -----------------------
216
- # Gradio Interface (Modified)
217
- # -----------------------
218
- with gr.Blocks() as demo:
219
- gr.Markdown("## 🎬 YouTube Channel Thumbnails Downloader (Files Preview)")
220
- gr.Markdown("Thumbnails are now listed as individual files. Click the filename to preview/download.")
221
-
222
- url_input = gr.Textbox(label="YouTube Channel URL", placeholder="https://www.youtube.com/@roisinmurphyofficial")
223
- page_selector = gr.Dropdown(
224
- choices=["videos", "shorts", "live", "all"],
225
- value="videos",
226
- label="Page to Collect"
227
- )
228
- max_videos_slider = gr.Slider(minimum=1, maximum=100, step=1, value=20, label="Max Items to Fetch")
229
- start_btn = gr.Button("🚀 Start Collect")
230
-
231
- status_output = gr.Textbox(label="Status")
232
-
233
- # 💡 REPLACED gr.Gallery with gr.Files
234
- thumbs_list = gr.Files(
235
- label="Thumbnails Preview and Download (Click name for preview)",
236
- file_count="multiple", # Allows multiple files
237
- type="filepath", # Returns the path, which is what we need
238
- visible=True # Ensure it's visible initially
239
- )
240
-
241
- download_btn = gr.File(label="Download All Thumbnails (ZIP)")
242
-
243
- start_btn.click(
244
- fetch_and_zip_progress,
245
- inputs=[url_input, max_videos_slider, page_selector],
246
- # 💡 Updated output targets to match the new return values
247
- outputs=[status_output, thumbs_list, download_btn, download_btn]
248
- # Note: Added download_btn twice as the generator yields 4 items,
249
- # but the last one is a gr.File update to hide/show the component.
250
- # This is a slightly awkward necessity of Gradio's generator API.
251
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
252
 
253
  if __name__ == "__main__":
254
- demo.launch()
 
 
 
1
  import os
2
+ import zipfile
3
  import tempfile
4
+ import requests
5
+ import numpy as np
6
+ import pandas as pd
7
+ from PIL import Image
8
+ import torch
9
+ import torch.nn.functional as F
10
+ from torchvision import transforms
11
+ from torchvision.models import resnet50, ResNet50_Weights
12
+ from sklearn.cluster import MiniBatchKMeans
13
+ import matplotlib.pyplot as plt
14
+ import io
15
+ from datetime import datetime
16
+
17
+ import gradio as gr
18
+
19
+ # Face analysis
20
+ from deepface import DeepFace
21
+ import cv2
22
+
23
+ # ---------------------------
24
+ # Force CPU if no CUDA
25
+ # ---------------------------
26
+ if not torch.cuda.is_available():
27
+ os.environ["CUDA_VISIBLE_DEVICES"] = ""
28
+
29
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
30
+
31
+ # ---------------------------
32
+ # Load ResNet50
33
+ # ---------------------------
34
+ weights = ResNet50_Weights.DEFAULT
35
+ model = resnet50(weights=weights).to(device)
36
+ model.eval()
37
+
38
+ # ---------------------------
39
+ # Transformations
40
+ # ---------------------------
41
+ transform = transforms.Compose([
42
+ transforms.Resize(256),
43
+ transforms.CenterCrop(224),
44
+ transforms.ToTensor(),
45
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
46
+ std=[0.229, 0.224, 0.225]),
47
+ ])
48
+
49
+ # ---------------------------
50
+ # ImageNet labels
51
+ # ---------------------------
52
+ LABELS_URL = "https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt"
53
+ imagenet_classes = [line.strip() for line in requests.get(LABELS_URL).text.splitlines()]
54
+
55
+ # ---------------------------
56
+ # Color utilities
57
+ # ---------------------------
58
+ BASIC_COLORS = {
59
+ "Red": (255, 0, 0),
60
+ "Green": (0, 255, 0),
61
+ "Blue": (0, 0, 255),
62
+ "Yellow": (255, 255, 0),
63
+ "Cyan": (0, 255, 255),
64
+ "Magenta": (255, 0, 255),
65
+ "Black": (0, 0, 0),
66
+ "White": (255, 255, 255),
67
+ "Gray": (128, 128, 128),
68
+ }
69
+
70
+ def closest_basic_color(rgb):
71
+ r, g, b = rgb
72
+ min_dist = float("inf")
73
+ closest_color = None
74
+ for name, (cr, cg, cb) in BASIC_COLORS.items():
75
+ dist = (r - cr) ** 2 + (g - cg) ** 2 + (b - cb) ** 2
76
+ if dist < min_dist:
77
+ min_dist = dist
78
+ closest_color = name
79
+ return closest_color
80
+
81
+ def get_dominant_color(image, num_colors=5):
82
+ image = image.resize((100, 100))
83
+ pixels = np.array(image).reshape(-1, 3)
84
+ kmeans = MiniBatchKMeans(n_clusters=num_colors, random_state=0, n_init=5)
85
+ kmeans.fit(pixels)
86
+ dominant_color = kmeans.cluster_centers_[np.argmax(np.bincount(kmeans.labels_))]
87
+ dominant_color = tuple(dominant_color.astype(int))
88
+ hex_color = f"#{dominant_color[0]:02x}{dominant_color[1]:02x}{dominant_color[2]:02x}"
89
+ return dominant_color, hex_color
90
+
91
+ # ---------------------------
92
+ # Core function
93
+ # ---------------------------
94
+ def classify_zip_and_analyze_color(zip_file):
95
+ results = []
96
+
97
+ zip_name = os.path.splitext(os.path.basename(zip_file.name))[0]
98
+ date_str = datetime.now().strftime("%Y%m%d")
99
+
100
+ with tempfile.TemporaryDirectory() as tmpdir:
101
+ with zipfile.ZipFile(zip_file.name, 'r') as zip_ref:
102
+ zip_ref.extractall(tmpdir)
103
+
104
+ for fname in sorted(os.listdir(tmpdir)):
105
+ if fname.lower().endswith(('.png', '.jpg', '.jpeg')):
106
+ img_path = os.path.join(tmpdir, fname)
107
+ try:
108
+ image = Image.open(img_path).convert("RGB")
109
+ except Exception:
110
+ continue
111
+
112
+ # Classification
113
+ input_tensor = transform(image).unsqueeze(0).to(device)
114
+ with torch.no_grad():
115
+ output = model(input_tensor)
116
+ probs = F.softmax(output, dim=1)[0]
117
+
118
+ top3_prob, top3_idx = torch.topk(probs, 3)
119
+ preds = [(imagenet_classes[idx], f"{prob.item()*100:.2f}%") for idx, prob in zip(top3_idx, top3_prob)]
120
+
121
+ # Dominant color
122
+ rgb, hex_color = get_dominant_color(image)
123
+ basic_color = closest_basic_color(rgb)
124
+
125
+ # Face detection & characterization
126
+ faces_data = []
127
+ try:
128
+ img_cv2 = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
129
+ detected_faces = DeepFace.analyze(
130
+ img_cv2, actions=["age", "gender", "emotion"], enforce_detection=False
131
+ )
132
+ if isinstance(detected_faces, list):
133
+ for f in detected_faces:
134
+ faces_data.append({
135
+ "age": f["age"],
136
+ "gender": f["gender"],
137
+ "emotion": f["dominant_emotion"]
138
+ })
139
+ else:
140
+ faces_data.append({
141
+ "age": detected_faces["age"],
142
+ "gender": detected_faces["gender"],
143
+ "emotion": detected_faces["dominant_emotion"]
144
+ })
145
+ except Exception:
146
+ faces_data = []
147
+
148
+ # Thumbnail preview
149
+ thumbnail = image.copy()
150
+ thumbnail.thumbnail((64, 64))
151
+
152
+ results.append((
153
+ fname,
154
+ ", ".join([p[0] for p in preds]),
155
+ ", ".join([p[1] for p in preds]),
156
+ hex_color,
157
+ basic_color,
158
+ faces_data,
159
+ thumbnail
160
+ ))
161
+
162
+ # Build dataframe
163
+ df = pd.DataFrame(results, columns=[
164
+ "Filename", "Top 3 Predictions", "Confidence",
165
+ "Dominant Color", "Basic Color", "Face Info", "Thumbnail"
166
+ ])
167
+
168
+ # Save XLSX with zip name + date
169
+ out_xlsx = os.path.join(tempfile.gettempdir(), f"{zip_name}_{date_str}_results.xlsx")
170
+ df.to_excel(out_xlsx, index=False)
171
+
172
+ # ---------------------------
173
+ # Plot 1: Basic color frequency
174
+ # ---------------------------
175
+ fig1, ax1 = plt.subplots()
176
+ color_counts = df["Basic Color"].value_counts()
177
+ ax1.bar(color_counts.index, color_counts.values, color="skyblue")
178
+ ax1.set_title("Basic Color Frequency")
179
+ ax1.set_ylabel("Count")
180
+ buf1 = io.BytesIO()
181
+ plt.savefig(buf1, format="png")
182
+ plt.close(fig1)
183
+ buf1.seek(0)
184
+ plot1_img = Image.open(buf1)
185
+
186
+ # ---------------------------
187
+ # Plot 2: Top prediction distribution
188
+ # ---------------------------
189
+ fig2, ax2 = plt.subplots()
190
+ preds_flat = []
191
+ for p in df["Top 3 Predictions"]:
192
+ preds_flat.extend(p.split(", "))
193
+ pred_counts = pd.Series(preds_flat).value_counts().head(20)
194
+ ax2.barh(pred_counts.index[::-1], pred_counts.values[::-1], color="salmon")
195
+ ax2.set_title("Top Prediction Distribution")
196
+ ax2.set_xlabel("Count")
197
+ buf2 = io.BytesIO()
198
+ plt.savefig(buf2, format="png", bbox_inches="tight")
199
+ plt.close(fig2)
200
+ buf2.seek(0)
201
+ plot2_img = Image.open(buf2)
202
+
203
+ # ---------------------------
204
+ # Extract ages and genders
205
+ # ---------------------------
206
+ ages_male, ages_female = [], []
207
+ gender_confidence = {"Homme": 0, "Femme": 0}
208
+
209
+ for face_list in df["Face Info"]:
210
+ for face in face_list:
211
+ age = face["age"]
212
+ gender_dict = face["gender"]
213
+ gender = max(gender_dict, key=gender_dict.get)
214
+ conf = float(gender_dict[gender]) / 100
215
+ weight = min(conf, 0.9)
216
+ gender_trans = "Homme" if gender == "Man" else "Femme"
217
+ gender_confidence[gender_trans] += weight
218
+ if gender_trans == "Homme":
219
+ ages_male.append(age)
220
+ else:
221
+ ages_female.append(age)
222
+
223
+ # ---------------------------
224
+ # Plot 3: Gender distribution
225
+ # ---------------------------
226
+ fig3, ax3 = plt.subplots()
227
+ ax3.bar(gender_confidence.keys(), gender_confidence.values(), color=["lightblue", "pink"])
228
+ ax3.set_title("Gender Distribution (Weighted ≤90%)")
229
+ ax3.set_ylabel("Sum of Confidence")
230
+ buf3 = io.BytesIO()
231
+ plt.savefig(buf3, format="png")
232
+ plt.close(fig3)
233
+ buf3.seek(0)
234
+ plot3_img = Image.open(buf3)
235
+
236
+ # ---------------------------
237
+ # Plot 4: Age distribution by gender
238
+ # ---------------------------
239
+ fig4, ax4 = plt.subplots()
240
+ bins = range(0, 101, 5)
241
+ ax4.hist([ages_male, ages_female], bins=bins, color=["lightblue", "pink"], label=["Homme", "Femme"], edgecolor="black")
242
+ ax4.set_title("Age Distribution by Gender")
243
+ ax4.set_xlabel("Age")
244
+ ax4.set_ylabel("Count")
245
+ ax4.legend()
246
+ buf4 = io.BytesIO()
247
+ plt.savefig(buf4, format="png")
248
+ plt.close(fig4)
249
+ buf4.seek(0)
250
+ plot4_img = Image.open(buf4)
251
+
252
+ return df, out_xlsx, plot1_img, plot2_img, plot3_img, plot4_img
253
+
254
+ # ---------------------------
255
+ # Gradio Interface
256
+ # ---------------------------
257
+ demo = gr.Interface(
258
+ fn=classify_zip_and_analyze_color,
259
+ inputs=gr.File(file_types=[".zip"], label="Upload ZIP of images"),
260
+ outputs=[
261
+ gr.Dataframe(
262
+ headers=["Filename", "Top 3 Predictions", "Confidence",
263
+ "Dominant Color", "Basic Color", "Face Info", "Thumbnail"],
264
+ datatype=["str","str","str","str","str","str","pil"]
265
+ ),
266
+ gr.File(label="Download XLSX"),
267
+ gr.Image(type="pil", label="Basic Color Frequency"),
268
+ gr.Image(type="pil", label="Top Prediction Distribution"),
269
+ gr.Image(type="pil", label="Gender Distribution (Weighted ≤90%)"),
270
+ gr.Image(type="pil", label="Age Distribution by Gender"),
271
+ ],
272
+ title="Image Classifier with Color & Face Analysis",
273
+ description="Upload a ZIP of images. Classifies images, analyzes dominant color, detects/characterizes faces (age, gender, emotion), and shows thumbnails.",
274
+ )
275
 
276
  if __name__ == "__main__":
277
+ demo.launch(server_name="0.0.0.0", server_port=7860)