Ibad ur Rehman commited on
Commit
b586eeb
·
1 Parent(s): 51c66dc

perf: optimize qwen inference path

Browse files
Files changed (3) hide show
  1. app.py +8 -0
  2. config.py +5 -1
  3. pipeline.py +126 -42
app.py CHANGED
@@ -18,8 +18,12 @@ from config import (
18
  IMAGES_SCALE,
19
  MAX_FILE_SIZE_BYTES,
20
  MAX_FILE_SIZE_MB,
 
 
 
21
  QWEN_MAX_NEW_TOKENS,
22
  QWEN_MODEL,
 
23
  RENDER_DPI,
24
  logger,
25
  )
@@ -52,6 +56,10 @@ async def lifespan(app: FastAPI):
52
  logger.info(f"Max file size: {MAX_FILE_SIZE_MB}MB")
53
  logger.info(f"Qwen Model: {QWEN_MODEL}")
54
  logger.info(f"Qwen Max New Tokens: {QWEN_MAX_NEW_TOKENS}")
 
 
 
 
55
 
56
  logger.info("=" * 60)
57
  logger.info("Docling VLM Parser API ready (Qwen3-VL local parser)")
 
18
  IMAGES_SCALE,
19
  MAX_FILE_SIZE_BYTES,
20
  MAX_FILE_SIZE_MB,
21
+ QWEN_ATTN_IMPLEMENTATION,
22
+ QWEN_BATCH_SIZE,
23
+ QWEN_IMAGE_MAX_SIDE,
24
  QWEN_MAX_NEW_TOKENS,
25
  QWEN_MODEL,
26
+ QWEN_TORCH_DTYPE,
27
  RENDER_DPI,
28
  logger,
29
  )
 
56
  logger.info(f"Max file size: {MAX_FILE_SIZE_MB}MB")
57
  logger.info(f"Qwen Model: {QWEN_MODEL}")
58
  logger.info(f"Qwen Max New Tokens: {QWEN_MAX_NEW_TOKENS}")
59
+ logger.info(f"Qwen Batch Size: {QWEN_BATCH_SIZE}")
60
+ logger.info(f"Qwen Image Max Side: {QWEN_IMAGE_MAX_SIDE}")
61
+ logger.info(f"Qwen Attention: {QWEN_ATTN_IMPLEMENTATION}")
62
+ logger.info(f"Qwen Torch Dtype: {QWEN_TORCH_DTYPE}")
63
 
64
  logger.info("=" * 60)
65
  logger.info("Docling VLM Parser API ready (Qwen3-VL local parser)")
config.py CHANGED
@@ -21,7 +21,11 @@ MAX_FILE_SIZE_BYTES = MAX_FILE_SIZE_MB * 1024 * 1024
21
  RENDER_DPI = int(os.getenv("RENDER_DPI", "200"))
22
 
23
  QWEN_MODEL = os.getenv("QWEN_MODEL", "Qwen/Qwen3-VL-8B-Instruct")
24
- QWEN_MAX_NEW_TOKENS = int(os.getenv("QWEN_MAX_NEW_TOKENS", "4096"))
 
 
 
 
25
 
26
  # Blocked hostnames for SSRF protection
