from __future__ import annotations import ast import json import re from pathlib import Path from typing import Any, Optional import numpy as np from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import ( Qwen2_5_VLForConditionalGeneration, ) _TOOL_RE = re.compile( r'[^"]+)"\s+label=(?P\[[^\]]*\])\s*>\s*', re.IGNORECASE, ) _TOKEN_RE = re.compile( r'Reviewing\s+.+?\.\.\.|Inspecting\s+.+?\.\.\.|\s*', re.IGNORECASE | re.DOTALL, ) class VReasonQwen2_5_VLForConditionalGeneration(Qwen2_5_VLForConditionalGeneration): """Qwen2.5-VL model with VReason helper methods. Loaded via: AutoModelForVision2Seq.from_pretrained(..., trust_remote_code=True) Extra method: model.visual_reason(...) """ @staticmethod def _parse_labels(raw: str) -> list[str]: try: parsed = json.loads(raw) if isinstance(parsed, list): return [str(x).strip() for x in parsed if str(x).strip()] except Exception: pass try: parsed = ast.literal_eval(raw) if isinstance(parsed, list): return [str(x).strip() for x in parsed if str(x).strip()] except Exception: pass return [x.strip().strip('"').strip("'") for x in raw.strip("[]").split(",") if x.strip()] @staticmethod def _extract_tag(text: str, tag: str) -> str: m = re.search(rf"<{tag}>(.*?)", text, flags=re.IGNORECASE | re.DOTALL) return m.group(1).strip() if m else "" @classmethod def _parse_reasoning(cls, model_output: str) -> list[dict[str, Any]]: interpret = cls._extract_tag(model_output, "interpret") if not interpret: return [] regions: list[dict[str, Any]] = [] current_region: Optional[dict[str, Any]] = None current_sub: Optional[dict[str, Any]] = None cursor = 0 for tok in _TOKEN_RE.finditer(interpret): between = interpret[cursor:tok.start()].strip() if between and current_sub is not None: current_sub["reason"] = (current_sub["reason"] + " " + between).strip() token = tok.group(0) low = token.lower() if low.startswith("reviewing "): region = token[len("Reviewing ") :].strip() if region.endswith("..."): region = region[:-3].strip() current_region = { "region": region, "labels": [], "path": None, "pathological": [], } regions.append(current_region) current_sub = None elif low.startswith("inspecting ") and current_region is not None: anatomy = token[len("Inspecting ") :].strip() if anatomy.endswith("..."): anatomy = anatomy[:-3].strip() current_sub = { "anatomies": anatomy, "labels": [], "reason": "", "path": None, } current_region["pathological"].append(current_sub) else: m = _TOOL_RE.search(token) if m: labels = cls._parse_labels(m.group("labels")) tool_type = m.group("tool_type").lower() if tool_type == "anatomical_roi" and current_region is not None: current_region["labels"] = labels elif tool_type == "pathological_roi" and current_sub is not None: current_sub["labels"] = labels cursor = tok.end() trailing = interpret[cursor:].strip() if trailing and current_sub is not None: current_sub["reason"] = (current_sub["reason"] + " " + trailing).strip() for r in regions: for p in r["pathological"]: p["reason"] = re.sub(r"\s+", " ", p["reason"]).strip() return regions @staticmethod def _ids_for_many(names: list[str], name2id: dict[str, list[int]]) -> list[int]: out: list[int] = [] seen = set() for name in names: key = str(name).strip().lower() if key in name2id: for idx in name2id[key]: idx = int(idx) if idx not in seen: seen.add(idx) out.append(idx) return out @staticmethod def _mask_union(mask_array: np.ndarray, indices: list[int]) -> Optional[np.ndarray]: if mask_array is None or not indices: return None safe = [i for i in indices if 0 <= i < mask_array.shape[0]] if not safe: return None return mask_array[safe].any(axis=0) @staticmethod def _bbox_from_mask(mask: Optional[np.ndarray], width: int, height: int, pad: int = 0) -> tuple[int, int, int, int]: if mask is None or not mask.any(): return 0, width, 0, height ys, xs = np.where(mask) x0 = max(0, int(xs.min()) - pad) x1 = min(width, int(xs.max()) + pad) y0 = max(0, int(ys.min()) - pad) y1 = min(height, int(ys.max()) + pad) return x0, x1, y0, y1 @staticmethod def _to_alpha(mask: Optional[np.ndarray], invert: bool, feather: int, size_wh: tuple[int, int]): from PIL import Image import cv2 if mask is None: return Image.new("L", size_wh, color=255 if invert else 0) selected = (~mask if invert else mask).astype(np.uint8) * 255 if size_wh != (mask.shape[1], mask.shape[0]): selected = cv2.resize(selected, size_wh, interpolation=cv2.INTER_NEAREST) if feather > 0: selected = cv2.GaussianBlur(selected, ksize=(0, 0), sigmaX=feather, sigmaY=feather) return Image.fromarray(selected, mode="L") @staticmethod def _save_viz( img_base, mask: Optional[np.ndarray], out_path: Path, mode: str, blur_radius: int, feather: int, ring: int, roi_wh: Optional[tuple[int, int]], ) -> None: import cv2 from PIL import Image, ImageFilter base = img_base.convert("RGB") w, h = base.size if mask is not None and (mask.shape[1], mask.shape[0]) != (w, h): mask = cv2.resize(mask.astype(np.uint8), (w, h), interpolation=cv2.INTER_NEAREST).astype(bool) if mode == "blur": blurred = base.filter(ImageFilter.GaussianBlur(radius=blur_radius)) alpha = VReasonQwen2_5_VLForConditionalGeneration._to_alpha(mask, invert=True, feather=feather, size_wh=base.size) out = Image.composite(blurred, base, alpha) elif mode == "crop": x0, x1, y0, y1 = VReasonQwen2_5_VLForConditionalGeneration._bbox_from_mask(mask, w, h, pad=ring) out = base.crop((x0, y0, x1, y1)) else: if mask is None: out = base else: x0, x1, y0, y1 = VReasonQwen2_5_VLForConditionalGeneration._bbox_from_mask(mask, w, h, pad=ring) crop = base.crop((x0, y0, x1, y1)) crop_mask = mask[y0:y1, x0:x1] blurred = crop.filter(ImageFilter.GaussianBlur(radius=blur_radius)) alpha = VReasonQwen2_5_VLForConditionalGeneration._to_alpha(crop_mask, invert=True, feather=feather, size_wh=crop.size) out = Image.composite(blurred, crop, alpha) if roi_wh: out = out.resize((int(roi_wh[0]), int(roi_wh[1])), Image.BICUBIC) out_path.parent.mkdir(parents=True, exist_ok=True) out.save(out_path, format="JPEG", quality=95, subsampling=1, optimize=True) def visual_reason( self, *, processor, image, prompt_text: str = "Based on the provided chest radiograph, explain your diagnosis procedure and write a report.", model_output_text: Optional[str] = None, messages: Optional[list[dict[str, Any]]] = None, max_new_tokens: int = 1024, generation_kwargs: Optional[dict[str, Any]] = None, skip_special_tokens: bool = False, output_dir: Optional[str] = None, generate_roi: bool = False, mask_npy: Optional[str] = None, cxas_gpus: str = "0", viz_mode: str = "blurcrop", context_ring: int = 8, blur_radius: int = 31, feather: int = 6, resize_roi_to: Optional[tuple[int, int]] = None, ) -> dict[str, Any]: """Generate and parse VReason output, optionally producing ROI image artifacts. Args: processor: HF processor matching this model. image: Path or PIL image for frontal CXR. prompt_text: Default user prompt if `messages` is not provided. model_output_text: If provided, skips generation and parses this text directly. messages: Optional fully custom chat messages. output_dir: Directory for `reasoning.json` and ROI images. generate_roi: Whether to render ROI images using CXAS masks. mask_npy: Optional precomputed mask array path ([C,H,W]). """ import torch from PIL import Image if model_output_text is None: if messages is None: msg_image = image if isinstance(image, Path): msg_image = str(image) messages = [ { "role": "user", "content": [ {"type": "image", "image": msg_image}, {"type": "text", "text": prompt_text}, ], } ] prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) images_for_proc = [] for msg in messages: for part in msg.get("content", []): if part.get("type") != "image": continue img_obj = part.get("image") if isinstance(img_obj, Image.Image): images_for_proc.append(img_obj.convert("RGB")) else: images_for_proc.append(Image.open(str(img_obj)).convert("RGB")) inputs = processor(text=[prompt], images=[images_for_proc], return_tensors="pt").to(self.device) kwargs = dict(max_new_tokens=max_new_tokens) if generation_kwargs: kwargs.update(generation_kwargs) with torch.no_grad(): output_ids = self.generate(**inputs, **kwargs) model_output_text = processor.batch_decode(output_ids, skip_special_tokens=skip_special_tokens)[0] assert model_output_text is not None regions = self._parse_reasoning(model_output_text) result: dict[str, Any] = { "text": model_output_text, "reasoning": regions, "finding": self._extract_tag(model_output_text, "finding"), "impression": self._extract_tag(model_output_text, "impression"), "report": self._extract_tag(model_output_text, "report"), "viz_mode": viz_mode, } if generate_roi: if output_dir is None: raise ValueError("output_dir is required when generate_roi=True") out_dir = Path(output_dir).resolve() out_dir.mkdir(parents=True, exist_ok=True) if isinstance(image, Image.Image): base_img = image.convert("RGB") input_image_name = "input.jpg" else: image_path = Path(str(image)).resolve() base_img = Image.open(image_path).convert("RGB") input_image_name = image_path.name if mask_npy: mask_array = np.load(mask_npy).astype(bool) else: try: import cxas_vreason as cxas # type: ignore except Exception: import cxas # type: ignore mask_array = np.asarray(cxas.CXAS(gpus=cxas_gpus).eval().seg(str(image)), dtype=bool) try: from cxas_vreason.label_mapper import name2id as name2id # type: ignore except Exception: from cxas.label_mapper import name2id as name2id # type: ignore for i, region in enumerate(regions): region_labels = region.get("labels") or [region.get("region", "")] region_idx = self._ids_for_many(region_labels, name2id) region_mask = self._mask_union(mask_array, region_idx) region_name = f"anatomy_{i:03d}.jpg" region_path = out_dir / region_name self._save_viz(base_img, region_mask, region_path, viz_mode, blur_radius, feather, context_ring, resize_roi_to) region["path"] = region_name for j, sub in enumerate(region.get("pathological", [])): sub_labels = sub.get("labels") or [sub.get("anatomies", "")] sub_idx = self._ids_for_many(sub_labels, name2id) sub_mask = self._mask_union(mask_array, sub_idx) sub_name = f"pathology_{i:03d}_{j:03d}.jpg" sub_path = out_dir / sub_name self._save_viz(base_img, sub_mask, sub_path, viz_mode, blur_radius, feather, context_ring, resize_roi_to) sub["path"] = sub_name reasoning_json_path = out_dir / "reasoning.json" reasoning_json_path.write_text( json.dumps( { "input_image": input_image_name, "viz_mode": viz_mode, "reasoning": regions, "finding": result["finding"], "impression": result["impression"], "report": result["report"], }, ensure_ascii=False, indent=2, ), encoding="utf-8", ) result["reasoning_json"] = str(reasoning_json_path) return result