STTR / app.py
TGPro1's picture
Create app.py
76c263b verified
raw
history blame
2.51 kB
import gradio as gr
import requests
import base64
import os
import json
import numpy as np
import scipy.io.wavfile as wavfile
import tempfile
import torch
from google import genai
from google.genai import types
from gradio_client import Client, handle_file
from pyannote.audio import Pipeline
# Configuration
SEAMLESS_SPACE = "tgpro1/sttr"
GEMINI_API_KEY = os.environ.get('GEMINI_API_KEY')
HF_TOKEN = os.environ.get('HF_TOKEN')
LANGUAGES = {
"Darija": "ar-SA",
"Arabic": "ar-SA",
"French": "fr-FR",
"English": "en-US",
"Spanish": "es-ES",
"German": "de-DE",
"Italian": "it-IT",
"Portuguese": "pt-PT",
"Chinese": "zh-CN",
"Japanese": "ja-JP",
"Korean": "ko-KR",
"Russian": "ru-RU",
}
# Pyannote Diarization
diarization_pipeline = None
try:
if HF_TOKEN:
diarization_pipeline = Pipeline.from_pretrained(
"pyannote/speaker-diarization-3.1",
use_auth_token=HF_TOKEN
)
if torch.cuda.is_available():
diarization_pipeline.to(torch.device("cuda"))
print("Pyannote: LOADED (GPU)")
else:
print("Pyannote: LOADED (CPU)")
except Exception as e:
print(f"Pyannote Error: {e}")
def diarize_audio(audio_path, min_speakers=1, max_speakers=5):
if not diarization_pipeline:
return {"error": "Diarization not available"}
try:
diarization = diarization_pipeline(audio_path, min_speakers=int(min_speakers), max_speakers=int(max_speakers))
speakers = []
for turn, _, speaker in diarization.itertracks(yield_label=True):
speakers.append({"speaker": speaker, "start": round(turn.start, 2), "end": round(turn.end, 2)})
return {"segments": speakers, "num_speakers": len(set(s["speaker"] for s in speakers))}
except Exception as e:
return {"error": str(e)}
with gr.Blocks(title="STTR") as demo:
gr.Markdown("# STTR - Speaker Diarization")
with gr.Tab("Diarization"):
audio_in = gr.Audio(type="filepath", label="Audio")
with gr.Row():
min_spk = gr.Slider(1, 10, value=1, step=1, label="Min Speakers")
max_spk = gr.Slider(1, 10, value=5, step=1, label="Max Speakers")
btn = gr.Button("Analyze", variant="primary")
output = gr.JSON(label="Result")
btn.click(diarize_audio, [audio_in, min_spk, max_spk], output, api_name="/diarize")
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860)))