27
  BLOCKED_HOSTNAMES = {
 
21
  RENDER_DPI = int(os.getenv("RENDER_DPI", "200"))
22
 
23
  QWEN_MODEL = os.getenv("QWEN_MODEL", "Qwen/Qwen3-VL-8B-Instruct")
24
+ QWEN_MAX_NEW_TOKENS = int(os.getenv("QWEN_MAX_NEW_TOKENS", "1536"))
25
+ QWEN_BATCH_SIZE = int(os.getenv("QWEN_BATCH_SIZE", "2"))
26
+ QWEN_IMAGE_MAX_SIDE = int(os.getenv("QWEN_IMAGE_MAX_SIDE", "1536"))
27
+ QWEN_ATTN_IMPLEMENTATION = os.getenv("QWEN_ATTN_IMPLEMENTATION", "flash_attention_2")
28
+ QWEN_TORCH_DTYPE = os.getenv("QWEN_TORCH_DTYPE", "bfloat16")
29
 
30
  # Blocked hostnames for SSRF protection
31
  BLOCKED_HOSTNAMES = {
pipeline.py CHANGED
@@ -11,7 +11,15 @@ import torch
11
  from PIL import Image
12
  from transformers import AutoProcessor, Qwen3VLForConditionalGeneration
13
 
14
- from config import QWEN_MAX_NEW_TOKENS, QWEN_MODEL, logger
 
 
 
 
 
 
 
 
15
  from postprocess import _post_process_merged_markdown
16
  from rendering import _image_file_to_png_bytes, _pdf_to_page_images
17
 
@@ -31,18 +39,48 @@ _OCR_PROMPT = (
31
  )
32
 
33
 
 
 
 
 
 
 
 
 
 
 
 
34
  def _get_pipeline() -> tuple[Qwen3VLForConditionalGeneration, AutoProcessor]:
35
  """Get or create the global Qwen3-VL pipeline."""
36
  global _model, _processor
37
  if _model is None or _processor is None:
38
  logger.info(f"Loading Qwen model: {QWEN_MODEL}")
39
  _processor = AutoProcessor.from_pretrained(QWEN_MODEL, trust_remote_code=True)
40
- _model = Qwen3VLForConditionalGeneration.from_pretrained(
41
- QWEN_MODEL,
42
- torch_dtype="auto",
43
- device_map="auto",
44
- trust_remote_code=True,
45
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  _model.eval()
47
  return _model, _processor
48
 
@@ -79,51 +117,96 @@ def _create_images_zip(output_dir: Path) -> tuple[Optional[str], int]:
79
  return base64.b64encode(zip_buffer.getvalue()).decode("utf-8"), image_count
80
 
81
 
82
- def _extract_markdown_from_image(
83
- image_bytes: bytes,
84
- page_label: str,
85
- ) -> str:
86
- """Run a single page image through Qwen3-VL."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  model, processor = _get_pipeline()
88
- image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
89
- messages = [
90
- {
91
- "role": "user",
92
- "content": [
93
- {"type": "image", "image": image},
94
- {"type": "text", "text": _OCR_PROMPT},
95
- ],
96
- }
97
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
 
99
- inputs = processor.apply_chat_template(
100
- messages,
101
- tokenize=True,
102
- add_generation_prompt=True,
103
- return_dict=True,
104
  return_tensors="pt",
105
  )
106
 
107
  device = next(model.parameters()).device
108
- inputs = inputs.to(device)
 
 
 
109
 
110
  with torch.inference_mode():
111
  generated_ids = model.generate(
112
- **inputs,
113
  max_new_tokens=QWEN_MAX_NEW_TOKENS,
114
  do_sample=False,
115
  )
116
 
117
- prompt_length = inputs["input_ids"].shape[1]
118
- output_ids = generated_ids[:, prompt_length:]
119
- text = processor.batch_decode(
120
- output_ids,
121
- skip_special_tokens=True,
122
- clean_up_tokenization_spaces=False,
123
- )[0].strip()
 
 
 
 
 
 
 
124
 
125
- logger.info(f"[{page_label}] Qwen generated {len(text)} chars")
126
- return text
 
 
 
 
 
 
 
127
 
128
 
129
  def _collect_page_images(
@@ -159,10 +242,11 @@ def _convert_document(
159
  raise ValueError("No pages available to parse")
160
 
161
  markdown_pages: list[str] = []
162
- for page_idx, image_bytes in page_images:
163
- page_label = f"{request_id}:page:{page_idx + 1}"
164
- text = _extract_markdown_from_image(image_bytes, page_label)
165
- markdown_pages.append(text)
 
166
 
167
  markdown_content = "\n\n".join(p for p in markdown_pages if p).strip()
168
  markdown_content = _post_process_merged_markdown(markdown_content)
 
11
  from PIL import Image
12
  from transformers import AutoProcessor, Qwen3VLForConditionalGeneration
13
 
14
+ from config import (
15
+ QWEN_ATTN_IMPLEMENTATION,
16
+ QWEN_BATCH_SIZE,
17
+ QWEN_IMAGE_MAX_SIDE,
18
+ QWEN_MAX_NEW_TOKENS,
19
+ QWEN_MODEL,
20
+ QWEN_TORCH_DTYPE,
21
+ logger,
22
+ )
23
  from postprocess import _post_process_merged_markdown
24
  from rendering import _image_file_to_png_bytes, _pdf_to_page_images
25
 
 
39
  )
40
 
41
 
42
+ def _resolve_torch_dtype() -> torch.dtype | str:
43
+ """Resolve configured dtype to a torch dtype when possible."""
44
+ dtype_map = {
45
+ "auto": "auto",
46
+ "bfloat16": torch.bfloat16,
47
+ "float16": torch.float16,
48
+ "float32": torch.float32,
49
+ }
50
+ return dtype_map.get(QWEN_TORCH_DTYPE.lower(), "auto")
51
+
52
+
53
  def _get_pipeline() -> tuple[Qwen3VLForConditionalGeneration, AutoProcessor]:
54
  """Get or create the global Qwen3-VL pipeline."""
55
  global _model, _processor
56
  if _model is None or _processor is None:
57
  logger.info(f"Loading Qwen model: {QWEN_MODEL}")
58
  _processor = AutoProcessor.from_pretrained(QWEN_MODEL, trust_remote_code=True)
59
+ model_kwargs = {
60
+ "torch_dtype": _resolve_torch_dtype(),
61
+ "device_map": "auto",
62
+ "trust_remote_code": True,
63
+ }
64
+ if QWEN_ATTN_IMPLEMENTATION and QWEN_ATTN_IMPLEMENTATION.lower() != "none":
65
+ model_kwargs["attn_implementation"] = QWEN_ATTN_IMPLEMENTATION
66
+ try:
67
+ _model = Qwen3VLForConditionalGeneration.from_pretrained(
68
+ QWEN_MODEL,
69
+ **model_kwargs,
70
+ )
71
+ except Exception as e:
72
+ if "attn_implementation" in model_kwargs:
73
+ logger.warning(
74
+ f"Failed to load Qwen with attn_implementation={QWEN_ATTN_IMPLEMENTATION}: {e}. "
75
+ "Retrying without custom attention."
76
+ )
77
+ model_kwargs.pop("attn_implementation", None)
78
+ _model = Qwen3VLForConditionalGeneration.from_pretrained(
79
+ QWEN_MODEL,
80
+ **model_kwargs,
81
+ )
82
+ else:
83
+ raise
84
  _model.eval()
85
  return _model, _processor
86
 
 
117
  return base64.b64encode(zip_buffer.getvalue()).decode("utf-8"), image_count
118
 
119
 
120
+ def _resize_image(image: Image.Image) -> Image.Image:
121
+ """Downscale images to reduce visual token count and generation latency."""
122
+ max_side = max(image.size)
123
+ if max_side <= QWEN_IMAGE_MAX_SIDE:
124
+ return image
125
+
126
+ scale = QWEN_IMAGE_MAX_SIDE / max_side
127
+ new_size = (
128
+ max(1, int(image.size[0] * scale)),
129
+ max(1, int(image.size[1] * scale)),
130
+ )
131
+ return image.resize(new_size, Image.Resampling.LANCZOS)
132
+
133
+
134
+ def _extract_markdown_from_images(
135
+ page_images: list[tuple[int, bytes]],
136
+ request_id: str,
137
+ ) -> dict[int, str]:
138
+ """Run a batch of page images through Qwen3-VL."""
139
  model, processor = _get_pipeline()
140
+ prompt_texts: list[str] = []
141
+ images: list[Image.Image] = []
142
+ page_indices: list[int] = []
143
+
144
+ for page_idx, image_bytes in page_images:
145
+ image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
146
+ image = _resize_image(image)
147
+ messages = [
148
+ {
149
+ "role": "user",
150
+ "content": [
151
+ {"type": "image", "image": image},
152
+ {"type": "text", "text": _OCR_PROMPT},
153
+ ],
154
+ }
155
+ ]
156
+ prompt_texts.append(
157
+ processor.apply_chat_template(
158
+ messages,
159
+ tokenize=False,
160
+ add_generation_prompt=True,
161
+ )
162
+ )
163
+ images.append(image)
164
+ page_indices.append(page_idx)
165
 
166
+ inputs = processor(
167
+ text=prompt_texts,
168
+ images=images,
169
+ padding=True,
 
170
  return_tensors="pt",
171
  )
172
 
173
  device = next(model.parameters()).device
174
+ model_inputs = {
175
+ key: value.to(device) if hasattr(value, "to") else value
176
+ for key, value in inputs.items()
177
+ }
178
 
179
  with torch.inference_mode():
180
  generated_ids = model.generate(
181
+ **model_inputs,
182
  max_new_tokens=QWEN_MAX_NEW_TOKENS,
183
  do_sample=False,
184
  )
185
 
186
+ input_lengths = model_inputs["attention_mask"].sum(dim=1).tolist()
187
+ decoded_pages: dict[int, str] = {}
188
+ for row_idx, prompt_length in enumerate(input_lengths):
189
+ output_ids = generated_ids[row_idx : row_idx + 1, int(prompt_length) :]
190
+ text = processor.batch_decode(
191
+ output_ids,
192
+ skip_special_tokens=True,
193
+ clean_up_tokenization_spaces=False,
194
+ )[0].strip()
195
+ page_idx = page_indices[row_idx]
196
+ decoded_pages[page_idx] = text
197
+ logger.info(f"[{request_id}:page:{page_idx + 1}] Qwen generated {len(text)} chars")
198
+
199
+ return decoded_pages
200
 
201
+
202
+ def _extract_markdown_from_image(
203
+ image_bytes: bytes,
204
+ page_label: str,
205
+ ) -> str:
206
+ """Backwards-compatible single-image wrapper."""
207
+ page_idx = 0
208
+ page_map = _extract_markdown_from_images([(page_idx, image_bytes)], page_label)
209
+ return page_map[page_idx]
210
 
211
 
212
  def _collect_page_images(
 
242
  raise ValueError("No pages available to parse")
243
 
244
  markdown_pages: list[str] = []
245
+ for batch_start in range(0, len(page_images), QWEN_BATCH_SIZE):
246
+ batch = page_images[batch_start : batch_start + QWEN_BATCH_SIZE]
247
+ batch_outputs = _extract_markdown_from_images(batch, request_id)
248
+ for page_idx, _ in batch:
249
+ markdown_pages.append(batch_outputs.get(page_idx, ""))
250
 
251
  markdown_content = "\n\n".join(p for p in markdown_pages if p).strip()
252
  markdown_content = _post_process_merged_markdown(markdown_content)