clementBE commited on
Commit
9cffd38
Β·
verified Β·
1 Parent(s): 1ee281a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +164 -256
app.py CHANGED
@@ -1,264 +1,172 @@
1
  import os
2
- import zipfile
3
  import tempfile
4
- import requests
5
- import numpy as np
6
- import pandas as pd
7
- from PIL import Image
8
  import torch
9
- import torch.nn.functional as F
10
- from torchvision import transforms
11
- from torchvision.models import resnet50, ResNet50_Weights
12
- from sklearn.cluster import MiniBatchKMeans
13
- import matplotlib.pyplot as plt
14
- import io
15
-
16
  import gradio as gr
17
-
18
- # Face analysis
19
- from deepface import DeepFace
20
- import cv2
21
-
22
- # ---------------------------
23
- # Force CPU if no CUDA
24
- # ---------------------------
25
- if not torch.cuda.is_available():
26
- os.environ["CUDA_VISIBLE_DEVICES"] = ""
27
-
28
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
29
-
30
- # ---------------------------
31
- # Load ResNet50
32
- # ---------------------------
33
- weights = ResNet50_Weights.DEFAULT
34
- model = resnet50(weights=weights).to(device)
35
- model.eval()
36
-
37
- # ---------------------------
38
- # Transformations
39
- # ---------------------------
40
- transform = transforms.Compose([
41
- transforms.Resize(256),
42
- transforms.CenterCrop(224),
43
- transforms.ToTensor(),
44
- transforms.Normalize(mean=[0.485, 0.456, 0.406],
45
- std=[0.229, 0.224, 0.225]),
46
- ])
47
-
48
- # ---------------------------
49
- # ImageNet labels
50
- # ---------------------------
51
- LABELS_URL = "https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt"
52
- imagenet_classes = [line.strip() for line in requests.get(LABELS_URL).text.splitlines()]
53
-
54
- # ---------------------------
55
- # Color utilities
56
- # ---------------------------
57
- BASIC_COLORS = {
58
- "Red": (255, 0, 0),
59
- "Green": (0, 255, 0),
60
- "Blue": (0, 0, 255),
61
- "Yellow": (255, 255, 0),
62
- "Cyan": (0, 255, 255),
63
- "Magenta": (255, 0, 255),
64
- "Black": (0, 0, 0),
65
- "White": (255, 255, 255),
66
- "Gray": (128, 128, 128),
67
  }
68
 
