ASureevaA commited on
Commit
fa051f7
·
1 Parent(s): 35e85d1
Files changed (2) hide show
  1. app.py +62 -59
  2. requirements.txt +1 -1
app.py CHANGED
@@ -8,12 +8,12 @@ import torch
8
  import gradio as gradio_module
9
  from PIL import Image
10
  from transformers import (
11
- TrOCRProcessor,
12
- VisionEncoderDecoderModel,
13
  pipeline,
14
- VitsTokenizer,
15
  VitsModel,
 
16
  )
 
 
17
 
18
  # ============================
19
  # 1. Настройки устройства
@@ -26,62 +26,66 @@ device_string: str = "cuda" if torch.cuda.is_available() else "cpu"
26
  # 2. Модели
27
  # ============================
28
 
29
- # OCR: печатный английский текст
30
- # Модель: microsoft/trocr-small-printed
31
- ocr_processor: TrOCRProcessor = TrOCRProcessor.from_pretrained(
32
- "microsoft/trocr-small-printed"
33
- )
34
- ocr_model: VisionEncoderDecoderModel = VisionEncoderDecoderModel.from_pretrained(
35
- "microsoft/trocr-small-printed"
36
- )
37
- ocr_model.to(device_string)
38
 
39
- # Суммаризация: английский новостной/общий текст
40
- # Модель: sshleifer/distilbart-cnn-12-6
41
  summary_pipeline = pipeline(
42
  task="summarization",
43
  model="sshleifer/distilbart-cnn-12-6",
44
  )
45
 
46
- # TTS: английская MMS VITS
47
- # Модель: facebook/mms-tts-eng
48
  tts_model: VitsModel = VitsModel.from_pretrained("facebook/mms-tts-eng")
49
- tts_tokenizer: VitsTokenizer = VitsTokenizer.from_pretrained("facebook/mms-tts-eng")
50
  tts_model.to(device_string)
51
 
52
 
53
  # ============================
54
- # 3. OCR
55
  # ============================
56
 
57
  def run_ocr(image_object: Image.Image) -> str:
58
  """
59
- Распознавание печатного английского текста с изображения.
60
- Используем TrOCR (microsoft/trocr-small-printed).
61
-
62
- Ожидается более-менее читаемый printed text
63
- (скриншоты, документы, слайды и т.п.).
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  """
 
65
  if image_object is None:
66
  return ""
67
 
68
- rgb_image_object: Image.Image = image_object.convert("RGB")
 
 
69
 
70
- processor_output = ocr_processor(
71
- images=rgb_image_object,
72
- return_tensors="pt",
73
- )
74
- pixel_values_tensor = processor_output.pixel_values.to(device_string)
75
 
76
- with torch.no_grad():
77
- generated_id_tensor = ocr_model.generate(pixel_values_tensor)
 
 
 
78
 
79
- decoded_text_list = ocr_processor.batch_decode(
80
- generated_id_tensor,
81
- skip_special_tokens=True,
82
- )
83
 
84
- recognized_text: str = decoded_text_list[0].strip()
 
 
85
  return recognized_text
86
 
87
 
