| from typing import Tuple, Optional |
|
|
| import tempfile |
|
|
| import numpy as numpy_module |
| import soundfile as soundfile_module |
| import torch |
| import gradio as gradio_module |
| from PIL import Image |
| import easyocr |
| from transformers import ( |
| pipeline, |
| VitsModel, |
| AutoTokenizer, |
| ) |
|
|
| device_string: str = "cpu" |
|
|
| ocr_reader = easyocr.Reader( |
| ["en"], |
| gpu=False, |
| ) |
|
|
|
|
| def run_ocr(image_object: Image.Image) -> str: |
| """ |
| OCR для печатного английского текста. |
| """ |
| if image_object is None: |
| return "" |
|
|
| rgb_image_object: Image.Image = image_object.convert("RGB") |
| numpy_image = numpy_module.array(rgb_image_object) |
|
|
| ocr_results = ocr_reader.readtext( |
| numpy_image, |
| detail=0, |
| paragraph=True, |
| ) |
|
|
| text_parts = [str(text_value) for text_value in ocr_results if text_value] |
|
|
| recognized_text: str = "\n".join(text_parts).strip() |
| return recognized_text |
|
|
| text_classifier_pipeline = pipeline( |
| task="text-classification", |
| model="distilbert-base-uncased-finetuned-sst-2-english", |
| ) |
|
|
|
|
| def run_text_classification(input_text: str) -> str: |
| """ |
| Анализ текста трансформером. |
| """ |
| cleaned_text: str = input_text.strip() |
| if not cleaned_text: |
| return "" |
|
|
| classifier_result_list = text_classifier_pipeline( |
| cleaned_text, |
| truncation=True, |
| max_length=512, |
| ) |
| classifier_result = classifier_result_list[0] |
|
|
| label_value: str = str(classifier_result.get("label", "")) |
| score_value: float = float(classifier_result.get("score", 0.0)) |
|
|
| classification_text: str = f"{label_value} (score={score_value:.3f})" |
| return classification_text |
|
|
|
|
| summary_pipeline = pipeline( |
| task="summarization", |
| model="sshleifer/distilbart-cnn-12-6", |
| ) |
|
|
|
|
| def run_summarization( |
| input_text: str, |
| max_summary_tokens: int = 128, |
| ) -> str: |
| """ |
| Английская суммаризация. |
| """ |
| cleaned_text: str = input_text.strip() |
| if not cleaned_text: |
| return "" |
|
|
| word_count: int = len(cleaned_text.split()) |
| dynamic_max_length: int = min( |
| max_summary_tokens, |
| max(32, word_count + 20), |
| ) |
|
|
| if word_count < 8: |
| return cleaned_text |
|
|
| summary_result_list = summary_pipeline( |
| cleaned_text, |
| max_length=dynamic_max_length, |
| min_length=max(10, dynamic_max_length // 3), |
| do_sample=False, |
| ) |
|
|
| summary_text: str = summary_result_list[0]["summary_text"].strip() |
| return summary_text |
|
|
|
|
| tts_model: VitsModel = VitsModel.from_pretrained("facebook/mms-tts-eng") |
| tts_tokenizer: AutoTokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-eng") |
| tts_model.to(device_string) |
|
|
|
|
| def run_tts(summary_text: str) -> Optional[str]: |
| """ |
| Озвучка английского текста конспекта через VitsModel (facebook/mms-tts-eng). |
| """ |
| cleaned_text: str = summary_text.strip() |
| if not cleaned_text: |
| return None |
|
|
| tokenized_inputs = tts_tokenizer( |
| cleaned_text, |
| return_tensors="pt", |
| ) |
| tokenized_inputs = { |
| key: value.to(device_string) |
| for key, value in tokenized_inputs.items() |
| } |
|
|
| input_ids_tensor = tokenized_inputs.get("input_ids") |
| if input_ids_tensor is None or input_ids_tensor.numel() == 0: |
| return None |
|
|
| try: |
| with torch.no_grad(): |
| model_output = tts_model(**tokenized_inputs) |
| waveform_tensor = model_output.waveform |
| except RuntimeError as runtime_error: |
| print(f"[WARN] TTS RuntimeError: {runtime_error}") |
| return None |
|
|
| waveform_array = waveform_tensor.squeeze().cpu().numpy().astype("float32") |
| waveform_array = numpy_module.clip(waveform_array, -1.0, 1.0) |
|
|
| with tempfile.NamedTemporaryFile( |
| suffix=".wav", |
| delete=False, |
| ) as temporary_file: |
| soundfile_module.write( |
| temporary_file.name, |
| waveform_array, |
| tts_model.config.sampling_rate, |
| ) |
| file_path: str = temporary_file.name |
|
|
| return file_path |
|
|
|
|
| def full_flow( |
| image_object: Image.Image, |
| max_summary_tokens: int = 128, |
| ) -> Tuple[str, str, str, Optional[str]]: |
| """ |
| 1) OCR |
| 2) Классификация текста |
| 3) Суммаризация |
| 4) TTS |
| """ |
| recognized_text: str = run_ocr(image_object=image_object) |
|
|
| classification_text: str = run_text_classification(recognized_text) |
|
|
| summary_text: str = run_summarization( |
| input_text=recognized_text, |
| max_summary_tokens=max_summary_tokens, |
| ) |
|
|
| audio_file_path: Optional[str] = run_tts(summary_text=summary_text) |
|
|
| return recognized_text, classification_text, summary_text, audio_file_path |
|
|
|
|
| gradio_interface = gradio_module.Interface( |
| fn=full_flow, |
| inputs=[ |
| gradio_module.Image( |
| type="pil", |
| label="Изображение с напечатанным английским текстом", |
| ), |
| gradio_module.Slider( |
| minimum=32, |
| maximum=256, |
| value=128, |
| step=16, |
| label="Максимальная длина конспекта (токены, примерно)", |
| ), |
| ], |
| outputs=[ |
| gradio_module.Textbox( |
| label="Распознанный текст (OCR, easyocr)", |
| lines=8, |
| ), |
| gradio_module.Textbox( |
| label="Анализ текста (классификация, DistilBERT)", |
| lines=2, |
| ), |
| gradio_module.Textbox( |
| label="Конспект (английский текст, DistilBART)", |
| lines=6, |
| ), |
| gradio_module.Audio( |
| label="Озвучка конспекта (английский TTS, VITS)", |
| type="filepath", |
| ), |
| ], |
| title="Картинка → Текст → Анализ → Конспект → Озвучка", |
| description=( |
| "1) easyocr распознаёт печатный английский текст с картинки.\n" |
| "2) Трансформер-классификатор (DistilBERT) оценивает тон текста.\n" |
| "3) Трансформер-суммаризатор (DistilBART) делает краткий конспект.\n" |
| "4) Трансформер TTS (MMS VITS) озвучивает конспект." |
| ), |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| gradio_interface.launch() |
|
|