File size: 6,535 Bytes
f80380f
b31d0e9
f80380f
b31d0e9
 
 
f80380f
b31d0e9
a88eb1e
ed03824
a88eb1e
 
f6e6de6
fa051f7
a88eb1e
 
ed03824
b31d0e9
ed03824
529a697
 
a88eb1e
 
35e85d1
9eec39f
 
ed03824
9eec39f
 
a88eb1e
9eec39f
efbc18d
ed03824
 
 
 
b2f68cd
ed03824
 
 
b2f68cd
9eec39f
ed03824
35e85d1
 
ed03824
 
 
 
 
 
 
 
529a697
ed03824
 
 
 
 
2717a3f
 
 
 
 
ed03824
 
 
 
 
 
 
 
efbc18d
 
 
 
 
 
 
b31d0e9
 
 
 
 
35e85d1
b31d0e9
 
 
a88eb1e
 
b31d0e9
 
 
 
 
 
fa051f7
 
 
b31d0e9
 
 
35e85d1
b31d0e9
 
 
 
 
 
 
efbc18d
 
 
 
 
b31d0e9
 
35e85d1
b31d0e9
 
 
 
 
 
 
 
 
35e85d1
 
 
 
b31d0e9
 
efbc18d
b31d0e9
 
 
 
 
fa051f7
b31d0e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35e85d1
b31d0e9
 
 
ed03824
b31d0e9
529a697
 
 
 
b31d0e9
 
 
ed03824
 
b31d0e9
 
 
 
 
 
 
ed03824
b31d0e9
 
 
a88eb1e
b31d0e9
 
 
fa051f7
b31d0e9
 
 
 
 
 
fa051f7
b31d0e9
 
a88eb1e
b31d0e9
ed03824
 
 
 
 
 
efbc18d
 
ed03824
b31d0e9
 
 
ed03824
b31d0e9
 
a88eb1e
ed03824
a88eb1e
ed03824
 
 
529a697
a88eb1e
 
 
b31d0e9
a88eb1e
b31d0e9
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
from typing import Tuple, Optional

import tempfile

import numpy as numpy_module
import soundfile as soundfile_module
import torch
import gradio as gradio_module
from PIL import Image
import easyocr
from transformers import (
    pipeline,
    VitsModel,
    AutoTokenizer,
)

device_string: str = "cpu"

ocr_reader = easyocr.Reader(
    ["en"],
    gpu=False,
)


def run_ocr(image_object: Image.Image) -> str:
    """
    OCR для печатного английского текста.
    """
    if image_object is None:
        return ""

    rgb_image_object: Image.Image = image_object.convert("RGB")
    numpy_image = numpy_module.array(rgb_image_object)

    ocr_results = ocr_reader.readtext(
        numpy_image,
        detail=0,
        paragraph=True,
    )

    text_parts = [str(text_value) for text_value in ocr_results if text_value]

    recognized_text: str = "\n".join(text_parts).strip()
    return recognized_text

text_classifier_pipeline = pipeline(
    task="text-classification",
    model="distilbert-base-uncased-finetuned-sst-2-english",
)


def run_text_classification(input_text: str) -> str:
    """
    Анализ текста трансформером.
    """
    cleaned_text: str = input_text.strip()
    if not cleaned_text:
        return ""

    classifier_result_list = text_classifier_pipeline(
        cleaned_text,
        truncation=True,
        max_length=512,
    )
    classifier_result = classifier_result_list[0]

    label_value: str = str(classifier_result.get("label", ""))
    score_value: float = float(classifier_result.get("score", 0.0))

    classification_text: str = f"{label_value} (score={score_value:.3f})"
    return classification_text


summary_pipeline = pipeline(
    task="summarization",
    model="sshleifer/distilbart-cnn-12-6",
)


