| import subprocess |
| import sys |
|
|
| |
| subprocess.check_call([sys.executable, "-m", "pip", "install", "--upgrade", "gradio>=4.44.0"]) |
|
|
| from transformers import pipeline, AutoTokenizer, AutoModelForTokenClassification |
| import gradio as gr |
| import numpy as np |
| import scipy.io.wavfile |
| import tempfile |
| import os |
| from transformers import VitsModel, AutoTokenizer |
| import torch |
| import re |
| import traceback |
|
|
| print("Starting application...") |
|
|
| |
| punct_pipe = None |
| model = None |
| tokenizer = None |
|
|
| def load_models(): |
| global punct_pipe, model, tokenizer |
| |
| print("Loading punctuation model...") |
| try: |
| punctuation_model_id = "oliverguhr/fullstop-punctuation-multilang-large" |
| punct_tokenizer = AutoTokenizer.from_pretrained(punctuation_model_id) |
| punct_model = AutoModelForTokenClassification.from_pretrained(punctuation_model_id) |
| punct_pipe = pipeline("token-classification", model=punct_model, tokenizer=punct_tokenizer, aggregation_strategy="simple") |
| print("✓ Punctuation model loaded successfully") |
| except Exception as e: |
| print(f"✗ Error loading punctuation model: {e}") |
| punct_pipe = None |
|
|
| print("Loading TTS model...") |
| try: |
| model = VitsModel.from_pretrained("facebook/mms-tts-kmr-script_latin") |
| tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-kmr-script_latin") |
| print("✓ TTS model loaded successfully") |
| except Exception as e: |
| print(f"✗ Error loading TTS model: {e}") |
| model = None |
| tokenizer = None |
|
|
| |
| load_models() |
|
|
| |
| num2word = { |
| "0": "sifir", "1": "yek", "2": "du", "3": "sê", "4": "çar", "5": "pênc", |
| "6": "şeş", "7": "heft", "8": "heşt", "9": "neh", "10": "deh" |
| } |
|
|
| def replace_numbers_with_words(text): |
| def repl(match): |
| num = match.group() |
| return num2word.get(num, num) |
| return re.sub(r'\b\d+\b', repl, text) |
|
|
| |
| abbrev_as_word = { |
| "KCK": "Keceke", |
| "PKK": "Pekeke", |
| "PAJK": "Pajek", |
| "PYD": "Peyede", |
| "YPG": "Yepege", |
| "YPJ": "Yepeje", |
| "HDP": "Hedepe", |
| "DBP": "Debepe", |
| "KDP": "Kedepe", |
| "PDK": "Pedeke", |
| "PUK": "Pûk", |
| "YNK": "Yeneke", |
| "TAK": "Tak", |
| "PJAK": "Pejak", |
| "ENKS": "Enekese", |
| "TEV-DEM": "Tevdem", |
| "KOMKAR": "Komkar", |
| "NATO": "Nato", |
| "UNESCO": "Yunesko", |
| "UNICEF": "Yunîsef", |
| "VOA": "Voa", |
| "RAM": "Rem", |
| "ram": "Rem", |
|
|
| } |
|
|
| abbrev_spelled = { |
| "UN": "Û En", |
| "EU": "E Û", |
| "NGO": "En Cî O", |
| "KRG": "Ke Re Ge", |
| "BBC": "Bî Bî Sî", |
| "CNN": "Sî En En", |
| "DW": "De We", |
| "TRT": "Te Re Te", |
| "RT": "Er Te", |
| "USB": "U Se Be", |
| "PDF": "Pe De Fe", |
| "AI": "A Î", |
| "IT": "Ay Tî", |
| "HTTP": "He Te Te Pe", |
| "HTML": "He Te Me Le", |
| "URL": "U Re Le", |
| "IP": "Ay Pî", |
| "CPU": "Sî Pî U", |
| "GPU": "Cî Pî U", |
| "SMS": "Es Em Es", |
| "GPS": "Cî Pî Es", |
| } |
|
|
| abbrev_map = {} |
| abbrev_map.update(abbrev_as_word) |
| abbrev_map.update(abbrev_spelled) |
|
|
| def expand_abbreviations(text: str) -> str: |
| for abbr, full in abbrev_map.items(): |
| pattern = r'(?<!\w)' + re.escape(abbr) + r'(?!\w)' |
| text = re.sub(pattern, full, text) |
| return text |
|
|
| def normalize_text(text: str) -> str: |
| text = text.replace("“", "\"").replace("”", "\"") |
| text = text.replace("’", "'").replace("‘", "'") |
| return text |
|
|
| def restore_punctuation(text): |
| if punct_pipe is None: |
| print("Punctuation model not available, skipping...") |
| return text |
| |
| try: |
| results = punct_pipe(text) |
| punctuated = "" |
| for token in results: |
| word = token['word'] |
| punct = token.get('entity_group', '') |
| if punct == "PERIOD": |
| punctuated += word + ". " |
| elif punct == "COMMA": |
| punctuated += word + ", " |
| else: |
| punctuated += word + " " |
| return punctuated.strip() |
| except Exception as e: |
| print(f"Punctuation error: {e}") |
| return text |
|
|
| |
| def preprocess_text(text: str) -> str: |
| text = normalize_text(text) |
| text = replace_numbers_with_words(text) |
| text = expand_abbreviations(text) |
| text = restore_punctuation(text) |
| return text |
|
|
| def text_to_speech(text): |
| print(f"=== TTS Function Called ===") |
| print(f"Input text: '{text}'") |
| |
| try: |
| if not text or text.strip() == "": |
| error_msg = "Please enter some text" |
| print(f"Error: {error_msg}") |
| return None |
| |
| if model is None or tokenizer is None: |
| error_msg = "TTS model not loaded properly" |
| print(f"Error: {error_msg}") |
| return None |
| |
| print("Processing text...") |
| processed_text = preprocess_text(text.strip()) |
| print(f"Processed text: '{processed_text}'") |
| |
| print("Tokenizing...") |
| inputs = tokenizer(processed_text, return_tensors="pt") |
| print(f"Tokenized successfully, input_ids shape: {inputs['input_ids'].shape}") |
| |
| print("Generating audio...") |
| with torch.no_grad(): |
| output = model(**inputs).waveform |
| print(f"Audio generated, shape: {output.shape}") |
| |
| waveform = output.squeeze().numpy() |
| waveform = waveform / np.max(np.abs(waveform)) |
| print(f"Waveform shape: {waveform.shape}") |
| |
| print("Saving audio file...") |
| tmp_file = tempfile.NamedTemporaryFile(suffix=".wav", delete=False) |
| tmp_path = tmp_file.name |
| tmp_file.close() |
| |
| sampling_rate = getattr(model.config, "sampling_rate", 16000) |
| scipy.io.wavfile.write(tmp_path, rate=sampling_rate, data=waveform) |
| |
| print(f"✓ Audio saved to: {tmp_path}") |
| print("=== TTS Function Completed Successfully ===") |
| return tmp_path |
| |
| except Exception as e: |
| error_msg = f"Error in TTS: {str(e)}" |
| print(f"✗ {error_msg}") |
| traceback.print_exc() |
| return None |
|
|
| print("Creating Gradio interface...") |
|
|
| interface = gr.Interface( |
| fn=text_to_speech, |
| inputs=gr.Textbox( |
| label="Nivîseke bi kurmancî binivîse", |
| placeholder="Mînak: Silav! Ez baş im." |
| ), |
| outputs=gr.Audio(label="Deng"), |
| title="Bernameya Nivîs-bo-Deng ya bi kurmancî - Kurmanji Text-to-Speech", |
| description="Nivîseke bi kurmancî binivîse ku bo deng bê veguherandin. / Write Kurmanji Kurdish text and listen to it.", |
| submit_btn="Bişîne", |
| clear_btn="Paqij bike", |
| examples=[ |
| ["Silav! Ez baş im."], |
| ["Tu çawa yî?"], |
| ["Ez ji Kurdistanê me."], |
| ["HDP û KCK li ser vê mijarê axivîn."], |
| ["Ez bi USB yekî vê belavim."], |
| ] |
| ) |
|
|
| print("Launching interface...") |
|
|
| if __name__ == "__main__": |
| interface.launch() |
|
|