69
- def closest_basic_color(rgb):
70
- r, g, b = rgb
71
- min_dist = float("inf")
72
- closest_color = None
73
- for name, (cr, cg, cb) in BASIC_COLORS.items():
74
- dist = (r - cr) ** 2 + (g - cg) ** 2 + (b - cb) ** 2
75
- if dist < min_dist:
76
- min_dist = dist
77
- closest_color = name
78
- return closest_color
79
-
80
- def get_dominant_color(image, num_colors=5):
81
- image = image.resize((100, 100))
82
- pixels = np.array(image).reshape(-1, 3)
83
- kmeans = MiniBatchKMeans(n_clusters=num_colors, random_state=0, n_init=5)
84
- kmeans.fit(pixels)
85
- dominant_color = kmeans.cluster_centers_[np.argmax(np.bincount(kmeans.labels_))]
86
- dominant_color = tuple(dominant_color.astype(int))
87
- hex_color = f"#{dominant_color[0]:02x}{dominant_color[1]:02x}{dominant_color[2]:02x}"
88
- return dominant_color, hex_color
89
-
90
- # ---------------------------
91
- # Core function
92
- # ---------------------------
93
- def classify_zip_and_analyze_color(zip_file):
94
- results = []
95
-
96
- with tempfile.TemporaryDirectory() as tmpdir:
97
- with zipfile.ZipFile(zip_file.name, 'r') as zip_ref:
98
- zip_ref.extractall(tmpdir)
99
-
100
- for fname in sorted(os.listdir(tmpdir)):
101
- if fname.lower().endswith(('.png', '.jpg', '.jpeg')):
102
- img_path = os.path.join(tmpdir, fname)
103
- try:
104
- image = Image.open(img_path).convert("RGB")
105
- except Exception:
106
- continue
107
-
108
- # Classification
109
- input_tensor = transform(image).unsqueeze(0).to(device)
110
- with torch.no_grad():
111
- output = model(input_tensor)
112
- probs = F.softmax(output, dim=1)[0]
113
-
114
- top3_prob, top3_idx = torch.topk(probs, 3)
115
- preds = [(imagenet_classes[idx], f"{prob.item()*100:.2f}%") for idx, prob in zip(top3_idx, top3_prob)]
116
-
117
- # Dominant color
118
- rgb, hex_color = get_dominant_color(image)
119
- basic_color = closest_basic_color(rgb)
120
-
121
- # ---------------------------
122
- # Face detection & characterization
123
- # ---------------------------
124
- face_info = ""
125
- try:
126
- img_cv2 = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
127
- faces = DeepFace.analyze(img_cv2, actions=["age", "gender", "emotion"], enforce_detection=False)
128
- if isinstance(faces, list): # multiple faces
129
- for f in faces:
130
- face_info += f"Age: {f['age']}, Gender: {f['gender']}, Gender Confidence: {f['gender_confidence']*100:.2f}, Emotion: {f['dominant_emotion']}; "
131
- else: # single face
132
- face_info = f"Age: {faces['age']}, Gender: {faces['gender']}, Gender Confidence: {faces['gender_confidence']*100:.2f}, Emotion: {faces['dominant_emotion']}"
133
- except Exception as e:
134
- face_info = "No face detected"
135
-
136
- results.append((
137
- fname,
138
- ", ".join([p[0] for p in preds]),
139
- ", ".join([p[1] for p in preds]),
140
- hex_color,
141
- basic_color,
142
- face_info
143
- ))
144
-
145
- # Build dataframe
146
- df = pd.DataFrame(results, columns=["Filename", "Top 3 Predictions", "Confidence", "Dominant Color", "Basic Color", "Face Info"])
147
-
148
- # Save XLSX
149
- out_xlsx = os.path.join(tempfile.gettempdir(), "results.xlsx")
150
- df.to_excel(out_xlsx, index=False)
151
-
152
- # ---------------------------
153
- # Plot 1: Basic color frequency
154
- # ---------------------------
155
- fig1, ax1 = plt.subplots()
156
- color_counts = df["Basic Color"].value_counts()
157
- ax1.bar(color_counts.index, color_counts.values, color="skyblue")
158
- ax1.set_title("Basic Color Frequency")
159
- ax1.set_ylabel("Count")
160
- buf1 = io.BytesIO()
161
- plt.savefig(buf1, format="png")
162
- plt.close(fig1)
163
- buf1.seek(0)
164
- plot1_img = Image.open(buf1)
165
-
166
- # ---------------------------
167
- # Plot 2: Top prediction distribution
168
- # ---------------------------
169
- fig2, ax2 = plt.subplots()
170
- preds_flat = []
171
- for p in df["Top 3 Predictions"]:
172
- preds_flat.extend(p.split(", "))
173
- pred_counts = pd.Series(preds_flat).value_counts().head(20)
174
- ax2.barh(pred_counts.index[::-1], pred_counts.values[::-1], color="salmon")
175
- ax2.set_title("Top Prediction Distribution")
176
- ax2.set_xlabel("Count")
177
- buf2 = io.BytesIO()
178
- plt.savefig(buf2, format="png", bbox_inches="tight")
179
- plt.close(fig2)
180
- buf2.seek(0)
181
- plot2_img = Image.open(buf2)
182
-
183
- # ---------------------------
184
- # Extract age and gender (confidence ≀ 80%)
185
- # ---------------------------
186
- ages = []
187
- gender_confidence = {"Man": 0, "Woman": 0}
188
-
189
- for info in df["Face Info"]:
190
- if info != "No face detected":
191
- for face_str in info.split(";"):
192
- face_str = face_str.strip()
193
- if face_str:
194
- # Age
195
- age_part = face_str.split(",")[0]
196
- age = int(age_part.replace("Age:", "").strip())
197
- ages.append(age)
198
-
199
- # Gender and confidence
200
- gender_part = face_str.split(",")[1]
201
- gender = gender_part.replace("Gender:", "").strip()
202
-
203
- # Extract confidence
204
- conf = 1.0
205
- for part in face_str.split(","):
206
- if "Gender Confidence:" in part:
207
- conf = float(part.split("Gender Confidence:")[1].strip()) / 100 # convert % to 0-1
208
-
209
- # Only include if confidence ≀ 0.8
210
- if conf <= 0.8:
211
- if gender in gender_confidence:
212
- gender_confidence[gender] += conf
213
- else:
214
- gender_confidence[gender] = conf
215
-
216
- # ---------------------------
217
- # Plot 3: Gender distribution (confidence ≀ 80%)
218
- # ---------------------------
219
- fig3, ax3 = plt.subplots()
220
- ax3.bar(gender_confidence.keys(), gender_confidence.values(), color=["lightblue", "pink"])
221
- ax3.set_title("Gender Distribution (Confidence ≀ 80%)")
222
- ax3.set_ylabel("Sum of Confidence")
223
- buf3 = io.BytesIO()
224
- plt.savefig(buf3, format="png")
225
- plt.close(fig3)
226
- buf3.seek(0)
227
- plot3_img = Image.open(buf3)
228
-
229
- # ---------------------------
230
- # Plot 4: Age distribution
231
- # ---------------------------
232
- fig4, ax4 = plt.subplots()
233
- ax4.hist(ages, bins=range(0, 101, 5), color="lightgreen", edgecolor="black")
234
- ax4.set_title("Age Distribution")
235
- ax4.set_xlabel("Age")
236
- ax4.set_ylabel("Count")
237
- buf4 = io.BytesIO()
238
- plt.savefig(buf4, format="png")
239
- plt.close(fig4)
240
- buf4.seek(0)
241
- plot4_img = Image.open(buf4)
242
-
243
- return df, out_xlsx, plot1_img, plot2_img, plot3_img, plot4_img
244
-
245
- # ---------------------------
246
- # Gradio Interface
247
- # ---------------------------
248
- demo = gr.Interface(
249
- fn=classify_zip_and_analyze_color,
250
- inputs=gr.File(file_types=[".zip"], label="Upload ZIP of images"),
251
- outputs=[
252
- gr.Dataframe(headers=["Filename", "Top 3 Predictions", "Confidence", "Dominant Color", "Basic Color", "Face Info"]),
253
- gr.File(label="Download XLSX"),
254
- gr.Image(type="pil", label="Basic Color Frequency"),
255
- gr.Image(type="pil", label="Top Prediction Distribution"),
256
- gr.Image(type="pil", label="Gender Distribution (≀80% Confidence)"),
257
- gr.Image(type="pil", label="Age Distribution"),
258
- ],
259
- title="Image Classifier with Color & Face Analysis",
260
- description="Upload a ZIP of images. Classifies images, analyzes dominant color, and detects/characterizes faces (age, gender, emotion).",
261
- )
262
 
