prithivMLmods commited on
Commit
d6f9fb3
·
verified ·
1 Parent(s): 7abc3cc

update app

Browse files
Files changed (1) hide show
  1. app.py +47 -91
app.py CHANGED
@@ -10,15 +10,12 @@ from PIL import Image
10
  from threading import Thread
11
  from typing import Iterable, Optional, Tuple, List
12
 
13
- # --- Transformer & Model Imports ---
14
  from transformers import (
15
  Qwen2_5_VLForConditionalGeneration,
16
  AutoProcessor,
17
  TextIteratorStreamer,
18
  )
19
 
20
- # --- VibeVoice Imports ---
21
- # Assuming local folder structure exists for these imports
22
  try:
23
  from vibevoice.modular.modeling_vibevoice_streaming_inference import (
24
  VibeVoiceStreamingForConditionalGenerationInference,
@@ -28,39 +25,33 @@ try:
28
  )
29
  except ImportError:
30
  print("CRITICAL WARNING: 'vibevoice' modules not found. Ensure the vibevoice repository structure is present.")
31
- # Mocking for syntax checking if files are missing during dry-run
32
  VibeVoiceStreamingForConditionalGenerationInference = None
33
  VibeVoiceStreamingProcessor = None
34
 
35
- # --- UI Theme Imports ---
36
  from gradio.themes import Soft
37
  from gradio.themes.utils import colors, fonts, sizes
38
 
39
- # ==========================================
40
- # 1. THEME CONFIGURATION (Steel Blue)
41
- # ==========================================
42
-
43
- colors.steel_blue = colors.Color(
44
- name="steel_blue",
45
- c50="#EBF3F8",
46
- c100="#D3E5F0",
47
- c200="#A8CCE1",
48
- c300="#7DB3D2",
49
- c400="#529AC3",
50
- c500="#4682B4",
51
- c600="#3E72A0",
52
- c700="#36638C",
53
- c800="#2E5378",
54
- c900="#264364",
55
- c950="#1E3450",
56
  )
57
 
58
- class SteelBlueTheme(Soft):
59
  def __init__(
60
  self,
61
  *,
62
  primary_hue: colors.Color | str = colors.gray,
63
- secondary_hue: colors.Color | str = colors.steel_blue,
64
  neutral_hue: colors.Color | str = colors.slate,
65
  text_size: sizes.Size | str = sizes.text_lg,
66
  font: fonts.Font | str | Iterable[fonts.Font | str] = (
@@ -87,8 +78,14 @@ class SteelBlueTheme(Soft):
87
  button_primary_text_color_hover="white",
88
  button_primary_background_fill="linear-gradient(90deg, *secondary_500, *secondary_600)",
89
  button_primary_background_fill_hover="linear-gradient(90deg, *secondary_600, *secondary_700)",
90
- button_primary_background_fill_dark="linear-gradient(90deg, *secondary_600, *secondary_800)",
91
- button_primary_background_fill_hover_dark="linear-gradient(90deg, *secondary_500, *secondary_500)",
 
 
 
 
 
 
92
  slider_color="*secondary_500",
93
  slider_color_dark="*secondary_600",
94
  block_title_text_weight="600",
@@ -100,7 +97,7 @@ class SteelBlueTheme(Soft):
100
  block_label_background_fill="*primary_200",
101
  )
102
 
103
- steel_blue_theme = SteelBlueTheme()
104
 
105
  css = """
106
  #main-title h1 {
@@ -114,20 +111,15 @@ css = """
114
  }
