|
|
from typing import Tuple, Optional |
|
|
|
|
|
import tempfile |
|
|
|
|
|
import numpy as numpy_module |
|
|
import soundfile as soundfile_module |
|
|
import torch |
|
|
import gradio as gradio_module |
|
|
from PIL import Image |
|
|
import easyocr |
|
|
from transformers import ( |
|
|
pipeline, |
|
|
VitsModel, |
|
|
AutoTokenizer, |
|
|
) |
|
|
|
|
|
device_string: str = "cpu" |
|
|
|
|
|
ocr_reader = easyocr.Reader( |
|
|
["en"], |
|
|
gpu=False, |
|
|
) |
|
|
|
|
|
|
|
|
def run_ocr(image_object: Image.Image) -> str: |
|
|
""" |
|
|
OCR для печатного английского текста. |
|
|
""" |
|
|
if image_object is None: |
|
|
return "" |
|
|
|
|
|
rgb_image_object: Image.Image = image_object.convert("RGB") |
|
|
numpy_image = numpy_module.array(rgb_image_object) |
|
|
|
|
|
ocr_results = ocr_reader.readtext( |
|
|
numpy_image, |
|
|
detail=0, |
|
|
paragraph=True, |
|
|
) |
|
|
|
|
|
text_parts = [str(text_value) for text_value in ocr_results if text_value] |
|
|
|
|
|
recognized_text: str = "\n".join(text_parts).strip() |
|
|
return recognized_text |
|
|
|
|
|
text_classifier_pipeline = pipeline( |
|
|
task="text-classification", |
|
|
model="distilbert-base-uncased-finetuned-sst-2-english", |
|
|
) |
|
|
|
|
|
|
|
|
def run_text_classification(input_text: str) -> str: |
|
|
""" |
|
|
Анализ текста трансформером. |
|
|
""" |
|
|
cleaned_text: str = input_text.strip() |
|
|
if not cleaned_text: |
|
|
return "" |
|
|
|
|
|
classifier_result_list = text_classifier_pipeline( |
|
|
cleaned_text, |
|
|
truncation=True, |
|
|
max_length=512, |
|
|
) |
|
|
classifier_result = classifier_result_list[0] |
|
|
|
|
|
label_value: str = str(classifier_result.get("label", "")) |
|
|
score_value: float = float(classifier_result.get("score", 0.0)) |
|
|
|
|
|
classification_text: str = f"{label_value} (score={score_value:.3f})" |
|
|
return classification_text |
|
|
|
|
|
|
|
|
summary_pipeline = pipeline( |
|
|
task="summarization", |
|
|
model="sshleifer/distilbart-cnn-12-6", |
|
|
) |
|
|
|
|
|
|
|
|
def run_summarization( |
|
|
input_text: str, |
|
|
max_summary_tokens: int = 128, |
|
|
) -> str: |
|
|
""" |
|
|
Английская суммаризация. |
|
|
""" |
|
|
cleaned_text: str = input_text.strip() |
|
|
if not cleaned_text: |
|
|
return "" |
|
|
|
|
|
word_count: int = len(cleaned_text.split()) |
|
|
dynamic_max_length: int = min( |
|
|
max_summary_tokens, |
|
|
max(32, word_count + 20), |
|
|
) |
|
|
|
|
|
if word_count < 8: |
|
|
return cleaned_text |
|
|
|
|
|
summary_result_list = summary_pipeline( |
|
|
cleaned_text, |
|
|
max_length=dynamic_max_length, |
|
|
min_length=max(10, dynamic_max_length // 3), |
|
|
do_sample=False, |
|
|
) |
|
|
|
|
|
summary_text: str = summary_result_list[0]["summary_text"].strip() |
|
|
return summary_text |
|
|
|
|
|
|
|
|
tts_model: VitsModel = VitsModel.from_pretrained("facebook/mms-tts-eng") |
|
|
tts_tokenizer: AutoTokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-eng") |
|
|
tts_model.to(device_string) |
|
|
|
|
|
|
|
|
def run_tts(summary_text: str) -> Optional[str]: |
|
|
""" |
|
|
Озвучка английского текста конспекта через VitsModel (facebook/mms-tts-eng). |
|
|
""" |
|
|
cleaned_text: str = summary_text.strip() |
|
|
if not cleaned_text: |
|
|
return None |
|
|
|
|
|
tokenized_inputs = tts_tokenizer( |
|
|
cleaned_text, |
|
|
return_tensors="pt", |
|
|
) |
|
|
tokenized_inputs = { |
|
|
key: value.to(device_string) |
|
|
for key, value in tokenized_inputs.items() |
|
|
} |
|
|
|
|
|
input_ids_tensor = tokenized_inputs.get("input_ids") |
|
|
if input_ids_tensor is None or input_ids_tensor.numel() == 0: |
|
|
return None |
|
|
|
|
|
try: |
|
|
with torch.no_grad(): |
|
|
model_output = tts_model(**tokenized_inputs) |
|
|
waveform_tensor = model_output.waveform |
|
|
except RuntimeError as runtime_error: |
|
|
print(f"[WARN] TTS RuntimeError: {runtime_error}") |
|
|
return None |
|
|
|
|
|
waveform_array = waveform_tensor.squeeze().cpu().numpy().astype("float32") |
|
|
waveform_array = numpy_module.clip(waveform_array, -1.0, 1.0) |
|
|
|
|
|
with tempfile.NamedTemporaryFile( |
|
|
suffix=".wav", |
|
|
delete=False, |
|
|
) as temporary_file: |
|
|
soundfile_module.write( |
|
|
temporary_file.name, |
|
|
waveform_array, |
|
|
tts_model.config.sampling_rate, |
|
|
) |
|
|
file_path: str = temporary_file.name |
|
|
|
|
|
return file_path |
|
|
|
|
|
|
|
|
def full_flow( |
|
|
image_object: Image.Image, |
|
|
max_summary_tokens: int = 128, |
|
|
) -> Tuple[str, str, str, Optional[str]]: |
|
|
""" |
|
|
1) OCR |
|
|
2) Классификация текста |
|
|
3) Суммаризация |
|
|
4) TTS |
|
|
""" |
|
|
recognized_text: str = run_ocr(image_object=image_object) |
|
|
|
|
|
classification_text: str = run_text_classification(recognized_text) |
|
|
|
|
|
summary_text: str = run_summarization( |
|
|
input_text=recognized_text, |
|
|
max_summary_tokens=max_summary_tokens, |
|
|
) |
|
|
|
|
|
audio_file_path: Optional[str] = run_tts(summary_text=summary_text) |
|
|
|
|
|
return recognized_text, classification_text, summary_text, audio_file_path |
|
|
|
|
|
|
|
|
gradio_interface = gradio_module.Interface( |
|
|
fn=full_flow, |
|
|
inputs=[ |
|
|
gradio_module.Image( |
|
|
type="pil", |
|
|
label="Изображение с напечатанным английским текстом", |
|
|
), |
|
|
gradio_module.Slider( |
|
|
minimum=32, |
|
|
maximum=256, |
|
|
value=128, |
|
|
step=16, |
|
|
label="Максимальная длина конспекта (токены, примерно)", |
|
|
), |
|
|
], |
|
|
outputs=[ |
|
|
gradio_module.Textbox( |
|
|
label="Распознанный текст (OCR, easyocr)", |
|
|
lines=8, |
|
|
), |
|
|
gradio_module.Textbox( |
|
|
label="Анализ текста (классификация, DistilBERT)", |
|
|
lines=2, |
|
|
), |
|
|
gradio_module.Textbox( |
|
|
label="Конспект (английский текст, DistilBART)", |
|
|
lines=6, |
|
|
), |
|
|
gradio_module.Audio( |
|
|
label="Озвучка конспекта (английский TTS, VITS)", |
|
|
type="filepath", |
|
|
), |
|
|
], |
|
|
title="Картинка → Текст → Анализ → Конспект → Озвучка", |
|
|
description=( |
|
|
"1) easyocr распознаёт печатный английский текст с картинки.\n" |
|
|
"2) Трансформер-классификатор (DistilBERT) оценивает тон текста.\n" |
|
|
"3) Трансформер-суммаризатор (DistilBART) делает краткий конспект.\n" |
|
|
"4) Трансформер TTS (MMS VITS) озвучивает конспект." |
|
|
), |
|
|
) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
gradio_interface.launch() |
|
|
|