263
  if __name__ == "__main__":
264
- demo.launch(server_name="0.0.0.0", server_port=7860)
 
1
  import os
 
2
  import tempfile
3
+ import datetime
4
+ import time
 
 
5
  import torch
 
 
 
 
 
 
 
6
  import gradio as gr
7
+ import spaces
8
+ from transformers import pipeline
9
+ from docx import Document
10
+ from pydub import AudioSegment
11
+
12
+ # --- Model definitions ---
13
+ MODEL_SIZES = {
14
+ "Tiny (Fastest)": "openai/whisper-tiny",
15
+ "Base (Faster)": "openai/whisper-base",
16
+ "Small (Balanced)": "openai/whisper-small",
17
+ "Distil-Large-v3 (General Purpose)": "distil-whisper/distil-large-v3",
18
+ "Distil-Large-v3-FR (French-Specific)": "eustlb/distil-large-v3-fr"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  }
20
 
21
+ # --- Caches ---
22
+ model_cache = {}
23
+ summary_cache = {}
24
+
25
+ # --- Whisper pipeline loader ---
26
+ def get_model_pipeline(model_name, progress):
27
+ if model_name not in model_cache:
28
+ progress(0, desc="πŸš€ Initializing ZeroGPU instance...")
29
+ model_id = MODEL_SIZES[model_name]
30
+ device = 0 if torch.cuda.is_available() else "cpu"
31
+ progress(0.1, desc=f"⏳ Loading {model_name} model...")
32
+ model_cache[model_name] = pipeline(
33
+ "automatic-speech-recognition",
34
+ model=model_id,
35
+ device=device
36
+ )
37
+ progress(0.5, desc="βœ… Model loaded successfully!")
38
+ return model_cache[model_name]
39
+
40
+ # --- French summarization pipeline ---
41
+ def get_summary_pipeline():
42
+ if "summarizer" not in summary_cache:
43
+ summary_cache["summarizer"] = pipeline(
44
+ "summarization",
45
+ model="csebuetnlp/mT5_multilingual_XLSum"
46
+ )
47
+ return summary_cache["summarizer"]
48
+
49
+ # --- Export functions ---
50
+ def create_vtt(segments, file_path):
51
+ with open(file_path, "w", encoding="utf-8") as f:
52
+ f.write("WEBVTT\n\n")
53
+ for i, segment in enumerate(segments):
54
+ start_seconds = segment.get('start', 0)
55
+ end_seconds = segment.get('end', 0)
56
+ start = str(datetime.timedelta(seconds=int(start_seconds)))
57
+ end = str(datetime.timedelta(seconds=int(end_seconds)))
58
+ f.write(f"{i+1}\n{start} --> {end}\n{segment.get('text', '').strip()}\n\n")
59
+
60
+ def create_docx(segments, file_path, with_timestamps):
61
+ document = Document()
62
+ document.add_heading("Transcription", 0)
63
+ if with_timestamps:
64
+ for segment in segments:
65
+ text = segment.get('text', '').strip()
66
+ start_seconds = segment.get('start', 0)
67
+ end_seconds = segment.get('end', 0)
68
+ start = str(datetime.timedelta(seconds=int(start_seconds)))
69
+ end = str(datetime.timedelta(seconds=int(end_seconds)))
70
+ document.add_paragraph(f"[{start} - {end}] {text}")
71
+ else:
72
+ full_text = " ".join([segment.get('text', '').strip() for segment in segments])
73
+ document.add_paragraph(full_text)
74
+ document.save(file_path)
75
+
76
+ # --- Extract audio from video/audio ---
77
+ def extract_audio_from_video(file_path):
78
+ ext = os.path.splitext(file_path)[1].lower()
79
+ if ext in [".wav", ".mp3", ".m4a", ".flac"]:
80
+ return file_path # Already audio
81
+ temp_audio = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
82
+ temp_audio.close()
83
+ audio = AudioSegment.from_file(file_path)
84
+ audio.export(temp_audio.name, format="wav")
85
+ return temp_audio.name
86
+
87
+ # --- Main transcription function ---
88
+ @spaces.GPU
89
+ def transcribe_and_export(file, model_size, vtt_output, docx_timestamp_output, docx_no_timestamp_output, generate_summary, progress=gr.Progress()):
90
+ if file is None:
91
+ return None, None, None, None, "Please upload an audio or video file."
92
+
93
+ start_time = time.time()
94
+ audio_file_path = extract_audio_from_video(file)
95
+
96
+ # Transcription
97
+ pipe = get_model_pipeline(model_size, progress)
98
+ progress(0.75, desc="🎀 Transcribing audio...")
99
+ if model_size == "Distil-Large-v3-FR (French-Specific)":
100
+ raw_output = pipe(audio_file_path, return_timestamps=True, generate_kwargs={"language": "fr"})
101
+ else:
102
+ raw_output = pipe(audio_file_path, return_timestamps=True)
103
+
104
+ segments = raw_output.get("chunks", [])
105
+ outputs = {}
106
+ progress(0.85, desc="πŸ“ Generating output files...")
107
+
108
+ if vtt_output:
109
+ vtt_path = "transcription.vtt"
110
+ create_vtt(segments, vtt_path)
111
+ outputs["VTT"] = vtt_path
112
+ if docx_timestamp_output:
113
+ docx_ts_path = "transcription_with_timestamps.docx"
114
+ create_docx(segments, docx_ts_path, with_timestamps=True)
115
+ outputs["DOCX (with timestamps)"] = docx_ts_path
116
+ if docx_no_timestamp_output:
117
+ docx_no_ts_path = "transcription_without_timestamps.docx"
118
+ create_docx(segments, docx_no_ts_path, with_timestamps=False)
119
+ outputs["DOCX (without timestamps)"] = docx_no_ts_path
120
+
121
+ transcribed_text = raw_output['text']
122
+
123
+ # Generate summary if requested
124
+ summary_text = None
125
+ if generate_summary:
126
+ progress(0.95, desc="πŸ“ Generating summary...")
127
+ summarizer = get_summary_pipeline()
128
+ summary_output = summarizer(transcribed_text, max_length=150, min_length=30, do_sample=False)
129
+ summary_text = summary_output[0]['summary_text']
130
+
131
+ end_time = time.time()
132
+ total_time = end_time - start_time
133
+ downloadable_files = [path for path in outputs.values()]
134
+ status_message = f"βœ… Transcription complete! Total time: {total_time:.2f} seconds."
135
+
136
+ return transcribed_text, gr.Files(value=downloadable_files, label="Download Transcripts"), audio_file_path, summary_text, status_message
137
+
138
+ # --- Gradio UI ---
139
+ with gr.Blocks(title="Whisper ZeroGPU Transcription") as demo:
140
+ gr.Markdown("# πŸŽ™οΈ Whisper ZeroGPU Transcription")
141
+ gr.Markdown("Transcribe audio or video files with timestamps, and optionally generate a French summary.")
142
+
143
+ with gr.Row():
144
+ audio_input = gr.Audio(sources=["microphone", "upload"], type="filepath", label="Audio/Video File")
145
+ with gr.Column(scale=2):
146
+ model_selector = gr.Dropdown(
147
+ label="Choose Whisper Model Size",
148
+ choices=list(MODEL_SIZES.keys()),
149
+ value="Distil-Large-v3-FR (French-Specific)"
150
+ )
151
+ gr.Markdown("### Choose Output Formats")
152
+ with gr.Row():
153
+ vtt_checkbox = gr.Checkbox(label="VTT", value=True)
154
+ docx_ts_checkbox = gr.Checkbox(label="DOCX (with timestamps)", value=False)
155
+ docx_no_ts_checkbox = gr.Checkbox(label="DOCX (without timestamps)", value=True)
156
+ summary_checkbox = gr.Checkbox(label="Generate Summary", value=False)
157
+
158
+ transcribe_btn = gr.Button("Transcribe", variant="primary")
159
+ status_text = gr.Textbox(label="Status", interactive=False)
160
+
161
+ transcription_output = gr.Textbox(label="Full Transcription", lines=10)
162
+ downloadable_files_output = gr.Files(label="Download Transcripts")
163
+ summary_output = gr.Textbox(label="Summary", lines=5)
164
+
165
+ transcribe_btn.click(
166
+ fn=transcribe_and_export,
167
+ inputs=[audio_input, model_selector, vtt_checkbox, docx_ts_checkbox, docx_no_ts_checkbox, summary_checkbox],
168
+ outputs=[transcription_output, downloadable_files_output, audio_input, summary_output, status_text]
169
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
170
 
171
  if __name__ == "__main__":
172
+ demo.launch()