Spaces:
Running on T4
Running on T4
Ibad ur Rehman commited on
Commit ·
b586eeb
1
Parent(s): 51c66dc
perf: optimize qwen inference path
Browse files- app.py +8 -0
- config.py +5 -1
- pipeline.py +126 -42
app.py
CHANGED
|
@@ -18,8 +18,12 @@ from config import (
|
|
| 18 |
IMAGES_SCALE,
|
| 19 |
MAX_FILE_SIZE_BYTES,
|
| 20 |
MAX_FILE_SIZE_MB,
|
|
|
|
|
|
|
|
|
|
| 21 |
QWEN_MAX_NEW_TOKENS,
|
| 22 |
QWEN_MODEL,
|
|
|
|
| 23 |
RENDER_DPI,
|
| 24 |
logger,
|
| 25 |
)
|
|
@@ -52,6 +56,10 @@ async def lifespan(app: FastAPI):
|
|
| 52 |
logger.info(f"Max file size: {MAX_FILE_SIZE_MB}MB")
|
| 53 |
logger.info(f"Qwen Model: {QWEN_MODEL}")
|
| 54 |
logger.info(f"Qwen Max New Tokens: {QWEN_MAX_NEW_TOKENS}")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
|
| 56 |
logger.info("=" * 60)
|
| 57 |
logger.info("Docling VLM Parser API ready (Qwen3-VL local parser)")
|
|
|
|
| 18 |
IMAGES_SCALE,
|
| 19 |
MAX_FILE_SIZE_BYTES,
|
| 20 |
MAX_FILE_SIZE_MB,
|
| 21 |
+
QWEN_ATTN_IMPLEMENTATION,
|
| 22 |
+
QWEN_BATCH_SIZE,
|
| 23 |
+
QWEN_IMAGE_MAX_SIDE,
|
| 24 |
QWEN_MAX_NEW_TOKENS,
|
| 25 |
QWEN_MODEL,
|
| 26 |
+
QWEN_TORCH_DTYPE,
|
| 27 |
RENDER_DPI,
|
| 28 |
logger,
|
| 29 |
)
|
|
|
|
| 56 |
logger.info(f"Max file size: {MAX_FILE_SIZE_MB}MB")
|
| 57 |
logger.info(f"Qwen Model: {QWEN_MODEL}")
|
| 58 |
logger.info(f"Qwen Max New Tokens: {QWEN_MAX_NEW_TOKENS}")
|
| 59 |
+
logger.info(f"Qwen Batch Size: {QWEN_BATCH_SIZE}")
|
| 60 |
+
logger.info(f"Qwen Image Max Side: {QWEN_IMAGE_MAX_SIDE}")
|
| 61 |
+
logger.info(f"Qwen Attention: {QWEN_ATTN_IMPLEMENTATION}")
|
| 62 |
+
logger.info(f"Qwen Torch Dtype: {QWEN_TORCH_DTYPE}")
|
| 63 |
|
| 64 |
logger.info("=" * 60)
|
| 65 |
logger.info("Docling VLM Parser API ready (Qwen3-VL local parser)")
|
config.py
CHANGED
|
@@ -21,7 +21,11 @@ MAX_FILE_SIZE_BYTES = MAX_FILE_SIZE_MB * 1024 * 1024
|
|
| 21 |
RENDER_DPI = int(os.getenv("RENDER_DPI", "200"))
|
| 22 |
|
| 23 |
QWEN_MODEL = os.getenv("QWEN_MODEL", "Qwen/Qwen3-VL-8B-Instruct")
|
| 24 |
-
QWEN_MAX_NEW_TOKENS = int(os.getenv("QWEN_MAX_NEW_TOKENS", "
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
|
| 26 |
# Blocked hostnames for SSRF protection
|
| 27 |
BLOCKED_HOSTNAMES = {
|
|
|
|
| 21 |
RENDER_DPI = int(os.getenv("RENDER_DPI", "200"))
|
| 22 |
|
| 23 |
QWEN_MODEL = os.getenv("QWEN_MODEL", "Qwen/Qwen3-VL-8B-Instruct")
|
| 24 |
+
QWEN_MAX_NEW_TOKENS = int(os.getenv("QWEN_MAX_NEW_TOKENS", "1536"))
|
| 25 |
+
QWEN_BATCH_SIZE = int(os.getenv("QWEN_BATCH_SIZE", "2"))
|
| 26 |
+
QWEN_IMAGE_MAX_SIDE = int(os.getenv("QWEN_IMAGE_MAX_SIDE", "1536"))
|
| 27 |
+
QWEN_ATTN_IMPLEMENTATION = os.getenv("QWEN_ATTN_IMPLEMENTATION", "flash_attention_2")
|
| 28 |
+
QWEN_TORCH_DTYPE = os.getenv("QWEN_TORCH_DTYPE", "bfloat16")
|
| 29 |
|
| 30 |
# Blocked hostnames for SSRF protection
|
| 31 |
BLOCKED_HOSTNAMES = {
|
pipeline.py
CHANGED
|
@@ -11,7 +11,15 @@ import torch
|
|
| 11 |
from PIL import Image
|
| 12 |
from transformers import AutoProcessor, Qwen3VLForConditionalGeneration
|
| 13 |
|
| 14 |
-
from config import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
from postprocess import _post_process_merged_markdown
|
| 16 |
from rendering import _image_file_to_png_bytes, _pdf_to_page_images
|
| 17 |
|
|
@@ -31,18 +39,48 @@ _OCR_PROMPT = (
|
|
| 31 |
)
|
| 32 |
|
| 33 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
def _get_pipeline() -> tuple[Qwen3VLForConditionalGeneration, AutoProcessor]:
|
| 35 |
"""Get or create the global Qwen3-VL pipeline."""
|
| 36 |
global _model, _processor
|
| 37 |
if _model is None or _processor is None:
|
| 38 |
logger.info(f"Loading Qwen model: {QWEN_MODEL}")
|
| 39 |
_processor = AutoProcessor.from_pretrained(QWEN_MODEL, trust_remote_code=True)
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
_model.eval()
|
| 47 |
return _model, _processor
|
| 48 |
|
|
@@ -79,51 +117,96 @@ def _create_images_zip(output_dir: Path) -> tuple[Optional[str], int]:
|
|
| 79 |
return base64.b64encode(zip_buffer.getvalue()).decode("utf-8"), image_count
|
| 80 |
|
| 81 |
|
| 82 |
-
def
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 87 |
model, processor = _get_pipeline()
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 98 |
|
| 99 |
-
inputs = processor
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
return_dict=True,
|
| 104 |
return_tensors="pt",
|
| 105 |
)
|
| 106 |
|
| 107 |
device = next(model.parameters()).device
|
| 108 |
-
|
|
|
|
|
|
|
|
|
|
| 109 |
|
| 110 |
with torch.inference_mode():
|
| 111 |
generated_ids = model.generate(
|
| 112 |
-
**
|
| 113 |
max_new_tokens=QWEN_MAX_NEW_TOKENS,
|
| 114 |
do_sample=False,
|
| 115 |
)
|
| 116 |
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
output_ids,
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 124 |
|
| 125 |
-
|
| 126 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 127 |
|
| 128 |
|
| 129 |
def _collect_page_images(
|
|
@@ -159,10 +242,11 @@ def _convert_document(
|
|
| 159 |
raise ValueError("No pages available to parse")
|
| 160 |
|
| 161 |
markdown_pages: list[str] = []
|
| 162 |
-
for
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
|
|
|
| 166 |
|
| 167 |
markdown_content = "\n\n".join(p for p in markdown_pages if p).strip()
|
| 168 |
markdown_content = _post_process_merged_markdown(markdown_content)
|
|
|
|
| 11 |
from PIL import Image
|
| 12 |
from transformers import AutoProcessor, Qwen3VLForConditionalGeneration
|
| 13 |
|
| 14 |
+
from config import (
|
| 15 |
+
QWEN_ATTN_IMPLEMENTATION,
|
| 16 |
+
QWEN_BATCH_SIZE,
|
| 17 |
+
QWEN_IMAGE_MAX_SIDE,
|
| 18 |
+
QWEN_MAX_NEW_TOKENS,
|
| 19 |
+
QWEN_MODEL,
|
| 20 |
+
QWEN_TORCH_DTYPE,
|
| 21 |
+
logger,
|
| 22 |
+
)
|
| 23 |
from postprocess import _post_process_merged_markdown
|
| 24 |
from rendering import _image_file_to_png_bytes, _pdf_to_page_images
|
| 25 |
|
|
|
|
| 39 |
)
|
| 40 |
|
| 41 |
|
| 42 |
+
def _resolve_torch_dtype() -> torch.dtype | str:
|
| 43 |
+
"""Resolve configured dtype to a torch dtype when possible."""
|
| 44 |
+
dtype_map = {
|
| 45 |
+
"auto": "auto",
|
| 46 |
+
"bfloat16": torch.bfloat16,
|
| 47 |
+
"float16": torch.float16,
|
| 48 |
+
"float32": torch.float32,
|
| 49 |
+
}
|
| 50 |
+
return dtype_map.get(QWEN_TORCH_DTYPE.lower(), "auto")
|
| 51 |
+
|
| 52 |
+
|
| 53 |
def _get_pipeline() -> tuple[Qwen3VLForConditionalGeneration, AutoProcessor]:
|
| 54 |
"""Get or create the global Qwen3-VL pipeline."""
|
| 55 |
global _model, _processor
|
| 56 |
if _model is None or _processor is None:
|
| 57 |
logger.info(f"Loading Qwen model: {QWEN_MODEL}")
|
| 58 |
_processor = AutoProcessor.from_pretrained(QWEN_MODEL, trust_remote_code=True)
|
| 59 |
+
model_kwargs = {
|
| 60 |
+
"torch_dtype": _resolve_torch_dtype(),
|
| 61 |
+
"device_map": "auto",
|
| 62 |
+
"trust_remote_code": True,
|
| 63 |
+
}
|
| 64 |
+
if QWEN_ATTN_IMPLEMENTATION and QWEN_ATTN_IMPLEMENTATION.lower() != "none":
|
| 65 |
+
model_kwargs["attn_implementation"] = QWEN_ATTN_IMPLEMENTATION
|
| 66 |
+
try:
|
| 67 |
+
_model = Qwen3VLForConditionalGeneration.from_pretrained(
|
| 68 |
+
QWEN_MODEL,
|
| 69 |
+
**model_kwargs,
|
| 70 |
+
)
|
| 71 |
+
except Exception as e:
|
| 72 |
+
if "attn_implementation" in model_kwargs:
|
| 73 |
+
logger.warning(
|
| 74 |
+
f"Failed to load Qwen with attn_implementation={QWEN_ATTN_IMPLEMENTATION}: {e}. "
|
| 75 |
+
"Retrying without custom attention."
|
| 76 |
+
)
|
| 77 |
+
model_kwargs.pop("attn_implementation", None)
|
| 78 |
+
_model = Qwen3VLForConditionalGeneration.from_pretrained(
|
| 79 |
+
QWEN_MODEL,
|
| 80 |
+
**model_kwargs,
|
| 81 |
+
)
|
| 82 |
+
else:
|
| 83 |
+
raise
|
| 84 |
_model.eval()
|
| 85 |
return _model, _processor
|
| 86 |
|
|
|
|
| 117 |
return base64.b64encode(zip_buffer.getvalue()).decode("utf-8"), image_count
|
| 118 |
|
| 119 |
|
| 120 |
+
def _resize_image(image: Image.Image) -> Image.Image:
|
| 121 |
+
"""Downscale images to reduce visual token count and generation latency."""
|
| 122 |
+
max_side = max(image.size)
|
| 123 |
+
if max_side <= QWEN_IMAGE_MAX_SIDE:
|
| 124 |
+
return image
|
| 125 |
+
|
| 126 |
+
scale = QWEN_IMAGE_MAX_SIDE / max_side
|
| 127 |
+
new_size = (
|
| 128 |
+
max(1, int(image.size[0] * scale)),
|
| 129 |
+
max(1, int(image.size[1] * scale)),
|
| 130 |
+
)
|
| 131 |
+
return image.resize(new_size, Image.Resampling.LANCZOS)
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def _extract_markdown_from_images(
|
| 135 |
+
page_images: list[tuple[int, bytes]],
|
| 136 |
+
request_id: str,
|
| 137 |
+
) -> dict[int, str]:
|
| 138 |
+
"""Run a batch of page images through Qwen3-VL."""
|
| 139 |
model, processor = _get_pipeline()
|
| 140 |
+
prompt_texts: list[str] = []
|
| 141 |
+
images: list[Image.Image] = []
|
| 142 |
+
page_indices: list[int] = []
|
| 143 |
+
|
| 144 |
+
for page_idx, image_bytes in page_images:
|
| 145 |
+
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
|
| 146 |
+
image = _resize_image(image)
|
| 147 |
+
messages = [
|
| 148 |
+
{
|
| 149 |
+
"role": "user",
|
| 150 |
+
"content": [
|
| 151 |
+
{"type": "image", "image": image},
|
| 152 |
+
{"type": "text", "text": _OCR_PROMPT},
|
| 153 |
+
],
|
| 154 |
+
}
|
| 155 |
+
]
|
| 156 |
+
prompt_texts.append(
|
| 157 |
+
processor.apply_chat_template(
|
| 158 |
+
messages,
|
| 159 |
+
tokenize=False,
|
| 160 |
+
add_generation_prompt=True,
|
| 161 |
+
)
|
| 162 |
+
)
|
| 163 |
+
images.append(image)
|
| 164 |
+
page_indices.append(page_idx)
|
| 165 |
|
| 166 |
+
inputs = processor(
|
| 167 |
+
text=prompt_texts,
|
| 168 |
+
images=images,
|
| 169 |
+
padding=True,
|
|
|
|
| 170 |
return_tensors="pt",
|
| 171 |
)
|
| 172 |
|
| 173 |
device = next(model.parameters()).device
|
| 174 |
+
model_inputs = {
|
| 175 |
+
key: value.to(device) if hasattr(value, "to") else value
|
| 176 |
+
for key, value in inputs.items()
|
| 177 |
+
}
|
| 178 |
|
| 179 |
with torch.inference_mode():
|
| 180 |
generated_ids = model.generate(
|
| 181 |
+
**model_inputs,
|
| 182 |
max_new_tokens=QWEN_MAX_NEW_TOKENS,
|
| 183 |
do_sample=False,
|
| 184 |
)
|
| 185 |
|
| 186 |
+
input_lengths = model_inputs["attention_mask"].sum(dim=1).tolist()
|
| 187 |
+
decoded_pages: dict[int, str] = {}
|
| 188 |
+
for row_idx, prompt_length in enumerate(input_lengths):
|
| 189 |
+
output_ids = generated_ids[row_idx : row_idx + 1, int(prompt_length) :]
|
| 190 |
+
text = processor.batch_decode(
|
| 191 |
+
output_ids,
|
| 192 |
+
skip_special_tokens=True,
|
| 193 |
+
clean_up_tokenization_spaces=False,
|
| 194 |
+
)[0].strip()
|
| 195 |
+
page_idx = page_indices[row_idx]
|
| 196 |
+
decoded_pages[page_idx] = text
|
| 197 |
+
logger.info(f"[{request_id}:page:{page_idx + 1}] Qwen generated {len(text)} chars")
|
| 198 |
+
|
| 199 |
+
return decoded_pages
|
| 200 |
|
| 201 |
+
|
| 202 |
+
def _extract_markdown_from_image(
|
| 203 |
+
image_bytes: bytes,
|
| 204 |
+
page_label: str,
|
| 205 |
+
) -> str:
|
| 206 |
+
"""Backwards-compatible single-image wrapper."""
|
| 207 |
+
page_idx = 0
|
| 208 |
+
page_map = _extract_markdown_from_images([(page_idx, image_bytes)], page_label)
|
| 209 |
+
return page_map[page_idx]
|
| 210 |
|
| 211 |
|
| 212 |
def _collect_page_images(
|
|
|
|
| 242 |
raise ValueError("No pages available to parse")
|
| 243 |
|
| 244 |
markdown_pages: list[str] = []
|
| 245 |
+
for batch_start in range(0, len(page_images), QWEN_BATCH_SIZE):
|
| 246 |
+
batch = page_images[batch_start : batch_start + QWEN_BATCH_SIZE]
|
| 247 |
+
batch_outputs = _extract_markdown_from_images(batch, request_id)
|
| 248 |
+
for page_idx, _ in batch:
|
| 249 |
+
markdown_pages.append(batch_outputs.get(page_idx, ""))
|
| 250 |
|
| 251 |
markdown_content = "\n\n".join(p for p in markdown_pages if p).strip()
|
| 252 |
markdown_content = _post_process_merged_markdown(markdown_content)
|