sbompolas's picture
Update app.py
594aeee verified
import gradio as gr
import torch
import gc
import time
import logging
from transformers import (
pipeline,
AutoProcessor,
AutoModelForSpeechSeq2Seq,
AutoModelForCTC,
WhisperForConditionalGeneration,
WhisperProcessor,
)
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class MultiASRApp:
def __init__(self):
self.pipe = None
self.current_model = None
self.current_kind = None # "whisper" | "ctc"
self.available_models = [
"openai/whisper-tiny",
"openai/whisper-base",
"openai/whisper-small",
"openai/whisper-medium",
"openai/whisper-large-v2",
"openai/whisper-large-v3",
"ilsp/whisper_greek_dialect_of_lesbos",
"ilsp/xls-r-greek-cretan",
]
# ------------------------
# Model detection
# ------------------------
def detect_model_kind(self, model_name):
if "xls-r" in model_name.lower() or "xlsr" in model_name.lower():
return "ctc"
return "whisper"
def is_fine_tuned_whisper(self, model_name):
return "ilsp/" in model_name.lower() and "whisper" in model_name.lower()
# ------------------------
# Device & dtype
# ------------------------
def pick_device(self, conservative=True):
if torch.cuda.is_available():
return "cuda:0", torch.float32 if conservative else torch.float16
return "cpu", torch.float32
# ------------------------
# Pipeline creation
# ------------------------
def create_whisper_pipe(self, model_name):
conservative = self.is_fine_tuned_whisper(model_name)
device, dtype = self.pick_device(conservative)
try:
model = WhisperForConditionalGeneration.from_pretrained(
model_name,
torch_dtype=dtype,
low_cpu_mem_usage=True,
)
processor = WhisperProcessor.from_pretrained(model_name)
except Exception:
model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_name,
torch_dtype=dtype,
low_cpu_mem_usage=True,
)
processor = AutoProcessor.from_pretrained(model_name)
model.to(device)
return pipeline(
"automatic-speech-recognition",
model=model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
device=device,
torch_dtype=dtype,
chunk_length_s=30,
)
def create_ctc_pipe(self, model_name):
device, dtype = self.pick_device(conservative=True)
processor = AutoProcessor.from_pretrained(model_name)
model = AutoModelForCTC.from_pretrained(
model_name,
torch_dtype=dtype,
low_cpu_mem_usage=True,
)
model.to(device)
return pipeline(
"automatic-speech-recognition",
model=model,
tokenizer=getattr(processor, "tokenizer", None),
feature_extractor=getattr(processor, "feature_extractor", None),
device=device,
torch_dtype=dtype,
chunk_length_s=20,
stride_length_s=(4, 2),
)
def load_model(self, model_name):
if self.current_model == model_name and self.pipe is not None:
return True
self.clear_model()
kind = self.detect_model_kind(model_name)
try:
if kind == "ctc":
self.pipe = self.create_ctc_pipe(model_name)
else:
self.pipe = self.create_whisper_pipe(model_name)
self.current_model = model_name
self.current_kind = kind
return True
except Exception as e:
logger.error(e, exc_info=True)
self.clear_model()
return False
def clear_model(self):
if self.pipe is not None:
del self.pipe
self.pipe = None
self.current_model = None
self.current_kind = None
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
# ------------------------
# Transcription
# ------------------------
def transcribe(self, audio, model_name):
if audio is None:
return "Ανέβασε ένα ηχητικό αρχείο.", ""
start = time.time()
if not self.load_model(model_name):
return "Σφάλμα φόρτωσης μοντέλου.", ""
# 🔒 FORCE GREEK FOR ALL WHISPER MODELS
if self.current_kind == "whisper":
result = self.pipe(
audio,
generate_kwargs={
"language": "greek",
"task": "transcribe",
},
)
else:
# XLS-R (CTC)
result = self.pipe(audio)
text = result.get("text", "")
info = (
f"Μοντέλο: {model_name}\n"
f"Χρόνος επεξεργασίας: {time.time() - start:.2f} δευτ."
)
return text.strip(), info
def status(self):
if not self.current_model:
return "Δεν έχει φορτωθεί μοντέλο"
return f"✔ {self.current_model}"
# ------------------------
# Gradio App
# ------------------------
app = MultiASRApp()
def run(audio, model):
return app.transcribe(audio, model)
def status():
return app.status()
with gr.Blocks(title="Ίντα λαλείς;", theme=gr.themes.Soft()) as demo:
gr.Markdown(
"""
# Ίντα λαλείς;
## Η Τεχνητή Νοημοσύνη μαθαίνει ελληνικές διαλέκτους
🎧 Ανέβασε ένα ηχητικό αρχείο και δες πώς η Τεχνητή Νοημοσύνη
αναγνωρίζει την ελληνική γλώσσα και τις διαλέκτους της.
📍 Athens Science Festival 2025
🏛 Ωδείο Αθηνών | 18–21 Δεκεμβρίου 2025
"""
)
model_status = gr.Textbox(
label="Κατάσταση μοντέλου",
value=status(),
interactive=False,
)
with gr.Row():
with gr.Column():
audio = gr.Audio(
label="🎵 Ανέβασε ηχητικό αρχείο",
type="filepath",
)
model = gr.Dropdown(
choices=app.available_models,
value="openai/whisper-small",
label="Μοντέλο αναγνώρισης ομιλίας",
)
btn = gr.Button(
"🗣️ Μετατροπή ομιλίας σε κείμενο",
variant="primary",
)
with gr.Column():
text_out = gr.Textbox(
label="📄 Κείμενο",
lines=8,
show_copy_button=True,
)
info_out = gr.Textbox(
label="Πληροφορίες",
lines=4,
)
btn.click(
run,
inputs=[audio, model],
outputs=[text_out, info_out],
)
model.change(lambda _: status(), outputs=model_status)
gr.Markdown(
"""
🔬 Έρευνα & τεχνολογία για τη γλωσσική ποικιλία
🎙️ Η φωνή ως πολιτιστική κληρονομιά
"""
)
if __name__ == "__main__":
demo.launch()