def run_summarization(
    input_text: str,
    max_summary_tokens: int = 128,
) -> str:
    """
    Английская суммаризация.
    """
    cleaned_text: str = input_text.strip()
    if not cleaned_text:
        return ""

    word_count: int = len(cleaned_text.split())
    dynamic_max_length: int = min(
        max_summary_tokens,
        max(32, word_count + 20),
    )

    if word_count < 8:
        return cleaned_text

    summary_result_list = summary_pipeline(
        cleaned_text,
        max_length=dynamic_max_length,
        min_length=max(10, dynamic_max_length // 3),
        do_sample=False,
    )

    summary_text: str = summary_result_list[0]["summary_text"].strip()
    return summary_text


tts_model: VitsModel = VitsModel.from_pretrained("facebook/mms-tts-eng")
tts_tokenizer: AutoTokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-eng")
tts_model.to(device_string)


def run_tts(summary_text: str) -> Optional[str]:
    """
    Озвучка английского текста конспекта через VitsModel (facebook/mms-tts-eng).
    """
    cleaned_text: str = summary_text.strip()
    if not cleaned_text:
        return None

    tokenized_inputs = tts_tokenizer(
        cleaned_text,
        return_tensors="pt",
    )
    tokenized_inputs = {
        key: value.to(device_string)
        for key, value in tokenized_inputs.items()
    }

    input_ids_tensor = tokenized_inputs.get("input_ids")
    if input_ids_tensor is None or input_ids_tensor.numel() == 0:
        return None

    try:
        with torch.no_grad():
            model_output = tts_model(**tokenized_inputs)
            waveform_tensor = model_output.waveform  # (batch, n_samples)
    except RuntimeError as runtime_error:
        print(f"[WARN] TTS RuntimeError: {runtime_error}")
        return None

    waveform_array = waveform_tensor.squeeze().cpu().numpy().astype("float32")
    waveform_array = numpy_module.clip(waveform_array, -1.0, 1.0)

    with tempfile.NamedTemporaryFile(
        suffix=".wav",
        delete=False,
    ) as temporary_file:
        soundfile_module.write(
            temporary_file.name,
            waveform_array,
            tts_model.config.sampling_rate,
        )
        file_path: str = temporary_file.name

    return file_path


def full_flow(
    image_object: Image.Image,
    max_summary_tokens: int = 128,
) -> Tuple[str, str, str, Optional[str]]:
    """
    1) OCR
    2) Классификация текста
    3) Суммаризация
    4) TTS
    """
    recognized_text: str = run_ocr(image_object=image_object)

    classification_text: str = run_text_classification(recognized_text)

    summary_text: str = run_summarization(
        input_text=recognized_text,
        max_summary_tokens=max_summary_tokens,
    )

    audio_file_path: Optional[str] = run_tts(summary_text=summary_text)

    return recognized_text, classification_text, summary_text, audio_file_path


gradio_interface = gradio_module.Interface(
    fn=full_flow,
    inputs=[
        gradio_module.Image(
            type="pil",
            label="Изображение с напечатанным английским текстом",
        ),
        gradio_module.Slider(
            minimum=32,
            maximum=256,
            value=128,
            step=16,
            label="Максимальная длина конспекта (токены, примерно)",
        ),
    ],
    outputs=[
        gradio_module.Textbox(
            label="Распознанный текст (OCR, easyocr)",
            lines=8,
        ),
        gradio_module.Textbox(
            label="Анализ текста (классификация, DistilBERT)",
            lines=2,
        ),
        gradio_module.Textbox(
            label="Конспект (английский текст, DistilBART)",
            lines=6,
        ),
        gradio_module.Audio(
            label="Озвучка конспекта (английский TTS, VITS)",
            type="filepath",
        ),
    ],
    title="Картинка → Текст → Анализ → Конспект → Озвучка",
    description=(
        "1) easyocr распознаёт печатный английский текст с картинки.\n"
        "2) Трансформер-классификатор (DistilBERT) оценивает тон текста.\n"
        "3) Трансформер-суммаризатор (DistilBART) делает краткий конспект.\n"
        "4) Трансформер TTS (MMS VITS) озвучивает конспект."
    ),
)


if __name__ == "__main__":
    gradio_interface.launch()