VReason-QwenVL / modeling_vreason.py
EvidenceAIResearch's picture
Upload vreason-huatuo model artifacts
35beca7 verified
Raw
History Blame Contribute Delete
14.6 kB
from __future__ import annotations
import ast
import json
import re
from pathlib import Path
from typing import Any, Optional
import numpy as np
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
Qwen2_5_VLForConditionalGeneration,
)
_TOOL_RE = re.compile(
r'<tool\s+type="(?P<tool_type>[^"]+)"\s+label=(?P<labels>\[[^\]]*\])\s*>\s*<image>',
re.IGNORECASE,
)
_TOKEN_RE = re.compile(
r'Reviewing\s+.+?\.\.\.|Inspecting\s+.+?\.\.\.|<tool\s+type="[^"]+"\s+label=\[[^\]]*\]\s*>\s*<image>',
re.IGNORECASE | re.DOTALL,
)
class VReasonQwen2_5_VLForConditionalGeneration(Qwen2_5_VLForConditionalGeneration):
"""Qwen2.5-VL model with VReason helper methods.
Loaded via:
AutoModelForVision2Seq.from_pretrained(..., trust_remote_code=True)
Extra method:
model.visual_reason(...)
"""
@staticmethod
def _parse_labels(raw: str) -> list[str]:
try:
parsed = json.loads(raw)
if isinstance(parsed, list):
return [str(x).strip() for x in parsed if str(x).strip()]
except Exception:
pass
try:
parsed = ast.literal_eval(raw)
if isinstance(parsed, list):
return [str(x).strip() for x in parsed if str(x).strip()]
except Exception:
pass
return [x.strip().strip('"').strip("'") for x in raw.strip("[]").split(",") if x.strip()]
@staticmethod
def _extract_tag(text: str, tag: str) -> str:
m = re.search(rf"<{tag}>(.*?)</{tag}>", text, flags=re.IGNORECASE | re.DOTALL)
return m.group(1).strip() if m else ""
@classmethod
def _parse_reasoning(cls, model_output: str) -> list[dict[str, Any]]:
interpret = cls._extract_tag(model_output, "interpret")
if not interpret:
return []
regions: list[dict[str, Any]] = []
current_region: Optional[dict[str, Any]] = None
current_sub: Optional[dict[str, Any]] = None
cursor = 0
for tok in _TOKEN_RE.finditer(interpret):
between = interpret[cursor:tok.start()].strip()
if between and current_sub is not None:
current_sub["reason"] = (current_sub["reason"] + " " + between).strip()
token = tok.group(0)
low = token.lower()
if low.startswith("reviewing "):
region = token[len("Reviewing ") :].strip()
if region.endswith("..."):
region = region[:-3].strip()
current_region = {
"region": region,
"labels": [],
"path": None,
"pathological": [],
}
regions.append(current_region)
current_sub = None
elif low.startswith("inspecting ") and current_region is not None:
anatomy = token[len("Inspecting ") :].strip()
if anatomy.endswith("..."):
anatomy = anatomy[:-3].strip()
current_sub = {
"anatomies": anatomy,
"labels": [],
"reason": "",
"path": None,
}
current_region["pathological"].append(current_sub)
else:
m = _TOOL_RE.search(token)
if m:
labels = cls._parse_labels(m.group("labels"))
tool_type = m.group("tool_type").lower()
if tool_type == "anatomical_roi" and current_region is not None:
current_region["labels"] = labels
elif tool_type == "pathological_roi" and current_sub is not None:
current_sub["labels"] = labels
cursor = tok.end()
trailing = interpret[cursor:].strip()
if trailing and current_sub is not None:
current_sub["reason"] = (current_sub["reason"] + " " + trailing).strip()
for r in regions:
for p in r["pathological"]:
p["reason"] = re.sub(r"\s+", " ", p["reason"]).strip()
return regions
@staticmethod
def _ids_for_many(names: list[str], name2id: dict[str, list[int]]) -> list[int]:
out: list[int] = []
seen = set()
for name in names:
key = str(name).strip().lower()
if key in name2id:
for idx in name2id[key]:
idx = int(idx)
if idx not in seen:
seen.add(idx)
out.append(idx)
return out
@staticmethod
def _mask_union(mask_array: np.ndarray, indices: list[int]) -> Optional[np.ndarray]:
if mask_array is None or not indices:
return None
safe = [i for i in indices if 0 <= i < mask_array.shape[0]]
if not safe:
return None
return mask_array[safe].any(axis=0)
@staticmethod
def _bbox_from_mask(mask: Optional[np.ndarray], width: int, height: int, pad: int = 0) -> tuple[int, int, int, int]:
if mask is None or not mask.any():
return 0, width, 0, height
ys, xs = np.where(mask)
x0 = max(0, int(xs.min()) - pad)
x1 = min(width, int(xs.max()) + pad)
y0 = max(0, int(ys.min()) - pad)
y1 = min(height, int(ys.max()) + pad)
return x0, x1, y0, y1
@staticmethod
def _to_alpha(mask: Optional[np.ndarray], invert: bool, feather: int, size_wh: tuple[int, int]):
from PIL import Image
import cv2
if mask is None:
return Image.new("L", size_wh, color=255 if invert else 0)
selected = (~mask if invert else mask).astype(np.uint8) * 255
if size_wh != (mask.shape[1], mask.shape[0]):
selected = cv2.resize(selected, size_wh, interpolation=cv2.INTER_NEAREST)
if feather > 0:
selected = cv2.GaussianBlur(selected, ksize=(0, 0), sigmaX=feather, sigmaY=feather)
return Image.fromarray(selected, mode="L")
@staticmethod
def _save_viz(
img_base,
mask: Optional[np.ndarray],
out_path: Path,
mode: str,
blur_radius: int,
feather: int,
ring: int,
roi_wh: Optional[tuple[int, int]],
) -> None:
import cv2
from PIL import Image, ImageFilter
base = img_base.convert("RGB")
w, h = base.size
if mask is not None and (mask.shape[1], mask.shape[0]) != (w, h):
mask = cv2.resize(mask.astype(np.uint8), (w, h), interpolation=cv2.INTER_NEAREST).astype(bool)
if mode == "blur":
blurred = base.filter(ImageFilter.GaussianBlur(radius=blur_radius))
alpha = VReasonQwen2_5_VLForConditionalGeneration._to_alpha(mask, invert=True, feather=feather, size_wh=base.size)
out = Image.composite(blurred, base, alpha)
elif mode == "crop":
x0, x1, y0, y1 = VReasonQwen2_5_VLForConditionalGeneration._bbox_from_mask(mask, w, h, pad=ring)
out = base.crop((x0, y0, x1, y1))
else:
if mask is None:
out = base
else:
x0, x1, y0, y1 = VReasonQwen2_5_VLForConditionalGeneration._bbox_from_mask(mask, w, h, pad=ring)
crop = base.crop((x0, y0, x1, y1))
crop_mask = mask[y0:y1, x0:x1]
blurred = crop.filter(ImageFilter.GaussianBlur(radius=blur_radius))
alpha = VReasonQwen2_5_VLForConditionalGeneration._to_alpha(crop_mask, invert=True, feather=feather, size_wh=crop.size)
out = Image.composite(blurred, crop, alpha)
if roi_wh:
out = out.resize((int(roi_wh[0]), int(roi_wh[1])), Image.BICUBIC)
out_path.parent.mkdir(parents=True, exist_ok=True)
out.save(out_path, format="JPEG", quality=95, subsampling=1, optimize=True)
def visual_reason(
self,
*,
processor,
image,
prompt_text: str = "Based on the provided chest radiograph, explain your diagnosis procedure and write a report.",
model_output_text: Optional[str] = None,
messages: Optional[list[dict[str, Any]]] = None,
max_new_tokens: int = 1024,
generation_kwargs: Optional[dict[str, Any]] = None,
skip_special_tokens: bool = False,
output_dir: Optional[str] = None,
generate_roi: bool = False,
mask_npy: Optional[str] = None,
cxas_gpus: str = "0",
viz_mode: str = "blurcrop",
context_ring: int = 8,
blur_radius: int = 31,
feather: int = 6,
resize_roi_to: Optional[tuple[int, int]] = None,
) -> dict[str, Any]:
"""Generate and parse VReason output, optionally producing ROI image artifacts.
Args:
processor: HF processor matching this model.
image: Path or PIL image for frontal CXR.
prompt_text: Default user prompt if `messages` is not provided.
model_output_text: If provided, skips generation and parses this text directly.
messages: Optional fully custom chat messages.
output_dir: Directory for `reasoning.json` and ROI images.
generate_roi: Whether to render ROI images using CXAS masks.
mask_npy: Optional precomputed mask array path ([C,H,W]).
"""
import torch
from PIL import Image
if model_output_text is None:
if messages is None:
msg_image = image
if isinstance(image, Path):
msg_image = str(image)
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": msg_image},
{"type": "text", "text": prompt_text},
],
}
]
prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
images_for_proc = []
for msg in messages:
for part in msg.get("content", []):
if part.get("type") != "image":
continue
img_obj = part.get("image")
if isinstance(img_obj, Image.Image):
images_for_proc.append(img_obj.convert("RGB"))
else:
images_for_proc.append(Image.open(str(img_obj)).convert("RGB"))
inputs = processor(text=[prompt], images=[images_for_proc], return_tensors="pt").to(self.device)
kwargs = dict(max_new_tokens=max_new_tokens)
if generation_kwargs:
kwargs.update(generation_kwargs)
with torch.no_grad():
output_ids = self.generate(**inputs, **kwargs)
model_output_text = processor.batch_decode(output_ids, skip_special_tokens=skip_special_tokens)[0]
assert model_output_text is not None
regions = self._parse_reasoning(model_output_text)
result: dict[str, Any] = {
"text": model_output_text,
"reasoning": regions,
"finding": self._extract_tag(model_output_text, "finding"),
"impression": self._extract_tag(model_output_text, "impression"),
"report": self._extract_tag(model_output_text, "report"),
"viz_mode": viz_mode,
}
if generate_roi:
if output_dir is None:
raise ValueError("output_dir is required when generate_roi=True")
out_dir = Path(output_dir).resolve()
out_dir.mkdir(parents=True, exist_ok=True)
if isinstance(image, Image.Image):
base_img = image.convert("RGB")
input_image_name = "input.jpg"
else:
image_path = Path(str(image)).resolve()
base_img = Image.open(image_path).convert("RGB")
input_image_name = image_path.name
if mask_npy:
mask_array = np.load(mask_npy).astype(bool)
else:
try:
import cxas_vreason as cxas # type: ignore
except Exception:
import cxas # type: ignore
mask_array = np.asarray(cxas.CXAS(gpus=cxas_gpus).eval().seg(str(image)), dtype=bool)
try:
from cxas_vreason.label_mapper import name2id as name2id # type: ignore
except Exception:
from cxas.label_mapper import name2id as name2id # type: ignore
for i, region in enumerate(regions):
region_labels = region.get("labels") or [region.get("region", "")]
region_idx = self._ids_for_many(region_labels, name2id)
region_mask = self._mask_union(mask_array, region_idx)
region_name = f"anatomy_{i:03d}.jpg"
region_path = out_dir / region_name
self._save_viz(base_img, region_mask, region_path, viz_mode, blur_radius, feather, context_ring, resize_roi_to)
region["path"] = region_name
for j, sub in enumerate(region.get("pathological", [])):
sub_labels = sub.get("labels") or [sub.get("anatomies", "")]
sub_idx = self._ids_for_many(sub_labels, name2id)
sub_mask = self._mask_union(mask_array, sub_idx)
sub_name = f"pathology_{i:03d}_{j:03d}.jpg"
sub_path = out_dir / sub_name
self._save_viz(base_img, sub_mask, sub_path, viz_mode, blur_radius, feather, context_ring, resize_roi_to)
sub["path"] = sub_name
reasoning_json_path = out_dir / "reasoning.json"
reasoning_json_path.write_text(
json.dumps(
{
"input_image": input_image_name,
"viz_mode": viz_mode,
"reasoning": regions,
"finding": result["finding"],
"impression": result["impression"],
"report": result["report"],
},
ensure_ascii=False,
indent=2,
),
encoding="utf-8",
)
result["reasoning_json"] = str(reasoning_json_path)
return result