clementBE commited on
Commit
67208fe
·
verified ·
1 Parent(s): 9cffd38

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +256 -164
app.py CHANGED
@@ -1,172 +1,264 @@
1
  import os
 
2
  import tempfile
3
- import datetime
4
- import time
 
 
5
  import torch
 
 
 
 
 
 
 
6
  import gradio as gr
7
- import spaces
8
- from transformers import pipeline
9
- from docx import Document
10
- from pydub import AudioSegment
11
-
12
- # --- Model definitions ---
13
- MODEL_SIZES = {
14
- "Tiny (Fastest)": "openai/whisper-tiny",
15
- "Base (Faster)": "openai/whisper-base",
16
- "Small (Balanced)": "openai/whisper-small",
17
- "Distil-Large-v3 (General Purpose)": "distil-whisper/distil-large-v3",
18
- "Distil-Large-v3-FR (French-Specific)": "eustlb/distil-large-v3-fr"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  }
20
 
21
- # --- Caches ---
22
- model_cache = {}
23
- summary_cache = {}
24
-
25
- # --- Whisper pipeline loader ---
26
- def get_model_pipeline(model_name, progress):
27
- if model_name not in model_cache:
28
- progress(0, desc="🚀 Initializing ZeroGPU instance...")
29
- model_id = MODEL_SIZES[model_name]
30
- device = 0 if torch.cuda.is_available() else "cpu"
31
- progress(0.1, desc=f"⏳ Loading {model_name} model...")
32
- model_cache[model_name] = pipeline(
33
- "automatic-speech-recognition",
34
- model=model_id,
35
- device=device
36
- )
37
- progress(0.5, desc="✅ Model loaded successfully!")
38
- return model_cache[model_name]
39
-
40
- # --- French summarization pipeline ---
41
- def get_summary_pipeline():
42
- if "summarizer" not in summary_cache:
43
- summary_cache["summarizer"] = pipeline(
44
- "summarization",
45
- model="csebuetnlp/mT5_multilingual_XLSum"
46
- )
47
- return summary_cache["summarizer"]
48
-
49
- # --- Export functions ---
50
- def create_vtt(segments, file_path):
51
- with open(file_path, "w", encoding="utf-8") as f:
52
- f.write("WEBVTT\n\n")
53
- for i, segment in enumerate(segments):
54
- start_seconds = segment.get('start', 0)
55
- end_seconds = segment.get('end', 0)
56
- start = str(datetime.timedelta(seconds=int(start_seconds)))
57
- end = str(datetime.timedelta(seconds=int(end_seconds)))
58
- f.write(f"{i+1}\n{start} --> {end}\n{segment.get('text', '').strip()}\n\n")
59
-
60
- def create_docx(segments, file_path, with_timestamps):
61
- document = Document()
62
- document.add_heading("Transcription", 0)
63
- if with_timestamps:
64
- for segment in segments:
65
- text = segment.get('text', '').strip()
66
- start_seconds = segment.get('start', 0)
67
- end_seconds = segment.get('end', 0)
68
- start = str(datetime.timedelta(seconds=int(start_seconds)))
69
- end = str(datetime.timedelta(seconds=int(end_seconds)))
70
- document.add_paragraph(f"[{start} - {end}] {text}")
71
- else:
72
- full_text = " ".join([segment.get('text', '').strip() for segment in segments])
73
- document.add_paragraph(full_text)
74
- document.save(file_path)
75
-
76
- # --- Extract audio from video/audio ---
77
- def extract_audio_from_video(file_path):
78
- ext = os.path.splitext(file_path)[1].lower()
79
- if ext in [".wav", ".mp3", ".m4a", ".flac"]:
80
- return file_path # Already audio
81
- temp_audio = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
82
- temp_audio.close()
83
- audio = AudioSegment.from_file(file_path)
84
- audio.export(temp_audio.name, format="wav")
85
- return temp_audio.name
86
-
87
- # --- Main transcription function ---
88
- @spaces.GPU
89
- def transcribe_and_export(file, model_size, vtt_output, docx_timestamp_output, docx_no_timestamp_output, generate_summary, progress=gr.Progress()):
90
- if file is None:
91
- return None, None, None, None, "Please upload an audio or video file."
92
-
93
- start_time = time.time()
94
- audio_file_path = extract_audio_from_video(file)
95
-
96
- # Transcription
97
- pipe = get_model_pipeline(model_size, progress)
98
- progress(0.75, desc="🎤 Transcribing audio...")
99
- if model_size == "Distil-Large-v3-FR (French-Specific)":
100
- raw_output = pipe(audio_file_path, return_timestamps=True, generate_kwargs={"language": "fr"})
101
- else:
102
- raw_output = pipe(audio_file_path, return_timestamps=True)
103
-
104
- segments = raw_output.get("chunks", [])
105
- outputs = {}
106
- progress(0.85, desc="📝 Generating output files...")
107
-
108
- if vtt_output:
109
- vtt_path = "transcription.vtt"
110
- create_vtt(segments, vtt_path)
111
- outputs["VTT"] = vtt_path
112
- if docx_timestamp_output:
113
- docx_ts_path = "transcription_with_timestamps.docx"
114
- create_docx(segments, docx_ts_path, with_timestamps=True)
115
- outputs["DOCX (with timestamps)"] = docx_ts_path
116
- if docx_no_timestamp_output:
117
- docx_no_ts_path = "transcription_without_timestamps.docx"
118
- create_docx(segments, docx_no_ts_path, with_timestamps=False)
119
- outputs["DOCX (without timestamps)"] = docx_no_ts_path
120
-
121
- transcribed_text = raw_output['text']
122
-
123
- # Generate summary if requested
124
- summary_text = None
125
- if generate_summary:
126
- progress(0.95, desc="📝 Generating summary...")
127
- summarizer = get_summary_pipeline()
128
- summary_output = summarizer(transcribed_text, max_length=150, min_length=30, do_sample=False)
129
- summary_text = summary_output[0]['summary_text']
130
-
131
- end_time = time.time()
132
- total_time = end_time - start_time
133
- downloadable_files = [path for path in outputs.values()]
134
- status_message = f"✅ Transcription complete! Total time: {total_time:.2f} seconds."
135
-
136
- return transcribed_text, gr.Files(value=downloadable_files, label="Download Transcripts"), audio_file_path, summary_text, status_message
137
-
138
- # --- Gradio UI ---
139
- with gr.Blocks(title="Whisper ZeroGPU Transcription") as demo:
140
- gr.Markdown("# 🎙️ Whisper ZeroGPU Transcription")
141
- gr.Markdown("Transcribe audio or video files with timestamps, and optionally generate a French summary.")
142
-
143
- with gr.Row():
144
- audio_input = gr.Audio(sources=["microphone", "upload"], type="filepath", label="Audio/Video File")
145
- with gr.Column(scale=2):
146
- model_selector = gr.Dropdown(
147
- label="Choose Whisper Model Size",
148
- choices=list(MODEL_SIZES.keys()),
149
- value="Distil-Large-v3-FR (French-Specific)"
150
- )
151
- gr.Markdown("### Choose Output Formats")
152
- with gr.Row():
153
- vtt_checkbox = gr.Checkbox(label="VTT", value=True)
154
- docx_ts_checkbox = gr.Checkbox(label="DOCX (with timestamps)", value=False)
155
- docx_no_ts_checkbox = gr.Checkbox(label="DOCX (without timestamps)", value=True)
156
- summary_checkbox = gr.Checkbox(label="Generate Summary", value=False)
157
-
158
- transcribe_btn = gr.Button("Transcribe", variant="primary")
159
- status_text = gr.Textbox(label="Status", interactive=False)
160
-
161
- transcription_output = gr.Textbox(label="Full Transcription", lines=10)
162
- downloadable_files_output = gr.Files(label="Download Transcripts")
163
- summary_output = gr.Textbox(label="Summary", lines=5)
164
-
165
- transcribe_btn.click(
166
- fn=transcribe_and_export,
167
- inputs=[audio_input, model_selector, vtt_checkbox, docx_ts_checkbox, docx_no_ts_checkbox, summary_checkbox],
168
- outputs=[transcription_output, downloadable_files_output, audio_input, summary_output, status_text]
169
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
170
 
171
  if __name__ == "__main__":
172
- demo.launch()
 
1
  import os
2
+ import zipfile
3
  import tempfile
4
+ import requests
5
+ import numpy as np
6
+ import pandas as pd
7
+ from PIL import Image
8
  import torch
9
+ import torch.nn.functional as F
10
+ from torchvision import transforms
11
+ from torchvision.models import resnet50, ResNet50_Weights
12
+ from sklearn.cluster import MiniBatchKMeans
13
+ import matplotlib.pyplot as plt
14
+ import io
15
+
16
  import gradio as gr
17
+
18
+ # Face analysis
19
+ from deepface import DeepFace
20
+ import cv2
21
+
22
+ # ---------------------------
23
+ # Force CPU if no CUDA
24
+ # ---------------------------
25
+ if not torch.cuda.is_available():
26
+ os.environ["CUDA_VISIBLE_DEVICES"] = ""
27
+
28
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
29
+
30
+ # ---------------------------
31
+ # Load ResNet50
32
+ # ---------------------------
33
+ weights = ResNet50_Weights.DEFAULT
34
+ model = resnet50(weights=weights).to(device)
35
+ model.eval()
36
+
37
+ # ---------------------------
38
+ # Transformations
39
+ # ---------------------------
40
+ transform = transforms.Compose([
41
+ transforms.Resize(256),
42
+ transforms.CenterCrop(224),
43
+ transforms.ToTensor(),
44
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
45
+ std=[0.229, 0.224, 0.225]),
46
+ ])
47
+
48
+ # ---------------------------
49
+ # ImageNet labels
50
+ # ---------------------------
51
+ LABELS_URL = "https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt"
52
+ imagenet_classes = [line.strip() for line in requests.get(LABELS_URL).text.splitlines()]
53
+
54
+ # ---------------------------
55
+ # Color utilities
56
+ # ---------------------------
57
+ BASIC_COLORS = {
58
+ "Red": (255, 0, 0),
59
+ "Green": (0, 255, 0),
60
+ "Blue": (0, 0, 255),
61
+ "Yellow": (255, 255, 0),
62
+ "Cyan": (0, 255, 255),
63
+ "Magenta": (255, 0, 255),
64
+ "Black": (0, 0, 0),
65
+ "White": (255, 255, 255),
66
+ "Gray": (128, 128, 128),
67
  }
