video-scan / modeling_marlin.py
cudabenchmarktest's picture
Add files using upload-large-folder tool
f0ab8f1 verified
"""Custom HuggingFace modeling code for Marlin.
This module subclasses the upstream ``Qwen3_5ForConditionalGeneration``
(native in ``transformers >= 5.7.0``) and adds two convenience methods —
:meth:`MarlinForConditionalGeneration.caption` and
:meth:`MarlinForConditionalGeneration.find` — that mirror moondream's
image-SDK ergonomics for video captioning and temporal grounding.
The forward pass is **not** modified: we only add chat-template + generate +
post-processing wrappers. Loading the model through
``AutoModelForCausalLM.from_pretrained(..., trust_remote_code=True)`` returns
this subclass thanks to the ``auto_map`` entry in ``config.json``.
Required environment for video inference (set before importing transformers)::
FORCE_QWENVL_VIDEO_READER=torchcodec
VIDEO_MAX_PIXELS=200704
FPS=2.0
FPS_MAX_FRAMES=240
FPS_MIN_FRAMES=4
System requirements:
* transformers >= 5.7.0
* torch >= 2.11.0
* torchcodec
* qwen-vl-utils >= 0.0.14
* av, pillow
"""
from __future__ import annotations
import os
import re
from typing import Any, Dict, List, Optional, Tuple, TypedDict, Union
import torch
# ``Qwen3_5ForConditionalGeneration`` is the native HF class for Marlin's
# backbone (Qwen3.5-2B with vision tower). It ships in transformers >= 5.7.0.
# We import it lazily-friendly at module top so AutoModelForCausalLM resolution
# works correctly when this file is loaded via ``trust_remote_code=True``.
from transformers import Qwen3_5ForConditionalGeneration
__all__ = [
"CAPTION_PROMPT",
"GROUNDING_PROMPT_TEMPLATE",
"CaptionResult",
"FindResult",
"Event",
"MarlinForConditionalGeneration",
"strip_thinking",
"parse_caption",
"parse_span",
]
# ---------------------------------------------------------------------------
# Canonical training-time prompts — DO NOT EDIT
# ---------------------------------------------------------------------------
#
# These strings must match exactly what the model was fine-tuned on. Diverging
# from them silently degrades quality.
CAPTION_PROMPT: str = (
"Provide a spatial description of this clip followed by time-ranged events.\n"
"For each event, give the time range as <start - end> and a short description."
)
GROUNDING_PROMPT_TEMPLATE: str = (
'Identify the timestamps during which "{event}" takes place. '
'Output the time range as "From <start> to <end>." (numbers in seconds).'
)
# ---------------------------------------------------------------------------
# Thinking-tag stripping
# ---------------------------------------------------------------------------
#
# ms-swift's Marlin training template uses ``add_non_thinking_prefix=True``,
# which prefixes every response with a bare ``<think>\n`` (no close tag). The
# model occasionally also emits a complete ``<think>...</think>`` block. Strip
# both robustly.
_THINK_BLOCK = re.compile(r"<think>.*?</think>\s*", re.DOTALL)
_THINK_PREFIX = re.compile(r"^\s*<think>\s*\n*", re.IGNORECASE)
_THINK_CLOSE = re.compile(r"</think>\s*", re.IGNORECASE)
def strip_thinking(text: str) -> str:
"""Remove ``<think>...</think>`` blocks and bare ``<think>`` prefixes.
Parameters
----------
text:
Raw model output.
Returns
-------
str
The text with any thinking artifacts removed and outer whitespace
stripped.
"""
out = _THINK_BLOCK.sub("", text)
out = _THINK_PREFIX.sub("", out)
out = _THINK_CLOSE.sub("", out)
return out.strip()
# ---------------------------------------------------------------------------
# Mode 1 — dense caption parser
# ---------------------------------------------------------------------------
class Event(TypedDict):
"""A single time-ranged event extracted from a dense caption."""
start: float
end: float
description: str
# Tolerates ``<1.2 - 3.4>`` / ``1.2 - 3.4`` / ``1.2-3.4`` with optional units.
# Unit alternation is ordered longest-first so e.g. ``"1.8 seconds"`` consumes
# the full word instead of leaving ``"econds"`` in the description.
_EVENT_LINE = re.compile(
r"^\s*<?\s*(\d+\.?\d*)\s*(?:seconds?|secs?|s)?\s*-\s*"
r"(\d+\.?\d*)\s*(?:seconds?|secs?|s)?\s*>?\s*[:\-]?\s*(.+?)\s*$"
)
def _parse_events(events_block: str) -> List[Event]:
"""Parse a multi-line events block into a list of :class:`Event` dicts."""
out: List[Event] = []
for raw_line in events_block.splitlines():
line = raw_line.strip()
if not line:
continue
m = _EVENT_LINE.match(line)
if not m:
continue
start = float(m.group(1))
end = float(m.group(2))
desc = m.group(3).strip().lstrip("-").strip()
if end <= start or not desc:
continue
out.append(Event(start=start, end=end, description=desc))
return out
def parse_caption(text: str) -> Tuple[str, str, List[Event]]:
"""Parse a Mode 1 caption into ``(caption, scene, events)``.
The model is trained to produce::
Scene: <one-paragraph spatial description>
Events:
<start - end> <description>
<start - end> <description>
The parser is tolerant: if explicit ``Scene:`` / ``Events:`` headers are
missing, ``scene`` falls back to everything before the first event line and
``events`` is whatever event-shaped lines were detected.
Parameters
----------
text:
Raw model output. Thinking artifacts will be stripped.
Returns
-------
tuple
``(caption, scene, events)`` — the post-thinking full text, the parsed
scene paragraph, and a list of :class:`Event` dicts in emission order.
"""
cleaned = strip_thinking(text)
scene_match = re.search(
r"(?:^|\n)\s*Scene\s*:\s*(.*?)(?=\n\s*Events\s*:|\Z)",
cleaned,
re.IGNORECASE | re.DOTALL,
)
events_match = re.search(
r"(?:^|\n)\s*Events\s*:\s*(.*)\Z",
cleaned,
re.IGNORECASE | re.DOTALL,
)
if scene_match:
scene = scene_match.group(1).strip()
else:
# Fallback: scene = everything before the first event-shaped line.
scene_lines: List[str] = []
for line in cleaned.splitlines():
if _EVENT_LINE.match(line.strip()):
break
scene_lines.append(line)
scene = "\n".join(scene_lines).strip()
events_block = events_match.group(1) if events_match else cleaned
events = _parse_events(events_block)
return cleaned, scene, events
# ---------------------------------------------------------------------------
# Mode 2 — temporal grounding parser
# ---------------------------------------------------------------------------
# Tolerates ``From 1.2 to 3.4.``, ``From 1.2s to 3.4 sec``; trailing period
# optional.
_SPAN_RE = re.compile(
r"From\s+(\d+\.?\d*)\s*(?:s|sec)?\s+to\s+(\d+\.?\d*)\s*(?:s|sec)?\.?",
re.IGNORECASE,
)
def parse_span(text: str) -> Tuple[str, Optional[Tuple[float, float]]]:
"""Parse a Mode 2 grounding output into ``(text, span)``.
Parameters
----------
text:
Raw model output. Thinking artifacts will be stripped.
Returns
-------
tuple
``(cleaned, span)`` — the post-thinking text and ``(start, end)`` in
seconds, or ``None`` if no valid ``"From X to Y"`` substring was found
or the span was non-positive.
"""
cleaned = strip_thinking(text)
m = _SPAN_RE.search(cleaned)
if not m:
return cleaned, None
start = float(m.group(1))
end = float(m.group(2))
if end <= start:
return cleaned, None
return cleaned, (start, end)
# ---------------------------------------------------------------------------
# Result dicts
# ---------------------------------------------------------------------------
class CaptionResult(TypedDict):
"""Return type for :meth:`MarlinForConditionalGeneration.caption`.
Keys
----
caption : str
Post-thinking model output (e.g. ``"Scene: ...\\n\\nEvents:\\n..."``).
scene : str
Parsed ``Scene:`` paragraph.
events : list of :class:`Event`
Parsed ``{start, end, description}`` dicts in emission order.
raw : str
Raw model output *before* thinking-prefix stripping (for debugging).
"""
caption: str
scene: str
events: List[Event]
raw: str
class FindResult(TypedDict):
"""Return type for :meth:`MarlinForConditionalGeneration.find`.
Keys
----
raw : str
Raw post-thinking model output (e.g. ``"From 1.2 to 3.4."``).
span : tuple of (float, float) or None
``(start, end)`` in seconds, or ``None`` if parsing failed.
format_ok : bool
``True`` iff the output matched the trained ``"From X to Y."`` format.
"""
raw: str
span: Optional[Tuple[float, float]]
format_ok: bool
# ---------------------------------------------------------------------------
# Default video-preprocessing env vars
# ---------------------------------------------------------------------------
#
# qwen-vl-utils reads these from the environment when ``apply_chat_template``
# decodes a video. We populate them here as a safety net for users who forget
# to set them before importing transformers. Existing env values are NEVER
# overwritten — explicit user settings always win.
_DEFAULT_VIDEO_ENV: Dict[str, str] = {
"FORCE_QWENVL_VIDEO_READER": "torchcodec",
"VIDEO_MAX_PIXELS": "200704",
"FPS": "2.0",
"FPS_MAX_FRAMES": "240",
"FPS_MIN_FRAMES": "4",
}
for _k, _v in _DEFAULT_VIDEO_ENV.items():
os.environ.setdefault(_k, _v)
# ---------------------------------------------------------------------------
# The actual model class
# ---------------------------------------------------------------------------
class MarlinForConditionalGeneration(Qwen3_5ForConditionalGeneration):
"""Marlin with ``.caption()`` and ``.find()`` convenience methods.
Inherits the full forward / generate / from_pretrained machinery from
:class:`transformers.Qwen3_5ForConditionalGeneration`; only adds two
helpers that wrap chat-template construction, generation, and the trained
output parsers.
Use it via the standard auto class::
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained(
"NemoStation/Marlin-2B",
trust_remote_code=True,
dtype=torch.bfloat16,
device_map={"": "cuda"},
)
result = model.caption("video.mp4")
span = model.find("video.mp4", event="a person enters the room")
"""
# ------------------------------------------------------------------ utils
@property
def processor(self): # type: ignore[override]
"""Lazily-loaded :class:`~transformers.AutoProcessor` for this checkpoint.
Cached on the instance to avoid the expensive HF Hub lookup on every
call.
"""
cached = getattr(self, "_processor", None)
if cached is None:
from transformers import AutoProcessor
cached = AutoProcessor.from_pretrained(
self.config._name_or_path,
trust_remote_code=True,
)
self._processor = cached
return cached
def compile(self, *args: Any, **kwargs: Any) -> "MarlinForConditionalGeneration":
"""Optional ``torch.compile`` wrapper around the model.
Returns ``self`` so it chains naturally after ``from_pretrained``::
model = AutoModelForCausalLM.from_pretrained(...).compile()
All positional / keyword args are forwarded to ``torch.compile``.
"""
# ``torch.compile`` replaces the module's forward with a compiled
# version in-place; we still return self for fluent chaining.
torch.compile(self, *args, **kwargs)
return self
# ----------------------------------------------------------- core generate
def _generate_video(
self,
video_path: Union[str, os.PathLike],
prompt: str,
max_tokens: int,
*,
do_sample: bool = False,
temperature: float = 1.0,
top_p: float = 1.0,
) -> str:
"""Build a chat message with one video + one text turn and decode.
Returns the raw decoded string (with any ``<think>`` artifacts still
attached — callers are expected to run :func:`strip_thinking`).
"""
messages = [
{
"role": "user",
"content": [
{"type": "video", "video": str(video_path)},
{"type": "text", "text": prompt},
],
}
]
inputs = self.processor.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=True,
return_tensors="pt",
return_dict=True,
).to(self.device)
with torch.inference_mode():
out = self.generate(
**inputs,
max_new_tokens=max_tokens,
do_sample=do_sample,
temperature=temperature if do_sample else 1.0,
top_p=top_p if do_sample else 1.0,
)
# Strip the prompt prefix so we only return the model's continuation.
prompt_len = inputs["input_ids"].shape[1]
out = out[:, prompt_len:]
return self.processor.batch_decode(out, skip_special_tokens=True)[0]
# ---------------------------------------------------------------- caption
def caption(
self,
video_path: Union[str, os.PathLike],
*,
prompt: Optional[str] = None,
do_sample: bool = False,
temperature: float = 1.0,
top_p: float = 1.0,
max_new_tokens: int = 2048,
) -> CaptionResult:
"""Generate a dense caption for a video.
Parameters
----------
video_path:
Local path to a video file (mp4, webm, etc.).
prompt:
Override the canonical training prompt. Almost always leave at
``None``; diverging from training silently degrades quality.
do_sample:
If ``True``, switch to nucleus sampling. Defaults to greedy.
temperature, top_p:
Sampling params, only used when ``do_sample=True``.
max_new_tokens:
Generation cap. Default 2048 is enough for any dense caption the
model produces in practice.
Returns
-------
CaptionResult
Dict with keys ``caption``, ``scene``, ``events``, ``raw``.
"""
prompt_text = prompt if prompt is not None else CAPTION_PROMPT
raw = self._generate_video(
video_path,
prompt_text,
max_tokens=max_new_tokens,
do_sample=do_sample,
temperature=temperature,
top_p=top_p,
)
cleaned, scene, events = parse_caption(raw)
return CaptionResult(
caption=cleaned,
scene=scene,
events=events,
raw=raw,
)
# ------------------------------------------------------------------- find
def find(
self,
video_path: Union[str, os.PathLike],
event: str,
*,
prompt_template: Optional[str] = None,
do_sample: bool = False,
temperature: float = 1.0,
top_p: float = 1.0,
max_new_tokens: int = 64,
) -> FindResult:
"""Locate when a natural-language event occurs in a video.
Parameters
----------
video_path:
Local path to a video file.
event:
Free-form description of the event to locate, e.g.
``"a person enters the room"``. Inserted into the trained prompt
via the ``{event}`` placeholder.
prompt_template:
Override the canonical training prompt template. Must include a
``{event}`` placeholder. Almost always leave at ``None``.
do_sample:
If ``True``, switch to nucleus sampling. Defaults to greedy.
temperature, top_p:
Sampling params, only used when ``do_sample=True``.
max_new_tokens:
Output budget. 64 is plenty for the one-line trained format.
Returns
-------
FindResult
Dict with keys ``raw``, ``span`` and ``format_ok``.
Raises
------
ValueError
If ``event`` is empty or whitespace-only.
"""
event_str = (event or "").strip()
if not event_str:
raise ValueError("`event` must be a non-empty string")
template = prompt_template if prompt_template is not None else GROUNDING_PROMPT_TEMPLATE
prompt_text = template.format(event=event_str)
raw = self._generate_video(
video_path,
prompt_text,
max_tokens=max_new_tokens,
do_sample=do_sample,
temperature=temperature,
top_p=top_p,
)
cleaned, span = parse_span(raw)
return FindResult(
raw=cleaned,
span=span,
format_ok=span is not None,
)