"""Custom HuggingFace modeling code for Marlin. This module subclasses the upstream ``Qwen3_5ForConditionalGeneration`` (native in ``transformers >= 5.7.0``) and adds two convenience methods — :meth:`MarlinForConditionalGeneration.caption` and :meth:`MarlinForConditionalGeneration.find` — that mirror moondream's image-SDK ergonomics for video captioning and temporal grounding. The forward pass is **not** modified: we only add chat-template + generate + post-processing wrappers. Loading the model through ``AutoModelForCausalLM.from_pretrained(..., trust_remote_code=True)`` returns this subclass thanks to the ``auto_map`` entry in ``config.json``. Required environment for video inference (set before importing transformers):: FORCE_QWENVL_VIDEO_READER=torchcodec VIDEO_MAX_PIXELS=200704 FPS=2.0 FPS_MAX_FRAMES=240 FPS_MIN_FRAMES=4 System requirements: * transformers >= 5.7.0 * torch >= 2.11.0 * torchcodec * qwen-vl-utils >= 0.0.14 * av, pillow """ from __future__ import annotations import os import re from typing import Any, Dict, List, Optional, Tuple, TypedDict, Union import torch # ``Qwen3_5ForConditionalGeneration`` is the native HF class for Marlin's # backbone (Qwen3.5-2B with vision tower). It ships in transformers >= 5.7.0. # We import it lazily-friendly at module top so AutoModelForCausalLM resolution # works correctly when this file is loaded via ``trust_remote_code=True``. from transformers import Qwen3_5ForConditionalGeneration __all__ = [ "CAPTION_PROMPT", "GROUNDING_PROMPT_TEMPLATE", "CaptionResult", "FindResult", "Event", "MarlinForConditionalGeneration", "strip_thinking", "parse_caption", "parse_span", ] # --------------------------------------------------------------------------- # Canonical training-time prompts — DO NOT EDIT # --------------------------------------------------------------------------- # # These strings must match exactly what the model was fine-tuned on. Diverging # from them silently degrades quality. CAPTION_PROMPT: str = ( "Provide a spatial description of this clip followed by time-ranged events.\n" "For each event, give the time range as and a short description." ) GROUNDING_PROMPT_TEMPLATE: str = ( 'Identify the timestamps during which "{event}" takes place. ' 'Output the time range as "From to ." (numbers in seconds).' ) # --------------------------------------------------------------------------- # Thinking-tag stripping # --------------------------------------------------------------------------- # # ms-swift's Marlin training template uses ``add_non_thinking_prefix=True``, # which prefixes every response with a bare ``\n`` (no close tag). The # model occasionally also emits a complete ``...`` block. Strip # both robustly. _THINK_BLOCK = re.compile(r".*?\s*", re.DOTALL) _THINK_PREFIX = re.compile(r"^\s*\s*\n*", re.IGNORECASE) _THINK_CLOSE = re.compile(r"\s*", re.IGNORECASE) def strip_thinking(text: str) -> str: """Remove ``...`` blocks and bare ```` prefixes. Parameters ---------- text: Raw model output. Returns ------- str The text with any thinking artifacts removed and outer whitespace stripped. """ out = _THINK_BLOCK.sub("", text) out = _THINK_PREFIX.sub("", out) out = _THINK_CLOSE.sub("", out) return out.strip() # --------------------------------------------------------------------------- # Mode 1 — dense caption parser # --------------------------------------------------------------------------- class Event(TypedDict): """A single time-ranged event extracted from a dense caption.""" start: float end: float description: str # Tolerates ``<1.2 - 3.4>`` / ``1.2 - 3.4`` / ``1.2-3.4`` with optional units. # Unit alternation is ordered longest-first so e.g. ``"1.8 seconds"`` consumes # the full word instead of leaving ``"econds"`` in the description. _EVENT_LINE = re.compile( r"^\s*?\s*[:\-]?\s*(.+?)\s*$" ) def _parse_events(events_block: str) -> List[Event]: """Parse a multi-line events block into a list of :class:`Event` dicts.""" out: List[Event] = [] for raw_line in events_block.splitlines(): line = raw_line.strip() if not line: continue m = _EVENT_LINE.match(line) if not m: continue start = float(m.group(1)) end = float(m.group(2)) desc = m.group(3).strip().lstrip("-").strip() if end <= start or not desc: continue out.append(Event(start=start, end=end, description=desc)) return out def parse_caption(text: str) -> Tuple[str, str, List[Event]]: """Parse a Mode 1 caption into ``(caption, scene, events)``. The model is trained to produce:: Scene: Events: The parser is tolerant: if explicit ``Scene:`` / ``Events:`` headers are missing, ``scene`` falls back to everything before the first event line and ``events`` is whatever event-shaped lines were detected. Parameters ---------- text: Raw model output. Thinking artifacts will be stripped. Returns ------- tuple ``(caption, scene, events)`` — the post-thinking full text, the parsed scene paragraph, and a list of :class:`Event` dicts in emission order. """ cleaned = strip_thinking(text) scene_match = re.search( r"(?:^|\n)\s*Scene\s*:\s*(.*?)(?=\n\s*Events\s*:|\Z)", cleaned, re.IGNORECASE | re.DOTALL, ) events_match = re.search( r"(?:^|\n)\s*Events\s*:\s*(.*)\Z", cleaned, re.IGNORECASE | re.DOTALL, ) if scene_match: scene = scene_match.group(1).strip() else: # Fallback: scene = everything before the first event-shaped line. scene_lines: List[str] = [] for line in cleaned.splitlines(): if _EVENT_LINE.match(line.strip()): break scene_lines.append(line) scene = "\n".join(scene_lines).strip() events_block = events_match.group(1) if events_match else cleaned events = _parse_events(events_block) return cleaned, scene, events # --------------------------------------------------------------------------- # Mode 2 — temporal grounding parser # --------------------------------------------------------------------------- # Tolerates ``From 1.2 to 3.4.``, ``From 1.2s to 3.4 sec``; trailing period # optional. _SPAN_RE = re.compile( r"From\s+(\d+\.?\d*)\s*(?:s|sec)?\s+to\s+(\d+\.?\d*)\s*(?:s|sec)?\.?", re.IGNORECASE, ) def parse_span(text: str) -> Tuple[str, Optional[Tuple[float, float]]]: """Parse a Mode 2 grounding output into ``(text, span)``. Parameters ---------- text: Raw model output. Thinking artifacts will be stripped. Returns ------- tuple ``(cleaned, span)`` — the post-thinking text and ``(start, end)`` in seconds, or ``None`` if no valid ``"From X to Y"`` substring was found or the span was non-positive. """ cleaned = strip_thinking(text) m = _SPAN_RE.search(cleaned) if not m: return cleaned, None start = float(m.group(1)) end = float(m.group(2)) if end <= start: return cleaned, None return cleaned, (start, end) # --------------------------------------------------------------------------- # Result dicts # --------------------------------------------------------------------------- class CaptionResult(TypedDict): """Return type for :meth:`MarlinForConditionalGeneration.caption`. Keys ---- caption : str Post-thinking model output (e.g. ``"Scene: ...\\n\\nEvents:\\n..."``). scene : str Parsed ``Scene:`` paragraph. events : list of :class:`Event` Parsed ``{start, end, description}`` dicts in emission order. raw : str Raw model output *before* thinking-prefix stripping (for debugging). """ caption: str scene: str events: List[Event] raw: str class FindResult(TypedDict): """Return type for :meth:`MarlinForConditionalGeneration.find`. Keys ---- raw : str Raw post-thinking model output (e.g. ``"From 1.2 to 3.4."``). span : tuple of (float, float) or None ``(start, end)`` in seconds, or ``None`` if parsing failed. format_ok : bool ``True`` iff the output matched the trained ``"From X to Y."`` format. """ raw: str span: Optional[Tuple[float, float]] format_ok: bool # --------------------------------------------------------------------------- # Default video-preprocessing env vars # --------------------------------------------------------------------------- # # qwen-vl-utils reads these from the environment when ``apply_chat_template`` # decodes a video. We populate them here as a safety net for users who forget # to set them before importing transformers. Existing env values are NEVER # overwritten — explicit user settings always win. _DEFAULT_VIDEO_ENV: Dict[str, str] = { "FORCE_QWENVL_VIDEO_READER": "torchcodec", "VIDEO_MAX_PIXELS": "200704", "FPS": "2.0", "FPS_MAX_FRAMES": "240", "FPS_MIN_FRAMES": "4", } for _k, _v in _DEFAULT_VIDEO_ENV.items(): os.environ.setdefault(_k, _v) # --------------------------------------------------------------------------- # The actual model class # --------------------------------------------------------------------------- class MarlinForConditionalGeneration(Qwen3_5ForConditionalGeneration): """Marlin with ``.caption()`` and ``.find()`` convenience methods. Inherits the full forward / generate / from_pretrained machinery from :class:`transformers.Qwen3_5ForConditionalGeneration`; only adds two helpers that wrap chat-template construction, generation, and the trained output parsers. Use it via the standard auto class:: from transformers import AutoModelForCausalLM model = AutoModelForCausalLM.from_pretrained( "NemoStation/Marlin-2B", trust_remote_code=True, dtype=torch.bfloat16, device_map={"": "cuda"}, ) result = model.caption("video.mp4") span = model.find("video.mp4", event="a person enters the room") """ # ------------------------------------------------------------------ utils @property def processor(self): # type: ignore[override] """Lazily-loaded :class:`~transformers.AutoProcessor` for this checkpoint. Cached on the instance to avoid the expensive HF Hub lookup on every call. """ cached = getattr(self, "_processor", None) if cached is None: from transformers import AutoProcessor cached = AutoProcessor.from_pretrained( self.config._name_or_path, trust_remote_code=True, ) self._processor = cached return cached def compile(self, *args: Any, **kwargs: Any) -> "MarlinForConditionalGeneration": """Optional ``torch.compile`` wrapper around the model. Returns ``self`` so it chains naturally after ``from_pretrained``:: model = AutoModelForCausalLM.from_pretrained(...).compile() All positional / keyword args are forwarded to ``torch.compile``. """ # ``torch.compile`` replaces the module's forward with a compiled # version in-place; we still return self for fluent chaining. torch.compile(self, *args, **kwargs) return self # ----------------------------------------------------------- core generate def _generate_video( self, video_path: Union[str, os.PathLike], prompt: str, max_tokens: int, *, do_sample: bool = False, temperature: float = 1.0, top_p: float = 1.0, ) -> str: """Build a chat message with one video + one text turn and decode. Returns the raw decoded string (with any ```` artifacts still attached — callers are expected to run :func:`strip_thinking`). """ messages = [ { "role": "user", "content": [ {"type": "video", "video": str(video_path)}, {"type": "text", "text": prompt}, ], } ] inputs = self.processor.apply_chat_template( messages, tokenize=True, add_generation_prompt=True, return_tensors="pt", return_dict=True, ).to(self.device) with torch.inference_mode(): out = self.generate( **inputs, max_new_tokens=max_tokens, do_sample=do_sample, temperature=temperature if do_sample else 1.0, top_p=top_p if do_sample else 1.0, ) # Strip the prompt prefix so we only return the model's continuation. prompt_len = inputs["input_ids"].shape[1] out = out[:, prompt_len:] return self.processor.batch_decode(out, skip_special_tokens=True)[0] # ---------------------------------------------------------------- caption def caption( self, video_path: Union[str, os.PathLike], *, prompt: Optional[str] = None, do_sample: bool = False, temperature: float = 1.0, top_p: float = 1.0, max_new_tokens: int = 2048, ) -> CaptionResult: """Generate a dense caption for a video. Parameters ---------- video_path: Local path to a video file (mp4, webm, etc.). prompt: Override the canonical training prompt. Almost always leave at ``None``; diverging from training silently degrades quality. do_sample: If ``True``, switch to nucleus sampling. Defaults to greedy. temperature, top_p: Sampling params, only used when ``do_sample=True``. max_new_tokens: Generation cap. Default 2048 is enough for any dense caption the model produces in practice. Returns ------- CaptionResult Dict with keys ``caption``, ``scene``, ``events``, ``raw``. """ prompt_text = prompt if prompt is not None else CAPTION_PROMPT raw = self._generate_video( video_path, prompt_text, max_tokens=max_new_tokens, do_sample=do_sample, temperature=temperature, top_p=top_p, ) cleaned, scene, events = parse_caption(raw) return CaptionResult( caption=cleaned, scene=scene, events=events, raw=raw, ) # ------------------------------------------------------------------- find def find( self, video_path: Union[str, os.PathLike], event: str, *, prompt_template: Optional[str] = None, do_sample: bool = False, temperature: float = 1.0, top_p: float = 1.0, max_new_tokens: int = 64, ) -> FindResult: """Locate when a natural-language event occurs in a video. Parameters ---------- video_path: Local path to a video file. event: Free-form description of the event to locate, e.g. ``"a person enters the room"``. Inserted into the trained prompt via the ``{event}`` placeholder. prompt_template: Override the canonical training prompt template. Must include a ``{event}`` placeholder. Almost always leave at ``None``. do_sample: If ``True``, switch to nucleus sampling. Defaults to greedy. temperature, top_p: Sampling params, only used when ``do_sample=True``. max_new_tokens: Output budget. 64 is plenty for the one-line trained format. Returns ------- FindResult Dict with keys ``raw``, ``span`` and ``format_ok``. Raises ------ ValueError If ``event`` is empty or whitespace-only. """ event_str = (event or "").strip() if not event_str: raise ValueError("`event` must be a non-empty string") template = prompt_template if prompt_template is not None else GROUNDING_PROMPT_TEMPLATE prompt_text = template.format(event=event_str) raw = self._generate_video( video_path, prompt_text, max_tokens=max_new_tokens, do_sample=do_sample, temperature=temperature, top_p=top_p, ) cleaned, span = parse_span(raw) return FindResult( raw=cleaned, span=span, format_ok=span is not None, )