68
 
69
+ def closest_basic_color(rgb):
70
+ r, g, b = rgb
71
+ min_dist = float("inf")
72
+ closest_color = None
73
+ for name, (cr, cg, cb) in BASIC_COLORS.items():
74
+ dist = (r - cr) ** 2 + (g - cg) ** 2 + (b - cb) ** 2
75
+ if dist < min_dist:
76
+ min_dist = dist
77
+ closest_color = name
78
+ return closest_color
79
+
80
+ def get_dominant_color(image, num_colors=5):
81
+ image = image.resize((100, 100))
82
+ pixels = np.array(image).reshape(-1, 3)
83
+ kmeans = MiniBatchKMeans(n_clusters=num_colors, random_state=0, n_init=5)
84
+ kmeans.fit(pixels)
85
+ dominant_color = kmeans.cluster_centers_[np.argmax(np.bincount(kmeans.labels_))]
86
+ dominant_color = tuple(dominant_color.astype(int))
87
+ hex_color = f"#{dominant_color[0]:02x}{dominant_color[1]:02x}{dominant_color[2]:02x}"
88
+ return dominant_color, hex_color
89
+
90
+ # ---------------------------
91
+ # Core function
92
+ # ---------------------------
93
+ def classify_zip_and_analyze_color(zip_file):
94
+ results = []
95
+
96
+ with tempfile.TemporaryDirectory() as tmpdir:
97
+ with zipfile.ZipFile(zip_file.name, 'r') as zip_ref:
98
+ zip_ref.extractall(tmpdir)
99
+
100
+ for fname in sorted(os.listdir(tmpdir)):
101
+ if fname.lower().endswith(('.png', '.jpg', '.jpeg')):
102
+ img_path = os.path.join(tmpdir, fname)
103
+ try:
104
+ image = Image.open(img_path).convert("RGB")
105
+ except Exception:
106
+ continue
107
+
108
+ # Classification
109
+ input_tensor = transform(image).unsqueeze(0).to(device)
110
+ with torch.no_grad():
111
+ output = model(input_tensor)
112
+ probs = F.softmax(output, dim=1)[0]
113
+
114
+ top3_prob, top3_idx = torch.topk(probs, 3)
115
+ preds = [(imagenet_classes[idx], f"{prob.item()*100:.2f}%") for idx, prob in zip(top3_idx, top3_prob)]
116
+
117
+ # Dominant color
118
+ rgb, hex_color = get_dominant_color(image)
119
+ basic_color = closest_basic_color(rgb)
120
+
121
+ # ---------------------------
122
+ # Face detection & characterization
123
+ # ---------------------------
124
+ face_info = ""
125
+ try:
126
+ img_cv2 = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
127
+ faces = DeepFace.analyze(img_cv2, actions=["age", "gender", "emotion"], enforce_detection=False)
128
+ if isinstance(faces, list): # multiple faces
129
+ for f in faces:
130
+ face_info += f"Age: {f['age']}, Gender: {f['gender']}, Gender Confidence: {f['gender_confidence']*100:.2f}, Emotion: {f['dominant_emotion']}; "
131
+ else: # single face
132
+ face_info = f"Age: {faces['age']}, Gender: {faces['gender']}, Gender Confidence: {faces['gender_confidence']*100:.2f}, Emotion: {faces['dominant_emotion']}"
133
+ except Exception as e:
134
+ face_info = "No face detected"
135
+
136
+ results.append((
137
+ fname,
138
+ ", ".join([p[0] for p in preds]),
139
+ ", ".join([p[1] for p in preds]),
140
+ hex_color,
141
+ basic_color,
142
+ face_info
143
+ ))
144
+
145
+ # Build dataframe
146
+ df = pd.DataFrame(results, columns=["Filename", "Top 3 Predictions", "Confidence", "Dominant Color", "Basic Color", "Face Info"])
147
+
148
+ # Save XLSX
149
+ out_xlsx = os.path.join(tempfile.gettempdir(), "results.xlsx")
150
+ df.to_excel(out_xlsx, index=False)
151
+
152
+ # ---------------------------
153
+ # Plot 1: Basic color frequency
154
+ # ---------------------------
155
+ fig1, ax1 = plt.subplots()
156
+ color_counts = df["Basic Color"].value_counts()
157
+ ax1.bar(color_counts.index, color_counts.values, color="skyblue")
158
+ ax1.set_title("Basic Color Frequency")
159
+ ax1.set_ylabel("Count")
160
+ buf1 = io.BytesIO()
161
+ plt.savefig(buf1, format="png")
162
+ plt.close(fig1)
163
+ buf1.seek(0)
164
+ plot1_img = Image.open(buf1)
165
+
166
+ # ---------------------------
167
+ # Plot 2: Top prediction distribution
168
+ # ---------------------------
169
+ fig2, ax2 = plt.subplots()
170
+ preds_flat = []
171
+ for p in df["Top 3 Predictions"]:
172
+ preds_flat.extend(p.split(", "))
173
+ pred_counts = pd.Series(preds_flat).value_counts().head(20)
174
+ ax2.barh(pred_counts.index[::-1], pred_counts.values[::-1], color="salmon")
175
+ ax2.set_title("Top Prediction Distribution")
176
+ ax2.set_xlabel("Count")
177
+ buf2 = io.BytesIO()
178
+ plt.savefig(buf2, format="png", bbox_inches="tight")
179
+ plt.close(fig2)
180
+ buf2.seek(0)
181
+ plot2_img = Image.open(buf2)
182
+
183
+ # ---------------------------
184
+ # Extract age and gender (confidence 80%)
185
+ # ---------------------------
186
+ ages = []
187
+ gender_confidence = {"Man": 0, "Woman": 0}
188
+
189
+ for info in df["Face Info"]:
190
+ if info != "No face detected":
191
+ for face_str in info.split(";"):
192
+ face_str = face_str.strip()
193
+ if face_str:
194
+ # Age
195
+ age_part = face_str.split(",")[0]
196
+ age = int(age_part.replace("Age:", "").strip())
197
+ ages.append(age)
198
+
199
+ # Gender and confidence
200
+ gender_part = face_str.split(",")[1]
201
+ gender = gender_part.replace("Gender:", "").strip()
202
+
203
+ # Extract confidence
204
+ conf = 1.0
205
+ for part in face_str.split(","):
206
+ if "Gender Confidence:" in part:
207
+ conf = float(part.split("Gender Confidence:")[1].strip()) / 100 # convert % to 0-1
208
+
209
+ # Only include if confidence ≤ 0.8
210
+ if conf <= 0.8:
211
+ if gender in gender_confidence:
212
+ gender_confidence[gender] += conf
213
+ else:
214
+ gender_confidence[gender] = conf
215
+
216
+ # ---------------------------
217
+ # Plot 3: Gender distribution (confidence ≤ 80%)
218
+ # ---------------------------
219
+ fig3, ax3 = plt.subplots()
220
+ ax3.bar(gender_confidence.keys(), gender_confidence.values(), color=["lightblue", "pink"])
221
+ ax3.set_title("Gender Distribution (Confidence ≤ 80%)")
222
+ ax3.set_ylabel("Sum of Confidence")
223
+ buf3 = io.BytesIO()
224
+ plt.savefig(buf3, format="png")
225
+ plt.close(fig3)
226
+ buf3.seek(0)
227
+ plot3_img = Image.open(buf3)
228
+
229
+ # ---------------------------
230
+ # Plot 4: Age distribution
231
+ # ---------------------------
232
+ fig4, ax4 = plt.subplots()
233
+ ax4.hist(ages, bins=range(0, 101, 5), color="lightgreen", edgecolor="black")
234
+ ax4.set_title("Age Distribution")
235
+ ax4.set_xlabel("Age")
236
+ ax4.set_ylabel("Count")
237
+ buf4 = io.BytesIO()
238
+ plt.savefig(buf4, format="png")
239
+ plt.close(fig4)
240
+ buf4.seek(0)
241
+ plot4_img = Image.open(buf4)
242
+
243
+ return df, out_xlsx, plot1_img, plot2_img, plot3_img, plot4_img
244
+
245
+ # ---------------------------
246
+ # Gradio Interface
247
+ # ---------------------------
248
+ demo = gr.Interface(
249
+ fn=classify_zip_and_analyze_color,
250
+ inputs=gr.File(file_types=[".zip"], label="Upload ZIP of images"),
251
+ outputs=[
252
+ gr.Dataframe(headers=["Filename", "Top 3 Predictions", "Confidence", "Dominant Color", "Basic Color", "Face Info"]),
253
+ gr.File(label="Download XLSX"),
254
+ gr.Image(type="pil", label="Basic Color Frequency"),
255
+ gr.Image(type="pil", label="Top Prediction Distribution"),
256
+ gr.Image(type="pil", label="Gender Distribution (≤80% Confidence)"),
257
+ gr.Image(type="pil", label="Age Distribution"),
258
+ ],
259
+ title="Image Classifier with Color & Face Analysis",
260
+ description="Upload a ZIP of images. Classifies images, analyzes dominant color, and detects/characterizes faces (age, gender, emotion).",
261
+ )
262
 
263
  if __name__ == "__main__":
264
+ demo.launch(server_name="0.0.0.0", server_port=7860)