115
  """
116
 
117
- # ==========================================
118
- # 2. MODEL SETUP (Global)
119
- # ==========================================
120
-
121
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
122
  print(f"Using Main Device: {device}")
123
 
124
- # --- A. Setup Chandra-OCR (Qwen3-VL) ---
125
- OCR_MODEL_ID = "Qwen/Qwen2.5-VL-7B-Instruct"
126
- print(f"Loading OCR Model: {OCR_MODEL_ID}...")
127
 
128
- ocr_processor = AutoProcessor.from_pretrained(OCR_MODEL_ID, trust_remote_code=True)
129
- ocr_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
130
- OCR_MODEL_ID,
131
  attn_implementation="flash_attention_2",
132
  trust_remote_code=True,
133
  torch_dtype=torch.float16
@@ -135,14 +127,11 @@ ocr_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
135
 
136
  print("OCR Model loaded successfully.")
137
 
138
- # --- B. Setup VibeVoice (TTS) ---
139
  TTS_MODEL_PATH = "microsoft/VibeVoice-Realtime-0.5B"
140
  print(f"Loading TTS Model: {TTS_MODEL_PATH}...")
141
 
142
- # Load processor
143
  tts_processor = VibeVoiceStreamingProcessor.from_pretrained(TTS_MODEL_PATH)
144
 
145
- # Load model on CPU initially (moved to GPU on demand to save VRAM)
146
  tts_model = VibeVoiceStreamingForConditionalGenerationInference.from_pretrained(
147
  TTS_MODEL_PATH,
148
  torch_dtype=torch.float16,
@@ -152,12 +141,10 @@ tts_model = VibeVoiceStreamingForConditionalGenerationInference.from_pretrained(
152
  tts_model.eval()
153
  tts_model.set_ddpm_inference_steps(num_steps=5)
154
 
155
- # Voice Mapper Class
156
  class VoiceMapper:
157
  """Maps speaker names to voice file paths"""
158
  def __init__(self):
159
  self.setup_voice_presets()
160
- # Clean up names
161
  new_dict = {}
162
  for name, path in self.voice_presets.items():
163
  if "_" in name: name = name.split("_")[0]
@@ -169,8 +156,6 @@ class VoiceMapper:
169
  voices_dir = os.path.join(os.path.dirname(__file__), "demo/voices/streaming_model")
170
  if not os.path.exists(voices_dir):
171
  print(f"Warning: Voices directory not found at {voices_dir}")
172
- # Create a placeholder if dir doesn't exist to prevent crash during init,
173
- # though generation will fail if no files.
174
  self.voice_presets = {}
175
  self.available_voices = {}
176
  return
@@ -190,12 +175,10 @@ class VoiceMapper:
190
  def get_voice_path(self, speaker_name: str) -> str:
191
  if speaker_name in self.voice_presets:
192
  return self.voice_presets[speaker_name]
193
- # Partial match
194
  speaker_lower = speaker_name.lower()
195
  for preset_name, path in self.voice_presets.items():
196
  if preset_name.lower() in speaker_lower or speaker_lower in preset_name.lower():
197
  return path
198
- # Default
199
  if self.voice_presets:
200
  return list(self.voice_presets.values())[0]
201
  return ""
@@ -203,12 +186,7 @@ class VoiceMapper:
203
  VOICE_MAPPER = VoiceMapper()
204
  print("TTS Model loaded successfully.")
205
 
206
-
207
- # ==========================================
208
- # 3. GENERATION FUNCTIONS
209
- # ==========================================
210
-
211
- @spaces.GPU(duration=120)
212
  def process_pipeline(
213
  image: Image.Image,
214
  query: str,
@@ -224,10 +202,8 @@ def process_pipeline(
224
  if image is None:
225
  return "Please upload an image.", None, "Error: No image provided."
226
 
227
- # --- Step 1: OCR ---
228
  progress(0.1, desc="Analyzing Image (OCR)...")
229
 
230
- # Clean query
231
  if not query:
232
  query = "OCR the content perfectly."
233
 
@@ -239,19 +215,16 @@ def process_pipeline(
239
  ]
240
  }]
241
 
242
- prompt_full = ocr_processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
243
 
244
- # Process inputs (move to device)
245
- inputs = ocr_processor(
246
  text=[prompt_full],
247
  images=[image],
248
  return_tensors="pt",
249
  padding=True
250
  ).to(device)
251
 
252
- # Generate Text
253
- # We use standard generate here instead of streamer to get the full string for TTS easily
254
- generated_ids = ocr_model.generate(
255
  **inputs,
256
  max_new_tokens=ocr_max_tokens,
257
  do_sample=True,
@@ -262,31 +235,26 @@ def process_pipeline(
262
  generated_ids_trimmed = [
263
  out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
264
  ]
265
- extracted_text = ocr_processor.batch_decode(
266
  generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
267
  )[0]
268
 
269
- # Clean cleanup
270
  extracted_text = extracted_text.replace("<|im_end|>", "").strip()
271
 
272
  progress(0.5, desc=f"OCR Complete. Converting to speech ({len(extracted_text)} chars)...")
273
 
274
- # --- Step 2: TTS ---
275
  if not extracted_text:
276
  return extracted_text, None, "OCR produced no text."
277
 
278
  try:
279
- # Pre-process text
280
  full_script = extracted_text.replace("'", "'").replace('"', '"').replace('"', '"')
281
 
282
- # Get voice
283
  voice_path = VOICE_MAPPER.get_voice_path(speaker_name)
284
  if not voice_path:
285
  return extracted_text, None, "Error: Voice file not found."
286
 
287
  all_prefilled_outputs = torch.load(voice_path, map_location="cuda", weights_only=False)
288
 
289
- # Prepare inputs
290
  tts_inputs = tts_processor.process_input_with_cached_prompt(
291
  text=full_script,
292
  cached_prompt=all_prefilled_outputs,
@@ -295,13 +263,11 @@ def process_pipeline(
295
  return_attention_mask=True,
296
  )
297
 
298
- # Move TTS model to GPU
299
  tts_model.to("cuda")
300
  for k, v in tts_inputs.items():
301
  if torch.is_tensor(v):
302
  tts_inputs[k] = v.to("cuda")
303
 
304
- # Generate Audio
305
  with torch.cuda.amp.autocast():
306
  outputs = tts_model.generate(
307
  **tts_inputs,
@@ -313,7 +279,6 @@ def process_pipeline(
313
  all_prefilled_outputs=copy.deepcopy(all_prefilled_outputs)
314
  )
315
 
316
- # Move TTS back to CPU to be safe
317
  tts_model.to("cpu")
318
  torch.cuda.empty_cache()
319
 
@@ -329,7 +294,7 @@ def process_pipeline(
329
  output_path=output_path,
330
  )
331
 
332
- status = f"✅ Success! OCR Text Length: {len(extracted_text)} chars."
333
  return extracted_text, output_path, status
334
  else:
335
  return extracted_text, None, "TTS Generation failed (no output)."
@@ -340,25 +305,14 @@ def process_pipeline(
340
  import traceback
341
  return extracted_text, None, f"Error during TTS: {str(e)}"
342
 
343
- # ==========================================
344
- # 4. GRADIO INTERFACE
345
- # ==========================================
346
-
347
- image_examples = [
348
- ["OCR the content perfectly.", "examples/3.jpg"],
349
- ["Perform OCR on the image.", "examples/1.jpg"],
350
- ["Extract the contents. [page].", "examples/2.jpg"],
351
- ]
352
-
353
  with gr.Blocks() as demo:
354
  gr.Markdown("# **Vision-to-VibeVoice-en**", elem_id="main-title")
355
-
356
  with gr.Row():
357
- # --- Left Column: Inputs ---
358
  with gr.Column(scale=1):
359
  gr.Markdown("### 1. Vision Input")
360
  image_upload = gr.Image(type="pil", label="Upload Image", height=300)
361
- image_query = gr.Textbox(label="OCR Prompt", value="OCR the content perfectly.", placeholder="E.g., Read this page...")
362
 
363
  gr.Markdown("### 2. Voice Settings")
364
  voice_choices = list(VOICE_MAPPER.available_voices.keys())
@@ -373,23 +327,20 @@ with gr.Blocks() as demo:
373
  cfg_slider = gr.Slider(minimum=1.0, maximum=3.0, value=1.5, step=0.1, label="CFG Scale (Speech Fidelity)")
374
 
375
  with gr.Accordion("Advanced Options", open=False):
376
- max_new_tokens = gr.Slider(label="Max OCR Tokens", minimum=128, maximum=4096, step=128, value=2048)
377
- temperature = gr.Slider(label="OCR Temperature", minimum=0.1, maximum=2.0, step=0.1, value=0.1)
378
 
379
- submit_btn = gr.Button("🚀 Process Vision to Voice", variant="primary", size="lg")
380
 
381
- # --- Right Column: Outputs ---
382
  with gr.Column(scale=1):
383
  gr.Markdown("### 3. Results", elem_id="output-title")
384
 
385
- # Text Output
386
  text_output = gr.Textbox(
387
  label="Extracted Text (Editable)",
388
  interactive=True,
389
  lines=10,
390
  )
391
 
392
- # Audio Output
393
  audio_output = gr.Audio(
394
  label="Generated Speech",
395
  type="filepath",
@@ -397,8 +348,13 @@ with gr.Blocks() as demo:
397
  )
398
 
399
  status_output = gr.Textbox(label="Status Log", lines=2)
 
 
 
 
 
 
400
 
401
- # --- Logic Connection ---
402
  submit_btn.click(
403
  fn=process_pipeline,
404
  inputs=[
@@ -413,4 +369,4 @@ with gr.Blocks() as demo:
413
  )
414
 
415
  if __name__ == "__main__":
416
- demo.queue(max_size=10).launch(css=css, theme=steel_blue_theme, ssr_mode=False, show_error=True)
 
10
  from threading import Thread
11
  from typing import Iterable, Optional, Tuple, List
12
 
 
13
  from transformers import (
14
  Qwen2_5_VLForConditionalGeneration,
15
  AutoProcessor,
16
  TextIteratorStreamer,
17
  )
18
 
 
 
19
  try:
20
  from vibevoice.modular.modeling_vibevoice_streaming_inference import (
21
  VibeVoiceStreamingForConditionalGenerationInference,
 
25
  )
26
  except ImportError:
27
  print("CRITICAL WARNING: 'vibevoice' modules not found. Ensure the vibevoice repository structure is present.")
 
28
  VibeVoiceStreamingForConditionalGenerationInference = None
29
  VibeVoiceStreamingProcessor = None
30
 
 
31
  from gradio.themes import Soft
32
  from gradio.themes.utils import colors, fonts, sizes
33
 
34
+ colors.orange_red = colors.Color(
35
+ name="orange_red",
36
+ c50="#FFF0E5",
37
+ c100="#FFE0CC",
38
+ c200="#FFC299",
39
+ c300="#FFA366",
40
+ c400="#FF8533",
41
+ c500="#FF4500",
42
+ c600="#E63E00",
43
+ c700="#CC3700",
44
+ c800="#B33000",
45
+ c900="#992900",
46
+ c950="#802200",
 
 
 
 
47
  )
48
 
49
+ class OrangeRedTheme(Soft):
50
  def __init__(
51
  self,
52
  *,
53
  primary_hue: colors.Color | str = colors.gray,
54
+ secondary_hue: colors.Color | str = colors.orange_red,
55
  neutral_hue: colors.Color | str = colors.slate,
56
  text_size: sizes.Size | str = sizes.text_lg,
57
  font: fonts.Font | str | Iterable[fonts.Font | str] = (
 
78
  button_primary_text_color_hover="white",
79
  button_primary_background_fill="linear-gradient(90deg, *secondary_500, *secondary_600)",
80
  button_primary_background_fill_hover="linear-gradient(90deg, *secondary_600, *secondary_700)",
81
+ button_primary_background_fill_dark="linear-gradient(90deg, *secondary_600, *secondary_700)",
82
+ button_primary_background_fill_hover_dark="linear-gradient(90deg, *secondary_500, *secondary_600)",
83
+ button_secondary_text_color="black",
84
+ button_secondary_text_color_hover="white",
85
+ button_secondary_background_fill="linear-gradient(90deg, *primary_300, *primary_300)",
86
+ button_secondary_background_fill_hover="linear-gradient(90deg, *primary_400, *primary_400)",
87
+ button_secondary_background_fill_dark="linear-gradient(90deg, *primary_500, *primary_600)",
88
+ button_secondary_background_fill_hover_dark="linear-gradient(90deg, *primary_500, *primary_500)",
89
  slider_color="*secondary_500",
90
  slider_color_dark="*secondary_600",
91
  block_title_text_weight="600",
 
97
  block_label_background_fill="*primary_200",
98
  )
99
 
100
+ orange_red_theme = OrangeRedTheme()
101
 
102
  css = """
