asr / app.py
VeuReu's picture
Update app.py
812c8de verified
# app.py — veureu/asr (Aina faster-whisper Catalan · ZeroGPU) — compatible with ENGINE
from __future__ import annotations
import os, json, tempfile
from typing import Dict, Any, List, Tuple, Optional
import gradio as gr
import spaces
import torch
# faster-whisper (CTranslate2)
from faster_whisper import WhisperModel
# =========================
# Config and lazy loading
# =========================
# By default we use the Catalan finetune from projecte-aina on HF.
# Change MODEL_ID to the exact repo you are using (e.g.: "projecte-aina/faster-whisper-large-v3-ca-3catparla")
MODEL_ID = os.environ.get("MODEL_ID", "projecte-aina/faster-whisper-large-v3-ca-3catparla")
# Detect if there is a GPU (ZeroGPU) -> fp16, otherwise INT8
HAS_CUDA = os.environ.get("CUDA_VISIBLE_DEVICES") not in (None, "", "-1")
DEVICE = "cuda" if HAS_CUDA else "cpu"
COMPUTE_TYPE = "float16" if HAS_CUDA else "int8" # "int8_float16" also works on low-end GPUs
_model: Optional[WhisperModel] = None
def _lazy_model() -> WhisperModel:
global _model
if _model is None:
_model = WhisperModel(
MODEL_ID,
device=DEVICE,
compute_type=COMPUTE_TYPE,
download_root=os.environ.get("HF_HOME") or None, # optional
)
return _model
_model_whis = None
_processor_whis = None
def _lazy_load_whisper():
"""
Lazy load para Whisper en HuggingFace Spaces (Stateless GPU compatible).
Evita inicializar CUDA en el proceso principal.
"""
global _model_whis, _processor_whis
if _model_whis is None or _processor_whis is None:
model_name = "projecte-aina/whisper-large-v3-ca-3catparla"
# processor
_processor_whis = WhisperProcessor.from_pretrained(model_name)
# model
m = WhisperForConditionalGeneration.from_pretrained(
model_name,
low_cpu_mem_usage=True,
use_safetensors=True,
)
m = m.to(DEVICE)
_model_whis = m
return _processor_whis, _model_whis
# ==================================
# Transcription core (Catalan)
# ==================================
@spaces.GPU
def _transcribe_core(
audio_path: str,
language: str = "ca",
task: str = "transcribe",
vad_filter: bool = True,
beam_size: int = 5,
temperature: float = 0.0,
word_timestamps: bool = False,
) -> Dict[str, Any]:
"""
Returns:
{
"text": "transcription…",
"segments": [
{"start": 0.10, "end": 1.92, "text": "…"},
...
],
"language": "ca",
"info": { "duration": ..., "device": "cuda/cpu", "compute_type": "float16/int8" }
}
"""
model = _lazy_model()
# faster-whisper produces a generator of segments + info
segments, info = model.transcribe(
audio_path,
language=language or "ca",
task=task,
vad_filter=vad_filter,
beam_size=int(beam_size),
temperature=float(temperature),
word_timestamps=bool(word_timestamps),
)
segs: List[Dict[str, Any]] = []
full_text_parts: List[str] = []
for seg in segments:
text = (seg.text or "").strip()
full_text_parts.append(text)
segs.append({
"start": round(float(seg.start), 3) if seg.start is not None else None,
"end": round(float(seg.end), 3) if seg.end is not None else None,
"text": text,
})
out = {
"text": " ".join([t for t in full_text_parts if t]),
"segments": segs,
"language": language or "ca",
"info": {
"duration": getattr(info, "duration", None),
"device": DEVICE,
"compute_type": COMPUTE_TYPE,
},
}
return out
# ==========================
# Endpoints Gradio (API/UI)
# ==========================
# 1) /predict — el que usa el ENGINE vía gradio_client
# Firma minimalista: solo el audio; el resto con defaults.
def predict_for_engine(
audio_file, # gr.Audio o gr.File
language: str = "ca",
timestamps: bool = True,
vad_filter: bool = True,
) -> Dict[str, Any]:
"""
ENGINE llama normalmente con: client.predict(<audio_path>, api_name="/predict")
Devolvemos dict con 'text' y 'segments'.
"""
# Gradio puede darte un dict {'name', 'data'} o una ruta directamente
path = None
if isinstance(audio_file, dict) and audio_file.get("name"):
path = audio_file["name"]
elif isinstance(audio_file, str):
path = audio_file
elif hasattr(audio_file, "name"):
path = audio_file.name
if not path:
return {"text": "", "segments": [], "language": language, "info": {"error": "no_audio"}}
return _transcribe_core(
path,
language=language or "ca",
task="transcribe",
vad_filter=bool(vad_filter),
beam_size=5,
temperature=0.0,
word_timestamps=bool(timestamps),
)
# 2) /transcribe — endpoint alternativo con más controles (útil para pruebas manuales/HTTP)
def transcribe_advanced(
audio_file,
language: str = "ca",
task: str = "transcribe", # "transcribe" | "translate"
vad_filter: bool = True,
beam_size: int = 5,
temperature: float = 0.0,
word_timestamps: bool = False,
) -> Dict[str, Any]:
path = None
if isinstance(audio_file, dict) and audio_file.get("name"):
path = audio_file["name"]
elif isinstance(audio_file, str):
path = audio_file
elif hasattr(audio_file, "name"):
path = audio_file.name
if not path:
return {"text": "", "segments": [], "language": language, "info": {"error": "no_audio"}}
return _transcribe_core(
path,
language=language or "ca",
task=task or "transcribe",
vad_filter=bool(vad_filter),
beam_size=int(beam_size),
temperature=float(temperature),
word_timestamps=bool(word_timestamps),
)
import math
from typing import Any, Dict, List, Tuple
from pydub import AudioSegment
from pyannote.audio import Pipeline
from io import BytesIO
import base64
import soundfile as sf
def diarize_audio(
wav_file: str,
min_segment_duration: float = 0.5,
max_segment_duration: float = 50.0,
) -> Tuple[List[str], List[Dict[str, Any]]]:
"""
Audio diarization that:
- Reads a WAV file
- Returns clips in memory as dicts for Gradio (without saving files)
- Returns the list of segments [{'start','end','speaker'}]
"""
# Load audio and calculate duration
audio = AudioSegment.from_wav(wav_file)
duration = len(audio) / 1000.0
# Diarization pipeline
pipeline = Pipeline.from_pretrained(
"pyannote/speaker-diarization-3.1",
use_auth_token=os.getenv('HF_TOKEN')
)
diarization = pipeline(wav_file)
clip_buffers: List[Tuple[str, BytesIO]] = []
segments: List[Dict[str, Any]] = []
spk_map: Dict[str, int] = {}
prev_end = 0.0
# Process each segment
for i, (turn, _, speaker) in enumerate(diarization.itertracks(yield_label=True)):
start, end = max(0.0, float(turn.start)), min(duration, float(turn.end))
if start < prev_end:
start = prev_end
if end <= start:
continue
seg_dur = end - start
if seg_dur < min_segment_duration:
continue
# Split very long segments
if seg_dur > max_segment_duration:
n = int(math.ceil(seg_dur / max_segment_duration))
sub_d = seg_dur / n
for j in range(n):
s = start + j * sub_d
e = min(end, start + (j + 1) * sub_d)
clip = audio[int(s*1000):int(e*1000)]
print(f"Creating clip from {s} to {e} seconds")
buf = BytesIO()
clip.export(buf, format="wav")
buf.seek(0)
clip_buffers.append((f"segment_{i:03d}_{j:02d}.wav", buf))
if speaker not in spk_map:
spk_map[speaker] = len(spk_map)
segments.append({"start": s, "end": e, "speaker": f"SPEAKER_{spk_map[speaker]:02d}"})
prev_end = e
else:
clip = audio[int(start*1000):int(end*1000)]
buf = BytesIO()
clip.export(buf, format="wav")
buf.seek(0)
clip_buffers.append((f"segment_{i:03d}.wav", buf))
if speaker not in spk_map:
spk_map[speaker] = len(spk_map)
segments.append({"start": start, "end": end, "speaker": f"SPEAKER_{spk_map[speaker]:02d}"})
prev_end = end
# If no segments, use the entire audio
if not segments:
buf = BytesIO()
audio.export(buf, format="wav")
buf.seek(0)
return [{"name": "segment_000.wav", "data": base64.b64encode(buf.read()).decode("utf-8")}], [{"start": 0.0, "end": duration, "speaker": "SPEAKER_00"}]
# Convert all clips to dicts for Gradio
print("Clip buffers:")
print(clip_buffers)
gr_clips = []
for i, (name, buf) in enumerate(clip_buffers, start=1):
buf.seek(0)
# Create temporary file but with friendly name
tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
tmp_file.write(buf.read())
tmp_file.close()
# Rename to something like "clip1.wav", "clip2.wav", ...
new_name = f"clip{i}.wav"
new_path = os.path.join(tempfile.gettempdir(), new_name)
os.rename(tmp_file.name, new_path)
gr_clips.append(new_path)
print("Gradio clips prepared.")
print(gr_clips)
return gr_clips, segments
import numpy as np
import torchaudio.transforms as T
from speechbrain.inference import SpeakerRecognition
from typing import List
import torchaudio
import torch
def voice_embedder(wav_file: str) -> List[float]:
print("======================================================")
model = SpeakerRecognition.from_hparams(
source="pretrained_models/spkrec-ecapa-voxceleb",
savedir="pretrained_models/spkrec-ecapa-voxceleb"
)
model.eval()
print("======================================================")
# Audio preprocessing
waveform, sr = torchaudio.load(wav_file)
target_sr = 16000
# Resample if needed
if sr != target_sr:
waveform = T.Resample(orig_freq=sr, new_freq=target_sr)(waveform)
# Convert to mono if stereo
if waveform.shape[0] > 1:
waveform = waveform.mean(dim=0, keepdim=True)
# Minimum duration of 0.2 seconds
min_samples = int(0.2 * target_sr)
if waveform.shape[1] < min_samples:
pad = min_samples - waveform.shape[1]
waveform = torch.cat([waveform, torch.zeros((1, pad))], dim=1)
# Compute speaker embedding
with torch.no_grad():
emb = (
model.encode_batch(waveform)
.squeeze()
.cpu()
.numpy()
.astype(float)
)
# Normalize embedding
emb = emb / np.linalg.norm(emb)
print(len(emb))
print(emb.tolist())
return emb.tolist()
def identify_speaker(wav_file: str, voice_col: List[Dict[str, Any]]) -> Dict[str, Any]:
voice_embedding = voice_embedder(wav_file)
voice_col = json.loads(voice_col)
identity = "Desconegut"
knn = []
if voice_col and voice_embedding is not None:
try:
num_embeddings = len(voice_col)
if num_embeddings < 1:
knn = []
identity = "Desconegut"
else:
n_results = min(3, num_embeddings)
voice_embedding = np.array(voice_embedding)
distances_embedding = []
# Compute Euclidean distance between the detected voice and each stored embedding
for voice_base_datos in voice_col:
voice_base_datos_embedding = np.array(voice_base_datos["embedding"])
distance = np.linalg.norm(voice_embedding - voice_base_datos_embedding)
distances_embedding.append({
"identity": voice_base_datos["nombre"],
"distance": float(distance)
})
# Sort by distance and keep the top N matches
distances_embedding = sorted(distances_embedding, key=lambda x: x["distance"])
knn = distances_embedding[:n_results]
# Assign identity if closest match exists
if knn:
identity = knn[0]["identity"]
else:
identity = "Desconegut"
except Exception as e:
print(f"Voice KNN failed: {e}")
knn = []
identity = "Desconegut"
return {"knn": knn, "identity": identity}
import subprocess
from pathlib import Path
from audio_extract import extract_audio
import os
import shutil
import tempfile
def convert_to_temporary(original_file):
"""
Converts a file to a temporary file, deletes the original, and returns
the path of the temporary file.
"""
if not os.path.exists(original_file):
raise FileNotFoundError(f"{original_file} does not exist")
# Create a temporary file in persistent mode
temp_fd, temp_path = tempfile.mkstemp(suffix=os.path.splitext(original_file)[1])
os.close(temp_fd) # Close the file descriptor; we'll use it as a normal file
# Copy the content to the temporary file
shutil.copy2(original_file, temp_path)
# Delete the original file
os.remove(original_file)
return temp_path
def extract_audio_ffmpeg(video_file, sr: int = 16000, mono: bool = True):
"""
Extracts audio from a video file using FFmpeg and returns the path
to the generated WAV audio file.
Parameters
----------
video_file : str
The temporary file path provided by Gradio for the uploaded video.
sr : int
Target audio sample rate.
mono : bool
Whether to convert audio to mono channel.
Returns
-------
str
Filepath to the extracted WAV audio file.
"""
if video_file is None:
return None
# Extract the file name without extension
base_name = os.path.splitext(os.path.basename(video_file))[0]
# Build the output path with .wav extension
audio_out = f"./{base_name}.wav"
# If the file already exists, return it directly
if os.path.exists(audio_out+".mp3"):
return audio_out
# Call the function that performs the extraction
extract_audio(input_path=video_file, output_path=audio_out)
return convert_to_temporary(audio_out+".mp3")
import torch
import torchaudio
from dataclasses import dataclass
from transformers import WhisperProcessor, WhisperForConditionalGeneration
import logging
def load_audio(path, target_sr=16000):
waveform, sr = torchaudio.load(path)
if sr != target_sr:
waveform = torchaudio.transforms.Resample(orig_freq=sr, new_freq=target_sr)(waveform)
return waveform.squeeze().numpy()
def transcribe_wav(wav_path: str) -> str:
model_name = "projecte-aina/whisper-large-v3-ca-3catparla"
device = "cuda"
dev = device
if dev == "cuda" and not torch.cuda.is_available():
dev = "cpu"
# Lazy-load the Whisper processor and model
processor, model = _lazy_load_whisper()
device = dev
# Load the WAV file
waveform, sr = torchaudio.load(wav_path)
target_sr = 16000
if sr != target_sr:
# Resample audio if sample rate differs
waveform = torchaudio.transforms.Resample(orig_freq=sr, new_freq=target_sr)(waveform)
sr = target_sr
# Preprocess the audio
inputs = processor(
waveform.numpy(), sampling_rate=sr, return_tensors="pt"
).input_features.to(model.device)
# Generate transcription with the model
with torch.no_grad():
ids = model.generate(inputs, max_new_tokens=440)[0]
# Decode the transcription
txt = processor.decode(ids)
# Normalize text if necessary
norm = getattr(processor.tokenizer, "_normalize", None)
return norm(txt) if callable(norm) else txt
def transcribe_long_audio(
wav_path: str,
chunk_length_s: int = 20,
overlap_s: int = 2,
) -> str:
model_name = "projecte-aina/whisper-large-v3-ca-3catparla"
device = "cuda"
dev = device
if dev == "cuda" and not torch.cuda.is_available():
dev = "cpu"
# Lazy-load the Whisper processor and model
processor, model = _lazy_load_whisper()
device = dev
# Load the full WAV file
waveform, sr = torchaudio.load(wav_path)
target_sr = 16000
if sr != target_sr:
# Resample if sample rate differs
waveform = torchaudio.transforms.Resample(orig_freq=sr, new_freq=target_sr)(waveform)
sr = target_sr
total_samples = waveform.shape[1]
# Calculate chunk size and overlap in samples
chunk_size = chunk_length_s * sr
overlap_size = overlap_s * sr
transcriptions = []
start = 0
while start < total_samples:
end = min(start + chunk_size, total_samples)
chunk = waveform[:, start:end] # Transcribe in small fragments
# Preprocess the chunk
input_features = processor(
chunk.numpy(),
sampling_rate=sr,
return_tensors="pt"
).input_features.to(model.device)
# Generate transcription for the chunk
with torch.no_grad():
predicted_ids = model.generate(
input_features,
max_new_tokens=440,
num_beams=1,
)[0]
# Decode and store the chunk transcription
text = processor.decode(predicted_ids, skip_special_tokens=True)
transcriptions.append(text.strip())
# Move to the next chunk with overlap
start += chunk_size - overlap_size
# Join all chunks into a single string
return " ".join(transcriptions).strip()
"""
# ==============================================================================
# UI & Endpoints
# ==============================================================================
Collection of Gradio interface elements and API endpoints used by the application.
This section defines the user-facing interface for Salamandra Vision 7B,
allowing users to interact with the model through images, text prompts,
video uploads, and batch operations.
The components and endpoints in this module typically:
- Accept images, text, or video files from the user
- Apply optional parameters such as temperature, token limits, or crop ratios
- Preprocess inputs and invoke internal inference or utility functions
- Return structured outputs, including text descriptions, JSON metadata,
or image galleries
All endpoints are designed to be stateless, safe for concurrent calls,
and compatible with both interactive UI usage and programmatic API access.
# ==============================================================================
"""
custom_css = """
h2 {
background: #e3e4e6 !important;
padding: 14px 22px !important;
border-radius: 14px !important;
box-shadow: 0 4px 12px rgba(0,0,0,0.08) !important;
display: block !important; /* ocupa tot l'ample */
width: 100% !important; /* assegura 100% */
margin: 20px auto !important;
text-align:center;
}
"""
with gr.Blocks(title="Aina faster-whisper (Català) · ZeroGPU", css=custom_css,theme=gr.themes.Soft()) as demo:
# Extract audio from video
gr.Markdown('<h2 style="text-align:center">Extreure àudio d\'un vídeo</h2>')
with gr.Row():
video_input = gr.Video(label="Puja un vídeo")
with gr.Row():
extract_btn = gr.Button("Extreure àudio", variant="primary")
with gr.Row():
audio_output = gr.Audio(label="Àudio extret (WAV)", type="filepath")
extract_btn.click(
fn=extract_audio_ffmpeg,
inputs=video_input,
outputs=audio_output
)
# Diarization section
gr.Markdown('<h2 style="text-align:center">Diarització de l\'àudio</h2>')
with gr.Row():
audio_input = gr.Audio(label="Àudio per diaritzar", type="filepath")
process_btn = gr.Button("Diaritzar àudio", variant="primary")
clips_output = gr.File(label="Clips d\'àudio generats", file_types=[".wav"], file_count="multiple")
diarization_output = gr.JSON(label="Resultat de la diarització")
process_btn.click(
diarize_audio,
inputs=[audio_input],
outputs=[clips_output, diarization_output],
api_name="diaritzar_audio",
concurrency_limit=1
)
# Voice embeddings section
gr.Markdown('<h2 style="text-align:center">Obtenir l\'embedding d\'un àudio</h2>')
with gr.Row():
audio_input = gr.Audio(label="Àudio per obtenir l\'embedding", type="filepath")
with gr.Row():
process_btn = gr.Button("Obtenir embedding", variant="primary")
with gr.Row():
clip_out = gr.JSON(label="Embedding de veu (vector)")
process_btn.click(
voice_embedder,
[audio_input],
clip_out,
api_name="voice_embedding",
concurrency_limit=1
)
gr.Markdown("---")
# Speaker identification
gr.Markdown('<h2 style="text-align:center">Identificació de parlants</h2>')
with gr.Row():
audio_input = gr.Audio(label="Àudio per identificar el parlant", type="filepath")
with gr.Row():
voice_col_input = gr.Textbox(
label="Llista de diccionaris voice_col (format JSON)",
placeholder='[{"nom": "Anna", "embedding": [0.12, 0.88, ...]}, ...]',
lines=5
)
with gr.Row():
process_btn = gr.Button("Processar àudio (Persones)", variant="primary")
with gr.Row():
output_json = gr.JSON(label="Resultat complet")
process_btn.click(
identify_speaker,
inputs=[audio_input, voice_col_input],
outputs=output_json,
api_name="identificar_veu",
concurrency_limit=1
)
# Short audio transcription
gr.Markdown('<h2 style="text-align:center">Aina faster-whisper (Català) Àudio curt → text</h2>')
with gr.Row():
audio_input = gr.Audio(type="filepath", label="Puja el teu àudio")
with gr.Row():
boton = gr.Button("Transcriure", variant="primary")
with gr.Row():
output_text = gr.Textbox(label="Text transcrit")
boton.click(
fn=transcribe_wav,
inputs=audio_input,
outputs=output_text
)
# Long audio transcription
gr.Markdown('<h2 style="text-align:center">Aina faster-whisper (Català) Àudio llarg → text</h2>')
with gr.Row():
audio_input = gr.Audio(type="filepath", label="Puja el teu àudio")
with gr.Row():
boton2 = gr.Button("Transcriure", variant="primary")
with gr.Row():
output_text = gr.Textbox(label="Text transcrit")
boton2.click(
fn=transcribe_long_audio,
inputs=audio_input,
outputs=output_text
)
# Main transcription section
gr.Markdown('<h2 style="text-align:center">Aina faster-whisper (Català) · ZeroGPU - Reconeixement de veu en català finetune projecte-aina</h2>')
with gr.Row():
with gr.Column():
inp = gr.Audio(sources=["upload", "microphone"], type="filepath", label="Àudio (WAV/MP3/MP4, etc.)")
lang = gr.Textbox(label="Idioma", value="ca")
ts = gr.Checkbox(label="Marques de temps", value=True)
vad = gr.Checkbox(label="Filtre VAD", value=True)
with gr.Column():
out = gr.JSON(label="Sortida /predict")
with gr.Row():
btn = gr.Button("Transcriure (ENGINE /predict)", variant="primary")
# Button callback
btn.click(predict_for_engine, [inp, lang, ts, vad], out, api_name="predict", concurrency_limit=1)
# Advanced transcription section
gr.Markdown('<h2 style="text-align:center">Avançat (/transcribe)</h2>')
with gr.Row():
with gr.Column():
inp2 = gr.Audio(sources=["upload", "microphone"], type="filepath", label="Àudio")
lang2 = gr.Textbox(label="Idioma", value="ca")
task2 = gr.Dropdown(["transcribe", "translate"], value="transcribe", label="Tasques")
vad2 = gr.Checkbox(label="Filtre VAD", value=True)
beam2 = gr.Slider(1, 10, value=5, step=1, label="Mida del feix")
temp2 = gr.Slider(0.0, 1.5, value=0.0, step=0.1, label="Temperatura")
wts2 = gr.Checkbox(label="Marques de temps per paraula", value=False)
with gr.Column():
out2 = gr.JSON(label="Sortida /transcribe")
with gr.Row():
btn2 = gr.Button("Transcriure (avançat)", variant="primary")
# Button callback advanced
btn2.click(
transcribe_advanced,
[inp2, lang2, task2, vad2, beam2, temp2, wts2],
out2,
api_name="transcribe",
concurrency_limit=1
)
demo.queue(max_size=8).launch(share=True,show_error=True)