TGPro1 commited on
Commit
76c263b
·
verified ·
1 Parent(s): c2af569

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +75 -0
app.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import requests
3
+ import base64
4
+ import os
5
+ import json
6
+ import numpy as np
7
+ import scipy.io.wavfile as wavfile
8
+ import tempfile
9
+ import torch
10
+ from google import genai
11
+ from google.genai import types
12
+ from gradio_client import Client, handle_file
13
+ from pyannote.audio import Pipeline
14
+
15
+ # Configuration
16
+ SEAMLESS_SPACE = "tgpro1/sttr"
17
+ GEMINI_API_KEY = os.environ.get('GEMINI_API_KEY')
18
+ HF_TOKEN = os.environ.get('HF_TOKEN')
19
+
20
+ LANGUAGES = {
21
+ "Darija": "ar-SA",
22
+ "Arabic": "ar-SA",
23
+ "French": "fr-FR",
24
+ "English": "en-US",
25
+ "Spanish": "es-ES",
26
+ "German": "de-DE",
27
+ "Italian": "it-IT",
28
+ "Portuguese": "pt-PT",
29
+ "Chinese": "zh-CN",
30
+ "Japanese": "ja-JP",
31
+ "Korean": "ko-KR",
32
+ "Russian": "ru-RU",
33
+ }
34
+
35
+ # Pyannote Diarization
36
+ diarization_pipeline = None
37
+ try:
38
+ if HF_TOKEN:
39
+ diarization_pipeline = Pipeline.from_pretrained(
40
+ "pyannote/speaker-diarization-3.1",
41
+ use_auth_token=HF_TOKEN
42
+ )
43
+ if torch.cuda.is_available():
44
+ diarization_pipeline.to(torch.device("cuda"))
45
+ print("Pyannote: LOADED (GPU)")
46
+ else:
47
+ print("Pyannote: LOADED (CPU)")
48
+ except Exception as e:
49
+ print(f"Pyannote Error: {e}")
50
+
51
+ def diarize_audio(audio_path, min_speakers=1, max_speakers=5):
52
+ if not diarization_pipeline:
53
+ return {"error": "Diarization not available"}
54
+ try:
55
+ diarization = diarization_pipeline(audio_path, min_speakers=int(min_speakers), max_speakers=int(max_speakers))
56
+ speakers = []
57
+ for turn, _, speaker in diarization.itertracks(yield_label=True):
58
+ speakers.append({"speaker": speaker, "start": round(turn.start, 2), "end": round(turn.end, 2)})
59
+ return {"segments": speakers, "num_speakers": len(set(s["speaker"] for s in speakers))}
60
+ except Exception as e:
61
+ return {"error": str(e)}
62
+
63
+ with gr.Blocks(title="STTR") as demo:
64
+ gr.Markdown("# STTR - Speaker Diarization")
65
+ with gr.Tab("Diarization"):
66
+ audio_in = gr.Audio(type="filepath", label="Audio")
67
+ with gr.Row():
68
+ min_spk = gr.Slider(1, 10, value=1, step=1, label="Min Speakers")
69
+ max_spk = gr.Slider(1, 10, value=5, step=1, label="Max Speakers")
70
+ btn = gr.Button("Analyze", variant="primary")
71
+ output = gr.JSON(label="Result")
72
+ btn.click(diarize_audio, [audio_in, min_spk, max_spk], output, api_name="/diarize")
73
+
74
+ if __name__ == "__main__":
75
+ demo.launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860)))