LCZZZZ's picture
Upload MemGen code and data
e34b94f verified
import torch.nn as nn
from safetensors import safe_open
from transformers import GenerationConfig
from dataclasses import dataclass, field
from typing import Optional, Callable, Dict, List
import os
import logging
import json
# ===== chat template =====
# Qwen2.5-VL chat template with vision support
CONVERSATION_TEMPLATE = r"""{% set image_count = namespace(value=0) %}{% set video_count = namespace(value=0) %}{% for message in messages %}{% if loop.first and message['role'] != 'system' %}<|im_start|>system
You are a helpful assistant.<|im_end|>
{% endif %}<|im_start|>{{ message['role'] }}
{% if message['content'] is string %}{{ message['content'] }}<|im_end|>
{% else %}{% for content in message['content'] %}{% if content['type'] == 'image' or 'image' in content or 'image_url' in content %}{% set image_count.value = image_count.value + 1 %}{% if add_vision_id %}Picture {{ image_count.value }}: {% endif %}<|vision_start|><|image_pad|><|vision_end|>{% elif content['type'] == 'video' or 'video' in content %}{% set video_count.value = video_count.value + 1 %}{% if add_vision_id %}Video {{ video_count.value }}: {% endif %}<|vision_start|><|video_pad|><|vision_end|>{% elif 'text' in content %}{{ content['text'] }}{% endif %}{% endfor %}<|im_end|>
{% endif %}{% endfor %}{% if add_generation_prompt %}<|im_start|>assistant
{% endif %}"""
# ===== torch part =====
def load_state_dict_from_safetensor(model_path) -> Dict:
"""Load a safetensor file from the given path and return a state_dict.
Args:
model_path (str): Path to the safetensor file.
Returns:
Dict[str, torch.Tensor]: A dictionary of model parameters,
where keys are parameter names and values are corresponding tensors.
"""
model_state_dict = {}
with safe_open(model_path, framework="pt") as f:
for key in f.keys():
model_state_dict[key] = f.get_tensor(key)
return model_state_dict
def fix_model_parameters(model: nn.Module):
"""Freeze all parameters of the given model.
Args:
model (nn.Module): The PyTorch model whose parameters will be frozen.
"""
for parameter in model.parameters():
parameter.requires_grad = False
def open_model_parameters(model: nn.Module):
"""Unfreeze all parameters of the given model.
Args:
model (nn.Module): The PyTorch model whose parameters will be unfrozen.
"""
for parameter in model.parameters():
parameter.requires_grad = True
def log_trainable_params(model: nn.Module):
"""Log all trainable parameters of the given model.
Args:
model (nn.Module): The PyTorch model to inspect.
"""
logging.info("Trainable parameters in the model:")
for name, param in model.named_parameters():
if param.requires_grad:
logging.info(f" {name}: {param.numel()} params, shape={param.shape}")
# ===== Eval Part =====
@dataclass
class EvalConfig:
output_dir: str = None
batch_size: int = 1
generation_config: GenerationConfig = None
@dataclass
class StaticEvalRecorder:
compute_metrics: List[Callable[[str, str, str], float]] = field(default_factory=list)
log_file: Optional[str] = None
writer: Optional[object] = None
# Internal storage
metric_sums: Dict[str, float] = field(init=False)
metric_counts: Dict[str, int] = field(init=False)
def __post_init__(self):
self.metric_sums = {metric.__name__: 0.0 for metric in self.compute_metrics}
self.metric_counts = {metric.__name__: 0 for metric in self.compute_metrics}
if self.log_file:
os.makedirs(os.path.dirname(self.log_file), exist_ok=True)
with open(self.log_file, 'w') as f:
f.write('') # Clear file
def record_batch(self, completions: List[str], examples: List[Dict]):
"""Record results for a batch of model outputs.
Args:
completions (List[str]): The model's answers (outputs).
examples (List[Dict]): Each completion's corresponding question and related attributes.
Each example is expected to contain the keys: "prompt" and "solution".
"""
# Extract all keys from the first example
keys = [key for key in examples[0]]
# Build kwargs for metrics computation (one list per field)
reward_kwargs = {key: [example[key] for example in examples] for key in keys}
reward_kwargs['completions'] = completions
# Compute all metrics in batch
batched_results = {}
for metric in self.compute_metrics: # iterate over each metric function
metric_name = metric.__name__ # use function name as metric name
batched_scores = metric(**reward_kwargs) # compute scores for the entire batch
batched_results[metric_name] = batched_scores
# Record experiment results for each example
for i, (completion, example) in enumerate(zip(completions, examples)):
# Collect the metric results for this specific example
metrics_result = {
metric_name: batched_results[metric_name][i]
for metric_name in batched_results
}
# Update running totals for metrics
for metric_name, score in metrics_result.items():
self.metric_sums[metric_name] += score
self.metric_counts[metric_name] += 1
# Create a log record with prompt, solution, completion, and metrics
prompt = example.get("prompt", "")
solution = example.get("solution", "")
record = {
'prompt': prompt,
'solution': solution,
'completion': completion,
'metrics': metrics_result
}
# Write the record into a log file (if available)
if self.log_file:
with open(self.log_file, 'a') as f:
f.write(json.dumps(record, ensure_ascii=False) + '\n')
# Update TensorBoard metrics (if writer is available)
if self.writer:
mean_metrics = self.get_mean_metrics() # get average metrics across all data so far
for name, value in mean_metrics.items():
self.writer.add_scalar(name, value, global_step=self.metric_counts[name])
def get_mean_metrics(self) -> Dict[str, float]:
return {
name: (self.metric_sums[name] / self.metric_counts[name]) if self.metric_counts[name] > 0 else 0.0
for name in self.metric_sums
}
def finalize(self):
mean_metrics = self.get_mean_metrics()
final_record = {
'summary_metrics': mean_metrics
}
if self.log_file:
with open(self.log_file, 'a', encoding='utf-8') as f:
f.write(json.dumps(final_record, ensure_ascii=False) + '\n')
if self.writer:
mean_metrics = self.get_mean_metrics()
for name, value in mean_metrics.items():
self.writer.add_scalar(name + "_final", value, global_step=self.metric_counts[name])
@dataclass
class DynamicEvalRecorder:
log_file: Optional[str] = None # path to the txt log file
writer: object = field(default=None) # TensorBoard SummaryWriter
def __post_init__(self):
if self.log_file is None:
raise ValueError("log_file path must be provided")
# Ensure the directory for the log file exists
os.makedirs(os.path.dirname(self.log_file), exist_ok=True)
self.logger = logging.getLogger("DynamicEvalRecorder")
# Internal counters
self._total_reward = 0.0
self._count = 0
# Initialize the file (clear previous content if any)
with open(self.log_file, "w", encoding="utf-8") as f:
f.write("DynamicEvalRecorder Log\n\n")
def record_batch(self, conversations: List[str], rewards: List[float]):
"""Record a batch of conversations and their associated rewards.
Args:
conversations (List[str]): List of conversation texts.
rewards (List[float]): List of reward values corresponding to conversations.
"""
if len(conversations) != len(rewards):
raise ValueError("conversations and rewards must have the same length")
# Append batch results to the log file
with open(self.log_file, "a", encoding="utf-8") as f:
for conv, rew in zip(conversations, rewards):
f.write(f"Conversation:\n{conv}\n")
f.write(f"Reward: {rew:.4f}\n")
f.write("-" * 40 + "\n")
# Update statistics
self._total_reward += rew
self._count += 1
# Compute running average reward
avg_reward = self._total_reward / self._count if self._count > 0 else 0.0
# Write running average to TensorBoard
if self.writer is not None:
self.writer.add_scalar("reward/avg", avg_reward, self._count)
# Log summary info
self.logger.info(f"Recorded {len(conversations)} items, avg_reward={avg_reward:.4f}")
def finalize(self):
"""Finalize evaluation: write final average reward to both log file and TensorBoard."""
# Compute final average reward
avg_reward = self._total_reward / self._count if self._count > 0 else 0.0
# Append final result to log file
with open(self.log_file, "a", encoding="utf-8") as f:
f.write("\nFinal Results\n")
f.write("=" * 40 + "\n")
f.write(f"Average Reward: {avg_reward:.4f}\n")
# Write final result to TensorBoard
if self.writer:
self.writer.add_scalar("ave_reward_final", avg_reward, global_step=self._count)