Spaces:
Running
Running
| """ | |
| VEFX-Reward: Qwen3-VL based reward model for video editing quality assessment. | |
| Extends Qwen3VLForConditionalGeneration with an rm_head for ordinal regression, | |
| scoring video edits on Instructional Following (IF), Render Quality (RQ), | |
| and Edit Exclusivity (EE) on a 1–4 scale. | |
| """ | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| from typing import List, Optional | |
| from transformers import Qwen3VLForConditionalGeneration | |
| class Qwen3VLRewardModelBT(Qwen3VLForConditionalGeneration): | |
| """Qwen3-VL with a reward head for ordinal video edit quality scoring.""" | |
| def __init__(self, config, output_dim=3, reward_token="special", | |
| special_token_ids=None, use_ordinal=True, num_classes=4, **kwargs): | |
| if 'use_cache' in kwargs: | |
| config.use_cache = kwargs.pop('use_cache') | |
| super().__init__(config, **kwargs) | |
| self.output_dim = output_dim | |
| self.rm_head = nn.Linear(config.text_config.hidden_size, output_dim, bias=False) | |
| nn.init.normal_(self.rm_head.weight, mean=0.0, std=1.0 / config.text_config.hidden_size) | |
| self.reward_token = reward_token | |
| self.use_ordinal = use_ordinal | |
| self.num_classes = num_classes | |
| self.special_token_ids = special_token_ids | |
| if self.special_token_ids is not None: | |
| self.reward_token = "special" | |
| def forward( | |
| self, | |
| input_ids: torch.LongTensor = None, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| position_ids: Optional[torch.LongTensor] = None, | |
| past_key_values: Optional[List[torch.FloatTensor]] = None, | |
| inputs_embeds: Optional[torch.FloatTensor] = None, | |
| labels: Optional[torch.LongTensor] = None, | |
| use_cache: Optional[bool] = None, | |
| output_attentions: Optional[bool] = None, | |
| output_hidden_states: Optional[bool] = None, | |
| return_dict: Optional[bool] = None, | |
| pixel_values: Optional[torch.Tensor] = None, | |
| pixel_values_videos: Optional[torch.FloatTensor] = None, | |
| image_grid_thw: Optional[torch.LongTensor] = None, | |
| video_grid_thw: Optional[torch.LongTensor] = None, | |
| mm_token_type_ids: Optional[torch.IntTensor] = None, | |
| **kwargs, | |
| ): | |
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions | |
| output_hidden_states = ( | |
| output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states | |
| ) | |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
| outputs = self.model( | |
| input_ids=input_ids, | |
| position_ids=position_ids, | |
| attention_mask=attention_mask, | |
| past_key_values=past_key_values, | |
| inputs_embeds=inputs_embeds, | |
| pixel_values=pixel_values, | |
| pixel_values_videos=pixel_values_videos, | |
| image_grid_thw=image_grid_thw, | |
| video_grid_thw=video_grid_thw, | |
| mm_token_type_ids=mm_token_type_ids, | |
| output_attentions=output_attentions, | |
| output_hidden_states=output_hidden_states, | |
| return_dict=return_dict, | |
| **kwargs, | |
| ) | |
| hidden_states = outputs[0] # [B, L, D] | |
| logits = self.rm_head(hidden_states) # [B, L, output_dim] | |
| if input_ids is not None: | |
| batch_size = input_ids.shape[0] | |
| else: | |
| batch_size = inputs_embeds.shape[0] | |
| pad_token_id = self.config.text_config.pad_token_id | |
| if pad_token_id is None and batch_size != 1: | |
| raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") | |
| if pad_token_id is None: | |
| sequence_lengths = -1 | |
| else: | |
| if input_ids is not None: | |
| sequence_lengths = torch.eq(input_ids, pad_token_id).int().argmax(-1) - 1 | |
| sequence_lengths = sequence_lengths % input_ids.shape[-1] | |
| sequence_lengths = sequence_lengths.to(logits.device) | |
| else: | |
| sequence_lengths = -1 | |
| if self.reward_token == "last": | |
| pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] | |
| elif self.reward_token == "mean": | |
| valid_lengths = torch.clamp(sequence_lengths, min=0, max=logits.size(1) - 1) | |
| pooled_logits = torch.stack([logits[i, :valid_lengths[i]].mean(dim=0) for i in range(batch_size)]) | |
| elif self.reward_token == "special": | |
| special_token_mask = torch.zeros_like(input_ids, dtype=torch.bool) | |
| for special_token_id in self.special_token_ids: | |
| special_token_mask = special_token_mask | (input_ids == special_token_id) | |
| pooled_logits = logits[special_token_mask, ...] | |
| num_matched = special_token_mask.sum(dim=1) | |
| num_dims = num_matched[0].item() | |
| pooled_logits = pooled_logits.view(batch_size, num_dims, -1) | |
| if self.use_ordinal: | |
| pooled_logits = pooled_logits.view(batch_size, -1) | |
| else: | |
| if self.output_dim == num_dims: | |
| pooled_logits = pooled_logits.diagonal(dim1=1, dim2=2) | |
| pooled_logits = pooled_logits.view(batch_size, -1) | |
| else: | |
| raise ValueError(f"Invalid reward_token: {self.reward_token}") | |
| return {"logits": pooled_logits} | |
| def ordinal_predict(logits: np.ndarray, num_classes: int): | |
| """ | |
| Convert CORN ordinal logits to predicted scores. | |
| Args: | |
| logits: [B, D, K-1] raw threshold logits | |
| num_classes: K (number of ordinal classes) | |
| Returns: | |
| hard_preds: [B, D] integer predictions in {1..K} | |
| soft_preds: [B, D] continuous expected value E[Y] | |
| """ | |
| probs = 1.0 / (1.0 + np.exp(-logits)) # sigmoid → P(Y>k | Y>=k) | |
| cum_probs = np.cumprod(probs, axis=-1) # P(Y>k) = prod_{j<=k} P(Y>j|Y>=j) | |
| hard_preds = (cum_probs > 0.5).sum(axis=-1) + 1 # [B, D] | |
| cum_ext = np.concatenate([ | |
| np.ones((*cum_probs.shape[:-1], 1)), | |
| cum_probs, | |
| np.zeros((*cum_probs.shape[:-1], 1)), | |
| ], axis=-1) | |
| p_class = cum_ext[..., :-1] - cum_ext[..., 1:] | |
| p_class = np.maximum(p_class, 0) | |
| class_values = np.arange(1, num_classes + 1) | |
| soft_preds = (p_class * class_values).sum(axis=-1) | |
| return hard_preds, soft_preds | |