NoteMaker / app.py
ASureevaA
edit
529a697
from typing import Tuple, Optional
import tempfile
import numpy as numpy_module
import soundfile as soundfile_module
import torch
import gradio as gradio_module
from PIL import Image
import easyocr
from transformers import (
pipeline,
VitsModel,
AutoTokenizer,
)
device_string: str = "cpu"
ocr_reader = easyocr.Reader(
["en"],
gpu=False,
)
def run_ocr(image_object: Image.Image) -> str:
"""
OCR для печатного английского текста.
"""
if image_object is None:
return ""
rgb_image_object: Image.Image = image_object.convert("RGB")
numpy_image = numpy_module.array(rgb_image_object)
ocr_results = ocr_reader.readtext(
numpy_image,
detail=0,
paragraph=True,
)
text_parts = [str(text_value) for text_value in ocr_results if text_value]
recognized_text: str = "\n".join(text_parts).strip()
return recognized_text
text_classifier_pipeline = pipeline(
task="text-classification",
model="distilbert-base-uncased-finetuned-sst-2-english",
)
def run_text_classification(input_text: str) -> str:
"""
Анализ текста трансформером.
"""
cleaned_text: str = input_text.strip()
if not cleaned_text:
return ""
classifier_result_list = text_classifier_pipeline(
cleaned_text,
truncation=True,
max_length=512,
)
classifier_result = classifier_result_list[0]
label_value: str = str(classifier_result.get("label", ""))
score_value: float = float(classifier_result.get("score", 0.0))
classification_text: str = f"{label_value} (score={score_value:.3f})"
return classification_text
summary_pipeline = pipeline(
task="summarization",
model="sshleifer/distilbart-cnn-12-6",
)
def run_summarization(
input_text: str,
max_summary_tokens: int = 128,
) -> str:
"""
Английская суммаризация.
"""
cleaned_text: str = input_text.strip()
if not cleaned_text:
return ""
word_count: int = len(cleaned_text.split())
dynamic_max_length: int = min(
max_summary_tokens,
max(32, word_count + 20),
)
if word_count < 8:
return cleaned_text
summary_result_list = summary_pipeline(
cleaned_text,
max_length=dynamic_max_length,
min_length=max(10, dynamic_max_length // 3),
do_sample=False,
)
summary_text: str = summary_result_list[0]["summary_text"].strip()
return summary_text
tts_model: VitsModel = VitsModel.from_pretrained("facebook/mms-tts-eng")
tts_tokenizer: AutoTokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-eng")
tts_model.to(device_string)
def run_tts(summary_text: str) -> Optional[str]:
"""
Озвучка английского текста конспекта через VitsModel (facebook/mms-tts-eng).
"""
cleaned_text: str = summary_text.strip()
if not cleaned_text:
return None
tokenized_inputs = tts_tokenizer(
cleaned_text,
return_tensors="pt",
)
tokenized_inputs = {
key: value.to(device_string)
for key, value in tokenized_inputs.items()
}
input_ids_tensor = tokenized_inputs.get("input_ids")
if input_ids_tensor is None or input_ids_tensor.numel() == 0:
return None
try:
with torch.no_grad():
model_output = tts_model(**tokenized_inputs)
waveform_tensor = model_output.waveform # (batch, n_samples)
except RuntimeError as runtime_error:
print(f"[WARN] TTS RuntimeError: {runtime_error}")
return None
waveform_array = waveform_tensor.squeeze().cpu().numpy().astype("float32")
waveform_array = numpy_module.clip(waveform_array, -1.0, 1.0)
with tempfile.NamedTemporaryFile(
suffix=".wav",
delete=False,
) as temporary_file:
soundfile_module.write(
temporary_file.name,
waveform_array,
tts_model.config.sampling_rate,
)
file_path: str = temporary_file.name
return file_path
def full_flow(
image_object: Image.Image,
max_summary_tokens: int = 128,
) -> Tuple[str, str, str, Optional[str]]:
"""
1) OCR
2) Классификация текста
3) Суммаризация
4) TTS
"""
recognized_text: str = run_ocr(image_object=image_object)
classification_text: str = run_text_classification(recognized_text)
summary_text: str = run_summarization(
input_text=recognized_text,
max_summary_tokens=max_summary_tokens,
)
audio_file_path: Optional[str] = run_tts(summary_text=summary_text)
return recognized_text, classification_text, summary_text, audio_file_path
gradio_interface = gradio_module.Interface(
fn=full_flow,
inputs=[
gradio_module.Image(
type="pil",
label="Изображение с напечатанным английским текстом",
),
gradio_module.Slider(
minimum=32,
maximum=256,
value=128,
step=16,
label="Максимальная длина конспекта (токены, примерно)",
),
],
outputs=[
gradio_module.Textbox(
label="Распознанный текст (OCR, easyocr)",
lines=8,
),
gradio_module.Textbox(
label="Анализ текста (классификация, DistilBERT)",
lines=2,
),
gradio_module.Textbox(
label="Конспект (английский текст, DistilBART)",
lines=6,
),
gradio_module.Audio(
label="Озвучка конспекта (английский TTS, VITS)",
type="filepath",
),
],
title="Картинка → Текст → Анализ → Конспект → Озвучка",
description=(
"1) easyocr распознаёт печатный английский текст с картинки.\n"
"2) Трансформер-классификатор (DistilBERT) оценивает тон текста.\n"
"3) Трансформер-суммаризатор (DistilBART) делает краткий конспект.\n"
"4) Трансформер TTS (MMS VITS) озвучивает конспект."
),
)
if __name__ == "__main__":
gradio_interface.launch()