103
  #main-title h1 {
 
111
  }
112
  """
113
 
 
 
 
 
114
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
115
  print(f"Using Main Device: {device}")
116
 
117
+ QWEN_VL_MODEL_ID = "Qwen/Qwen2.5-VL-7B-Instruct"
118
+ print(f"Loading OCR Model: {QWEN_VL_MODEL_ID}...")
 
119
 
120
+ qwen_processor = AutoProcessor.from_pretrained(QWEN_VL_MODEL_ID, trust_remote_code=True)
121
+ qwen_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
122
+ QWEN_VL_MODEL_ID,
123
  attn_implementation="flash_attention_2",
124
  trust_remote_code=True,
125
  torch_dtype=torch.float16
 
127
 
128
  print("OCR Model loaded successfully.")
129
 
 
130
  TTS_MODEL_PATH = "microsoft/VibeVoice-Realtime-0.5B"
131
  print(f"Loading TTS Model: {TTS_MODEL_PATH}...")
132
 
 
133
  tts_processor = VibeVoiceStreamingProcessor.from_pretrained(TTS_MODEL_PATH)
134
 
 
135
  tts_model = VibeVoiceStreamingForConditionalGenerationInference.from_pretrained(
136
  TTS_MODEL_PATH,
137
  torch_dtype=torch.float16,
 
141
  tts_model.eval()
142
  tts_model.set_ddpm_inference_steps(num_steps=5)
143
 
 
144
  class VoiceMapper:
145
  """Maps speaker names to voice file paths"""
146
  def __init__(self):
147
  self.setup_voice_presets()
 
148
  new_dict = {}
149
  for name, path in self.voice_presets.items():
150
  if "_" in name: name = name.split("_")[0]
 
156
  voices_dir = os.path.join(os.path.dirname(__file__), "demo/voices/streaming_model")
157
  if not os.path.exists(voices_dir):
158
  print(f"Warning: Voices directory not found at {voices_dir}")
 
 
159
  self.voice_presets = {}
160
  self.available_voices = {}
161
  return
 
175
  def get_voice_path(self, speaker_name: str) -> str:
176
  if speaker_name in self.voice_presets:
177
  return self.voice_presets[speaker_name]
 
178
  speaker_lower = speaker_name.lower()
179
  for preset_name, path in self.voice_presets.items():
180
  if preset_name.lower() in speaker_lower or speaker_lower in preset_name.lower():
181
  return path
 
182
  if self.voice_presets:
183
  return list(self.voice_presets.values())[0]
184
  return ""
 
186
  VOICE_MAPPER = VoiceMapper()
187
  print("TTS Model loaded successfully.")
188
 
189
+ @spaces.GPU
 
 
 
 
 
190
  def process_pipeline(
191
  image: Image.Image,
192
  query: str,
 
202
  if image is None:
203
  return "Please upload an image.", None, "Error: No image provided."
204
 
 
205
  progress(0.1, desc="Analyzing Image (OCR)...")
206
 
 
207
  if not query:
208
  query = "OCR the content perfectly."
209
 
 
215
  ]
216
  }]
217
 
218
+ prompt_full = qwen_processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
219
 
220
+ inputs = qwen_processor(
 
221
  text=[prompt_full],
222
  images=[image],
223
  return_tensors="pt",
224
  padding=True
225
  ).to(device)
226
 
227
+ generated_ids = qwen_model.generate(
 
 
228
  **inputs,
229
  max_new_tokens=ocr_max_tokens,
230
  do_sample=True,
 
235
  generated_ids_trimmed = [
236
  out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
237
  ]
238
+ extracted_text = qwen_processor.batch_decode(
239
  generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
240
  )[0]
241
 
 
242
  extracted_text = extracted_text.replace("<|im_end|>", "").strip()
243
 
244
  progress(0.5, desc=f"OCR Complete. Converting to speech ({len(extracted_text)} chars)...")
245
 
 
246
  if not extracted_text:
247
  return extracted_text, None, "OCR produced no text."
248
 
249
  try:
 
250
  full_script = extracted_text.replace("'", "'").replace('"', '"').replace('"', '"')
251
 
 
252
  voice_path = VOICE_MAPPER.get_voice_path(speaker_name)
253
  if not voice_path:
254
  return extracted_text, None, "Error: Voice file not found."
255
 
256
  all_prefilled_outputs = torch.load(voice_path, map_location="cuda", weights_only=False)
257
 
 
258
  tts_inputs = tts_processor.process_input_with_cached_prompt(
259
  text=full_script,
260
  cached_prompt=all_prefilled_outputs,
 
263
  return_attention_mask=True,
264
  )
265
 
 
266
  tts_model.to("cuda")
267
  for k, v in tts_inputs.items():
268
  if torch.is_tensor(v):
269
  tts_inputs[k] = v.to("cuda")
270
 
 
271
  with torch.cuda.amp.autocast():
272
  outputs = tts_model.generate(
273
  **tts_inputs,
 
279
  all_prefilled_outputs=copy.deepcopy(all_prefilled_outputs)
280
  )
281
 
 
282
  tts_model.to("cpu")
283
  torch.cuda.empty_cache()
284
 
 
294
  output_path=output_path,
295
  )
296
 
297
+ status = f"✅ Success! Text Length: {len(extracted_text)} chars."
298
  return extracted_text, output_path, status
299
  else:
300
  return extracted_text, None, "TTS Generation failed (no output)."
 
305
  import traceback
306
  return extracted_text, None, f"Error during TTS: {str(e)}"
307
 
 
 
 
 
 
 
 
 
 
 
308
  with gr.Blocks() as demo:
309
  gr.Markdown("# **Vision-to-VibeVoice-en**", elem_id="main-title")
310
+ gr.Markdown("Perform vision-to-audio inference with [Qwen2.5VL](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct) + [VibeVoice-Realtime-0.5B](https://huggingface.co/microsoft/VibeVoice-Realtime-0.5B).")
311
  with gr.Row():
 
312
  with gr.Column(scale=1):
313
  gr.Markdown("### 1. Vision Input")
314
  image_upload = gr.Image(type="pil", label="Upload Image", height=300)
315
+ image_query = gr.Textbox(label="Enter the prompt", value="Give a short description indicating whether the image is safe or unsafe.", placeholder="E.g., Read this page...")
316
 
317
  gr.Markdown("### 2. Voice Settings")
318
  voice_choices = list(VOICE_MAPPER.available_voices.keys())
 
327
  cfg_slider = gr.Slider(minimum=1.0, maximum=3.0, value=1.5, step=0.1, label="CFG Scale (Speech Fidelity)")
328
 
329
  with gr.Accordion("Advanced Options", open=False):
330
+ max_new_tokens = gr.Slider(label="Max Tokens", minimum=128, maximum=4096, step=128, value=2048)
331
+ temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=2.0, step=0.1, value=0.1)
332
 
333
+ submit_btn = gr.Button("Process Vision to Voice", variant="primary", size="lg")
334
 
 
335
  with gr.Column(scale=1):
336
  gr.Markdown("### 3. Results", elem_id="output-title")
337
 
 
338
  text_output = gr.Textbox(
339
  label="Extracted Text (Editable)",
340
  interactive=True,
341
  lines=10,
342
  )
343
 
 
344
  audio_output = gr.Audio(
345
  label="Generated Speech",
346
  type="filepath",
 
348
  )
349
 
350
  status_output = gr.Textbox(label="Status Log", lines=2)
351
+
352
+ gr.Examples(
353
+ examples=[["Perform OCR on the image.", "examples/1.jpg"]],
354
+ inputs=[image_query, image_upload],
355
+ label="Example"
356
+ )
357
 
 
358
  submit_btn.click(
359
  fn=process_pipeline,
360
  inputs=[
 
369
  )
370
 
371
  if __name__ == "__main__":
372
+ demo.queue(max_size=40).launch(css=css, theme=orange_red_theme, ssr_mode=False, show_error=True)