ASureevaA commited on
Commit
efbc18d
·
1 Parent(s): fa051f7
Files changed (2) hide show
  1. app.py +87 -67
  2. requirements.txt +2 -1
app.py CHANGED
@@ -7,13 +7,12 @@ import soundfile as soundfile_module
7
  import torch
8
  import gradio as gradio_module
9
  from PIL import Image
 
10
  from transformers import (
11
  pipeline,
12
  VitsModel,
13
  AutoTokenizer,
14
  )
15
- from nemotron_ocr.inference.pipeline import NemotronOCR # <-- Nemotron OCR v1
16
-
17
 
18
  # ============================
19
  # 1. Настройки устройства
@@ -23,76 +22,87 @@ device_string: str = "cuda" if torch.cuda.is_available() else "cpu"
23
 
24
 
25
  # ============================
26
- # 2. Модели
27
  # ============================
28
 
29
- ocr_engine: NemotronOCR = NemotronOCR()
30
-
31
- summary_pipeline = pipeline(
32
- task="summarization",
33
- model="sshleifer/distilbart-cnn-12-6",
34
  )
35
 
36
- tts_model: VitsModel = VitsModel.from_pretrained("facebook/mms-tts-eng")
37
- tts_tokenizer: AutoTokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-eng")
38
- tts_model.to(device_string)
39
-
40
-
41
- # ============================
42
- # 3. OCR через NemotronOCR
43
- # ============================
44
 
45
  def run_ocr(image_object: Image.Image) -> str:
46
  """
47
- OCR для печатного (и вообще любого) английского текста с картины.
48
-
49
- Используем NemotronOCR из nvidia/nemotron-ocr-v1.
50
- Модель сама делает:
51
- - детекцию текстовых блоков,
52
- - распознавание текста,
53
- - анализ порядка чтения.
54
-
55
- На выходе NemotronOCR даёт список dict:
56
- [
57
- {
58
- "text": "...",
59
- "confidence": float,
60
- "left": float,
61
- "upper": float,
62
- "right": float,
63
- "lower": float,
64
- ...
65
- },
66
- ...
67
- ]
68
  """
69
-
70
  if image_object is None:
71
  return ""
72
 
73
- with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temporary_file:
74
- image_object.save(temporary_file.name)
75
- image_path: str = temporary_file.name
76
 
77
- predictions = ocr_engine(image_path)
 
 
 
 
 
 
 
78
 
79
  text_parts = []
80
- for prediction in predictions:
81
- text_value = prediction.get("text", "")
82
  if not text_value:
83
  continue
84
-
85
-
86
- text_parts.append(str(text_value))
87
 
88
  recognized_text: str = "\n".join(text_parts).strip()
89
  return recognized_text
90
 
91
 
92
  # ============================
93
- # 4. Суммаризация (английский)
94
  # ============================