@@ -103,13 +107,14 @@ def run_summarization(
103
 
104
  word_count: int = len(cleaned_text.split())
105
 
106
- # Простая адаптация длины под размер текста,
107
- # чтобы не было бессмысленных max_length >> input_length.
108
  dynamic_max_length: int = min(
109
  max_summary_tokens,
110
  max(32, word_count + 20),
111
  )
112
 
 
 
 
113
  summary_result_list = summary_pipeline(
114
  cleaned_text,
115
  max_length=dynamic_max_length,
@@ -129,10 +134,8 @@ def run_tts(summary_text: str) -> Optional[str]:
129
  """
130
  Озвучка английского текста конспекта через VitsModel (facebook/mms-tts-eng).
131
 
132
- ВАЖНО:
133
- - защищаемся от пустого ввода;
134
- - ловим RuntimeError изнутри модели (бывают краши на редких входах);
135
- в этом случае просто возвращаем None, чтобы не ронять весь Space.
136
  """
137
  cleaned_text: str = summary_text.strip()
138
  if not cleaned_text:
@@ -142,7 +145,6 @@ def run_tts(summary_text: str) -> Optional[str]:
142
  cleaned_text,
143
  return_tensors="pt",
144
  )
145
-
146
  tokenized_inputs = {
147
  key: value.to(device_string)
148
  for key, value in tokenized_inputs.items()
@@ -151,14 +153,13 @@ def run_tts(summary_text: str) -> Optional[str]:
151
  input_ids_tensor = tokenized_inputs.get("input_ids")
152
  if input_ids_tensor is None:
153
  return None
154
-
155
  if input_ids_tensor.numel() == 0 or input_ids_tensor.shape[1] == 0:
156
  return None
157
 
158
  try:
159
  with torch.no_grad():
160
  model_output = tts_model(**tokenized_inputs)
161
- waveform_tensor = model_output.waveform # shape: (batch, n_samples)
162
  except RuntimeError as runtime_error:
163
  print(f"[WARN] TTS RuntimeError: {runtime_error}")
164
  return None
@@ -190,9 +191,9 @@ def full_flow(
190
  ) -> Tuple[str, str, Optional[str]]:
191
  """
192
  Полный пайплайн:
193
- 1) OCR: изображение -> исходный текст (английский)
194
- 2) Суммаризация: текст -> краткое резюме
195
- 3) TTS: резюме -> .wav файл (или None, если TTS не смог)
196
  """
197
  recognized_text: str = run_ocr(image_object=image_object)
198
 
@@ -207,7 +208,7 @@ def full_flow(
207
 
208
 
209
  # ============================
210
- # 7. Gradio UI
211
  # ============================
212
 
213
  gradio_interface = gradio_module.Interface(
@@ -215,35 +216,37 @@ gradio_interface = gradio_module.Interface(
215
  inputs=[
216
  gradio_module.Image(
217
  type="pil",
218
- label="Image with printed English text",
219
  ),
220
  gradio_module.Slider(
221
  minimum=32,
222
  maximum=256,
223
  value=128,
224
  step=16,
225
- label="Maximum summary length (tokens, approx)",
226
  ),
227
  ],
228
  outputs=[
229
  gradio_module.Textbox(
230
- label="Recognized text (OCR)",
231
- lines=6,
232
  ),
233
  gradio_module.Textbox(
234
- label="Summary (English)",
235
  lines=6,
236
  ),
237
  gradio_module.Audio(
238
- label="Summary narration (MMS VITS, en)",
239
  type="filepath",
240
  ),
241
  ],
242
- title="ImageTextSummarySpeech (English models)",
243
  description=(
244
- "1) English OCR transformer recognizes printed text from the image.\n"
245
- "2) English summarization transformer creates a short summary.\n"
246
- "3) English VITS (facebook/mms-tts-eng) reads the summary aloud."
 
 
247
  ),
248
  )
249
 
 
8
  import gradio as gradio_module
9
  from PIL import Image
10
  from transformers import (
 
 
11
  pipeline,
 
12
  VitsModel,
13
+ AutoTokenizer,
14
  )
15
+ from nemotron_ocr.inference.pipeline import NemotronOCR # <-- Nemotron OCR v1
16
+
17
 
18
  # ============================
19
  # 1. Настройки устройства
 
26
  # 2. Модели
27
  # ============================
28
 
29
+ ocr_engine: NemotronOCR = NemotronOCR()
 
 
 
 
 
 
 
 
30
 
 
 
31
  summary_pipeline = pipeline(
32
  task="summarization",
33
  model="sshleifer/distilbart-cnn-12-6",
34
  )
35
 
 
 
36
  tts_model: VitsModel = VitsModel.from_pretrained("facebook/mms-tts-eng")
37
+ tts_tokenizer: AutoTokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-eng")
38
  tts_model.to(device_string)
39
 
40
 
41
  # ============================
42
+ # 3. OCR через NemotronOCR
43
  # ============================
44
 
45
  def run_ocr(image_object: Image.Image) -> str:
46
  """
47
+ OCR для печатного (и вообще любого) английского текста с картины.
48
+
49
+ Используем NemotronOCR из nvidia/nemotron-ocr-v1.
50
+ Модель сама делает:
51
+ - детекцию текстовых блоков,
52
+ - распознавание текста,
53
+ - анализ порядка чтения.
54
+
55
+ На выходе NemotronOCR даёт список dict:
56
+ [
57
+ {
58
+ "text": "...",
59
+ "confidence": float,
60
+ "left": float,
61
+ "upper": float,
62
+ "right": float,
63
+ "lower": float,
64
+ ...
65
+ },
66
+ ...
67
+ ]
68
  """
69
+
70
  if image_object is None:
71
  return ""
72
 
73
+ with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temporary_file:
74
+ image_object.save(temporary_file.name)
75
+ image_path: str = temporary_file.name
76
 
77
+ predictions = ocr_engine(image_path)
 
 
 
 
78
 
79
+ text_parts = []
80
+ for prediction in predictions:
81
+ text_value = prediction.get("text", "")
82
+ if not text_value:
83
+ continue
84
 
 
 
 
 
85
 
86
+ text_parts.append(str(text_value))
87
+
88
+ recognized_text: str = "\n".join(text_parts).strip()
89
  return recognized_text
90
 
91
 
 
107
 
108
  word_count: int = len(cleaned_text.split())
109
 
 
 
110
  dynamic_max_length: int = min(
111
  max_summary_tokens,
112
  max(32, word_count + 20),
113
  )
114
 
115
+ if word_count < 8:
116
+ return cleaned_text
117
+
118
  summary_result_list = summary_pipeline(
119
  cleaned_text,
120
  max_length=dynamic_max_length,
 
134
  """
135
  Озвучка английского текста конспекта через VitsModel (facebook/mms-tts-eng).
136
 
137
+ Если модель внутри упадёт (известный баг на некоторых странных инпутах),
138
+ мы просто вернём None и не будем ронять всё приложение.
 
 
139
  """
140
  cleaned_text: str = summary_text.strip()
141
  if not cleaned_text:
 
145
  cleaned_text,
146
  return_tensors="pt",
147
  )
 
148
  tokenized_inputs = {
149
  key: value.to(device_string)
150
  for key, value in tokenized_inputs.items()
 
153
  input_ids_tensor = tokenized_inputs.get("input_ids")
154
  if input_ids_tensor is None:
155
  return None
 
156
  if input_ids_tensor.numel() == 0 or input_ids_tensor.shape[1] == 0:
157
  return None
158
 
159
  try:
160
  with torch.no_grad():
161
  model_output = tts_model(**tokenized_inputs)
162
+ waveform_tensor = model_output.waveform # (batch, n_samples)
163
  except RuntimeError as runtime_error:
164
  print(f"[WARN] TTS RuntimeError: {runtime_error}")
165
  return None
 
191
  ) -> Tuple[str, str, Optional[str]]:
192
  """
193
  Полный пайплайн:
194
+ 1) OCR: изображение -> исходный английский текст
195
+ 2) Суммаризация: текст -> конспект (английский)
196
+ 3) TTS: конспект -> .wav файл (или None, если TTS не смог)
197
  """
198
  recognized_text: str = run_ocr(image_object=image_object)
199
 
 
208
 
209
 
210
  # ============================
211
+ # 7. Gradio UI (на русском)
212
  # ============================
213
 
214
  gradio_interface = gradio_module.Interface(
 
216
  inputs=[
217
  gradio_module.Image(
218
  type="pil",
219
+ label="Изображение с напечатанным английским текстом",
220
  ),
221
  gradio_module.Slider(
222
  minimum=32,
223
  maximum=256,
224
  value=128,
225
  step=16,
226
+ label="Максимальная длина конспекта (токены, примерно)",
227
  ),
228
  ],
229
  outputs=[
230
  gradio_module.Textbox(
231
+ label="Распознанный текст (Nemotron OCR)",
232
+ lines=8,
233
  ),
234
  gradio_module.Textbox(
235
+ label="Конспект (английский текст)",
236
  lines=6,
237
  ),
238
  gradio_module.Audio(
239
+ label="Озвучка конспекта (английский TTS)",
240
  type="filepath",
241
  ),
242
  ],
243
+ title="КартинкаТекстКонспектОзвучка (Nemotron OCR + английские модели)",
244
  description=(
245
+ "1) Nemotron OCR v1 (nvidia/nemotron-ocr-v1) распознаёт текст с документа.\n"
246
+ "2) Английский трансформер суммаризации делает краткий пересказ.\n"
247
+ "3) VITS-модель MMS (facebook/mms-tts-eng) озвучивает конспект.\n\n"
248
+ "Если озвучка не сгенерировалась, значит конкретный текст не понравился TTS-модели "
249
+ "и она упала внутри — пайплайн просто пропустит аудио."
250
  ),
251
  )
252
 
requirements.txt CHANGED
@@ -1,4 +1,4 @@
1
- transformers>=4.33.0
2
  torch
3
  sentencepiece
4
  gradio
 
1
+ transformers>=4.40.0
2
  torch
3
  sentencepiece
4
  gradio