95
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
  def run_summarization(
97
  input_text: str,
98
  max_summary_tokens: int = 128,
@@ -106,13 +116,13 @@ def run_summarization(
106
  return ""
107
 
108
  word_count: int = len(cleaned_text.split())
109
-
110
  dynamic_max_length: int = min(
111
  max_summary_tokens,
112
  max(32, word_count + 20),
113
  )
114
 
115
  if word_count < 8:
 
116
  return cleaned_text
117
 
118
  summary_result_list = summary_pipeline(
@@ -127,9 +137,14 @@ def run_summarization(
127
 
128
 
129
  # ============================
130
- # 5. TTS (английский, MMS VITS)
131
  # ============================
132
 
 
 
 
 
 
133
  def run_tts(summary_text: str) -> Optional[str]:
134
  """
135
  Озвучка английского текста конспекта через VitsModel (facebook/mms-tts-eng).
@@ -151,9 +166,7 @@ def run_tts(summary_text: str) -> Optional[str]:
151
  }
152
 
153
  input_ids_tensor = tokenized_inputs.get("input_ids")
154
- if input_ids_tensor is None:
155
- return None
156
- if input_ids_tensor.numel() == 0 or input_ids_tensor.shape[1] == 0:
157
  return None
158
 
159
  try:
@@ -188,15 +201,18 @@ def run_tts(summary_text: str) -> Optional[str]:
188
  def full_flow(
189
  image_object: Image.Image,
190
  max_summary_tokens: int = 128,
191
- ) -> Tuple[str, str, Optional[str]]:
192
  """
193
  Полный пайплайн:
194
- 1) OCR: изображение -> исходный английский текст
195
- 2) Суммаризация: текст -> конспект (английский)
196
- 3) TTS: конспект -> .wav файл (или None, если TTS не смог)
 
197
  """
198
  recognized_text: str = run_ocr(image_object=image_object)
199
 
 
 
200
  summary_text: str = run_summarization(
201
  input_text=recognized_text,
202
  max_summary_tokens=max_summary_tokens,
@@ -204,7 +220,7 @@ def full_flow(
204
 
205
  audio_file_path: Optional[str] = run_tts(summary_text=summary_text)
206
 
207
- return recognized_text, summary_text, audio_file_path
208
 
209
 
210
  # ============================
@@ -228,25 +244,29 @@ gradio_interface = gradio_module.Interface(
228
  ],
229
  outputs=[
230
  gradio_module.Textbox(
231
- label="Распознанный текст (Nemotron OCR)",
232
  lines=8,
233
  ),
234
  gradio_module.Textbox(
235
- label="Конспект (английский текст)",
 
 
 
 
236
  lines=6,
237
  ),
238
  gradio_module.Audio(
239
- label="Озвучка конспекта (английский TTS)",
240
  type="filepath",
241
  ),
242
  ],
243
- title="Картинка → Текст → Конспект → Озвучка (Nemotron OCR + английские модели)",
244
  description=(
245
- "1) Nemotron OCR v1 (nvidia/nemotron-ocr-v1) распознаёт текст с документа.\n"
246
- "2) Английский трансформер суммаризации делает краткий пересказ.\n"
247
- "3) VITS-модель MMS (facebook/mms-tts-eng) озвучивает конспект.\n\n"
248
- "Если озвучка не сгенерировалась, значит конкретный текст не понравился TTS-модели "
249
- "и она упала внутри пайплайн просто пропустит аудио."
250
  ),
251
  )
252
 
 
7
  import torch
8
  import gradio as gradio_module
9
  from PIL import Image
10
+ import easyocr
11
  from transformers import (
12
  pipeline,
13
  VitsModel,
14
  AutoTokenizer,
15
  )
 
 
16
 
17
  # ============================
18
  # 1. Настройки устройства
 
22
 
23
 
24
  # ============================
25
+ # 2. OCR (easyocr, английский)
26
  # ============================
27
 
28
+ # TODO_USER: при желании можно добавить другие языки, но тогда конспект и TTS всё равно останутся на английском
29
+ ocr_reader = easyocr.Reader(
30
+ ["en"], # языки
31
+ gpu=(device_string == "cuda"),
 
32
  )
33
 
 
 
 
 
 
 
 
 
34
 
35
  def run_ocr(image_object: Image.Image) -> str:
36
  """
37
+ OCR для печатного английского текста.
38
+ Используем easyocr, потому что он реально более устойчивый для
39
+ произвольных сканов/фото, чем большинство трансформеров, которые мы пробовали.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  """
 
41
  if image_object is None:
42
  return ""
43
 
44
+ rgb_image_object: Image.Image = image_object.convert("RGB")
 
 
45
 
46
+ # easyocr работает с numpy-массивом
47
+ numpy_image = numpy_module.array(rgb_image_object)
48
+
49
+ results = ocr_reader.readtext(
50
+ numpy_image,
51
+ detail=1, # возвращаем bbox + текст + confidence
52
+ paragraph=True, # склеивать текст в параграфы, где это возможно
53
+ )
54
 
55
  text_parts = []
56
+ for bbox, text_value, confidence_value in results:
 
57
  if not text_value:
58
  continue
59
+ # TODO_USER: при желании можно фильтровать по confidence_value
60
+ text_parts.append(text_value)
 
61
 
62
  recognized_text: str = "\n".join(text_parts).strip()
63
  return recognized_text
64
 
65
 
66
  # ============================
67
+ # 3. Трансформер #1: классификация текста
68
  # ============================
69
 
70
+ text_classifier_pipeline = pipeline(
71
+ task="text-classification",
72
+ model="distilbert-base-uncased-finetuned-sst-2-english",
73
+ )
74
+
75
+
76
+ def run_text_classification(input_text: str) -> str:
77
+ """
78
+ Пример анализа текста трансформером:
79
+ используем sentiment-классификатор как демонстрацию.
80
+ Возвращаем строку вида: "label: POSITIVE, score: 0.98".
81
+ """
82
+ cleaned_text: str = input_text.strip()
83
+ if not cleaned_text:
84
+ return ""
85
+
86
+ result_list = text_classifier_pipeline(cleaned_text)
87
+ result = result_list[0]
88
+
89
+ label_value: str = str(result.get("label", ""))
90
+ score_value: float = float(result.get("score", 0.0))
91
+
92
+ classification_text: str = f"{label_value} (score={score_value:.3f})"
93
+ return classification_text
94
+
95
+
96
+ # ============================
97
+ # 4. Трансформер #2: суммаризация (английский)
98
+ # ============================
99
+
100
+ summary_pipeline = pipeline(
101
+ task="summarization",
102
+ model="sshleifer/distilbart-cnn-12-6",
103
+ )
104
+
105
+
106
  def run_summarization(
107
  input_text: str,
108
  max_summary_tokens: int = 128,
 
116
  return ""
117
 
118
  word_count: int = len(cleaned_text.split())
 
119
  dynamic_max_length: int = min(
120
  max_summary_tokens,
121
  max(32, word_count + 20),
122
  )
123
 
124
  if word_count < 8:
125
+ # TODO_USER: для очень короткого текста суммаризация сомнительна, возвращаем исходный текст
126
  return cleaned_text
127
 
128
  summary_result_list = summary_pipeline(
 
137
 
138
 
139
  # ============================
140
+ # 5. Трансформер #3: TTS (английский, MMS VITS)
141
  # ============================
142
 
143
+ tts_model: VitsModel = VitsModel.from_pretrained("facebook/mms-tts-eng")
144
+ tts_tokenizer: AutoTokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-eng")
145
+ tts_model.to(device_string)
146
+
147
+
148
  def run_tts(summary_text: str) -> Optional[str]:
149
  """
150
  Озвучка английского текста конспекта через VitsModel (facebook/mms-tts-eng).
 
166
  }
167
 
168
  input_ids_tensor = tokenized_inputs.get("input_ids")
169
+ if input_ids_tensor is None or input_ids_tensor.numel() == 0:
 
 
170
  return None
171
 
172
  try:
 
201
  def full_flow(
202
  image_object: Image.Image,
203
  max_summary_tokens: int = 128,
204
+ ) -> Tuple[str, str, str, Optional[str]]:
205
  """
206
  Полный пайплайн:
207
+ 1) OCR (easyocr): изображение -> исходный текст (английский)
208
+ 2) Классификация текста трансформером (sentiment)
209
+ 3) Суммаризация: текст -> конспект
210
+ 4) TTS: конспект -> .wav файл (или None)
211
  """
212
  recognized_text: str = run_ocr(image_object=image_object)
213
 
214
+ classification_text: str = run_text_classification(recognized_text)
215
+
216
  summary_text: str = run_summarization(
217
  input_text=recognized_text,
218
  max_summary_tokens=max_summary_tokens,
 
220
 
221
  audio_file_path: Optional[str] = run_tts(summary_text=summary_text)
222
 
223
+ return recognized_text, classification_text, summary_text, audio_file_path
224
 
225
 
226
  # ============================
 
244
  ],
245
  outputs=[
246
  gradio_module.Textbox(
247
+ label="Распознанный текст (OCR, easyocr)",
248
  lines=8,
249
  ),
250
  gradio_module.Textbox(
251
+ label="Анализ текста (классификация, DistilBERT)",
252
+ lines=2,
253
+ ),
254
+ gradio_module.Textbox(
255
+ label="Конспект (английский текст, DistilBART)",
256
  lines=6,
257
  ),
258
  gradio_module.Audio(
259
+ label="Озвучка конспекта (английский TTS, VITS)",
260
  type="filepath",
261
  ),
262
  ],
263
+ title="Картинка → Текст → Анализ → Конспект → Озвучка",
264
  description=(
265
+ "1) easyocr распознаёт печатный английский текст с картинки.\n"
266
+ "2) Трансформер-классификатор (DistilBERT) оценивает тон текста.\n"
267
+ "3) Трансформер-суммаризатор (DistilBART) делает краткий конспект.\n"
268
+ "4) Трансформер TTS (MMS VITS) озвучивает конспект.\n"
269
+ "В проекте используются три трансф��рмера с Hugging Face, OCR сделан через easyocr."
270
  ),
271
  )
272
 
requirements.txt CHANGED
@@ -1,7 +1,8 @@
1
- transformers>=4.40.0
2
  torch
3
  sentencepiece
4
  gradio
5
  Pillow
6
  numpy
7
  soundfile
 
 
1
+ transformers>=4.33.0
2
  torch
3
  sentencepiece
4
  gradio
5
  Pillow
6
  numpy
7
  soundfile
8
+ easyocr