|
|
import torch.nn as nn |
|
|
from safetensors import safe_open |
|
|
from transformers import GenerationConfig |
|
|
|
|
|
from dataclasses import dataclass, field |
|
|
from typing import Optional, Callable, Dict, List |
|
|
import os |
|
|
import logging |
|
|
import json |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
CONVERSATION_TEMPLATE = r"""{% set image_count = namespace(value=0) %}{% set video_count = namespace(value=0) %}{% for message in messages %}{% if loop.first and message['role'] != 'system' %}<|im_start|>system |
|
|
You are a helpful assistant.<|im_end|> |
|
|
{% endif %}<|im_start|>{{ message['role'] }} |
|
|
{% if message['content'] is string %}{{ message['content'] }}<|im_end|> |
|
|
{% else %}{% for content in message['content'] %}{% if content['type'] == 'image' or 'image' in content or 'image_url' in content %}{% set image_count.value = image_count.value + 1 %}{% if add_vision_id %}Picture {{ image_count.value }}: {% endif %}<|vision_start|><|image_pad|><|vision_end|>{% elif content['type'] == 'video' or 'video' in content %}{% set video_count.value = video_count.value + 1 %}{% if add_vision_id %}Video {{ video_count.value }}: {% endif %}<|vision_start|><|video_pad|><|vision_end|>{% elif 'text' in content %}{{ content['text'] }}{% endif %}{% endfor %}<|im_end|> |
|
|
{% endif %}{% endfor %}{% if add_generation_prompt %}<|im_start|>assistant |
|
|
{% endif %}""" |
|
|
|
|
|
|
|
|
def load_state_dict_from_safetensor(model_path) -> Dict: |
|
|
"""Load a safetensor file from the given path and return a state_dict. |
|
|
|
|
|
Args: |
|
|
model_path (str): Path to the safetensor file. |
|
|
|
|
|
Returns: |
|
|
Dict[str, torch.Tensor]: A dictionary of model parameters, |
|
|
where keys are parameter names and values are corresponding tensors. |
|
|
""" |
|
|
model_state_dict = {} |
|
|
with safe_open(model_path, framework="pt") as f: |
|
|
for key in f.keys(): |
|
|
model_state_dict[key] = f.get_tensor(key) |
|
|
return model_state_dict |
|
|
|
|
|
def fix_model_parameters(model: nn.Module): |
|
|
"""Freeze all parameters of the given model. |
|
|
|
|
|
Args: |
|
|
model (nn.Module): The PyTorch model whose parameters will be frozen. |
|
|
""" |
|
|
for parameter in model.parameters(): |
|
|
parameter.requires_grad = False |
|
|
|
|
|
def open_model_parameters(model: nn.Module): |
|
|
"""Unfreeze all parameters of the given model. |
|
|
|
|
|
Args: |
|
|
model (nn.Module): The PyTorch model whose parameters will be unfrozen. |
|
|
""" |
|
|
for parameter in model.parameters(): |
|
|
parameter.requires_grad = True |
|
|
|
|
|
def log_trainable_params(model: nn.Module): |
|
|
"""Log all trainable parameters of the given model. |
|
|
|
|
|
Args: |
|
|
model (nn.Module): The PyTorch model to inspect. |
|
|
""" |
|
|
logging.info("Trainable parameters in the model:") |
|
|
for name, param in model.named_parameters(): |
|
|
if param.requires_grad: |
|
|
logging.info(f" {name}: {param.numel()} params, shape={param.shape}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass |
|
|
class EvalConfig: |
|
|
output_dir: str = None |
|
|
batch_size: int = 1 |
|
|
generation_config: GenerationConfig = None |
|
|
|
|
|
@dataclass |
|
|
class StaticEvalRecorder: |
|
|
compute_metrics: List[Callable[[str, str, str], float]] = field(default_factory=list) |
|
|
log_file: Optional[str] = None |
|
|
writer: Optional[object] = None |
|
|
|
|
|
|
|
|
metric_sums: Dict[str, float] = field(init=False) |
|
|
metric_counts: Dict[str, int] = field(init=False) |
|
|
|
|
|
def __post_init__(self): |
|
|
self.metric_sums = {metric.__name__: 0.0 for metric in self.compute_metrics} |
|
|
self.metric_counts = {metric.__name__: 0 for metric in self.compute_metrics} |
|
|
if self.log_file: |
|
|
os.makedirs(os.path.dirname(self.log_file), exist_ok=True) |
|
|
with open(self.log_file, 'w') as f: |
|
|
f.write('') |
|
|
|
|
|
def record_batch(self, completions: List[str], examples: List[Dict]): |
|
|
"""Record results for a batch of model outputs. |
|
|
|
|
|
Args: |
|
|
completions (List[str]): The model's answers (outputs). |
|
|
examples (List[Dict]): Each completion's corresponding question and related attributes. |
|
|
Each example is expected to contain the keys: "prompt" and "solution". |
|
|
""" |
|
|
|
|
|
keys = [key for key in examples[0]] |
|
|
|
|
|
reward_kwargs = {key: [example[key] for example in examples] for key in keys} |
|
|
reward_kwargs['completions'] = completions |
|
|
|
|
|
|
|
|
batched_results = {} |
|
|
for metric in self.compute_metrics: |
|
|
metric_name = metric.__name__ |
|
|
batched_scores = metric(**reward_kwargs) |
|
|
batched_results[metric_name] = batched_scores |
|
|
|
|
|
|
|
|
for i, (completion, example) in enumerate(zip(completions, examples)): |
|
|
|
|
|
metrics_result = { |
|
|
metric_name: batched_results[metric_name][i] |
|
|
for metric_name in batched_results |
|
|
} |
|
|
|
|
|
|
|
|
for metric_name, score in metrics_result.items(): |
|
|
self.metric_sums[metric_name] += score |
|
|
self.metric_counts[metric_name] += 1 |
|
|
|
|
|
|
|
|
prompt = example.get("prompt", "") |
|
|
solution = example.get("solution", "") |
|
|
record = { |
|
|
'prompt': prompt, |
|
|
'solution': solution, |
|
|
'completion': completion, |
|
|
'metrics': metrics_result |
|
|
} |
|
|
|
|
|
|
|
|
if self.log_file: |
|
|
with open(self.log_file, 'a') as f: |
|
|
f.write(json.dumps(record, ensure_ascii=False) + '\n') |
|
|
|
|
|
|
|
|
if self.writer: |
|
|
mean_metrics = self.get_mean_metrics() |
|
|
for name, value in mean_metrics.items(): |
|
|
self.writer.add_scalar(name, value, global_step=self.metric_counts[name]) |
|
|
|
|
|
|
|
|
def get_mean_metrics(self) -> Dict[str, float]: |
|
|
return { |
|
|
name: (self.metric_sums[name] / self.metric_counts[name]) if self.metric_counts[name] > 0 else 0.0 |
|
|
for name in self.metric_sums |
|
|
} |
|
|
|
|
|
def finalize(self): |
|
|
mean_metrics = self.get_mean_metrics() |
|
|
final_record = { |
|
|
'summary_metrics': mean_metrics |
|
|
} |
|
|
|
|
|
if self.log_file: |
|
|
with open(self.log_file, 'a', encoding='utf-8') as f: |
|
|
f.write(json.dumps(final_record, ensure_ascii=False) + '\n') |
|
|
|
|
|
if self.writer: |
|
|
mean_metrics = self.get_mean_metrics() |
|
|
for name, value in mean_metrics.items(): |
|
|
self.writer.add_scalar(name + "_final", value, global_step=self.metric_counts[name]) |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class DynamicEvalRecorder: |
|
|
log_file: Optional[str] = None |
|
|
writer: object = field(default=None) |
|
|
|
|
|
def __post_init__(self): |
|
|
if self.log_file is None: |
|
|
raise ValueError("log_file path must be provided") |
|
|
|
|
|
|
|
|
os.makedirs(os.path.dirname(self.log_file), exist_ok=True) |
|
|
self.logger = logging.getLogger("DynamicEvalRecorder") |
|
|
|
|
|
|
|
|
self._total_reward = 0.0 |
|
|
self._count = 0 |
|
|
|
|
|
|
|
|
with open(self.log_file, "w", encoding="utf-8") as f: |
|
|
f.write("DynamicEvalRecorder Log\n\n") |
|
|
|
|
|
def record_batch(self, conversations: List[str], rewards: List[float]): |
|
|
"""Record a batch of conversations and their associated rewards. |
|
|
|
|
|
Args: |
|
|
conversations (List[str]): List of conversation texts. |
|
|
rewards (List[float]): List of reward values corresponding to conversations. |
|
|
""" |
|
|
if len(conversations) != len(rewards): |
|
|
raise ValueError("conversations and rewards must have the same length") |
|
|
|
|
|
|
|
|
with open(self.log_file, "a", encoding="utf-8") as f: |
|
|
for conv, rew in zip(conversations, rewards): |
|
|
f.write(f"Conversation:\n{conv}\n") |
|
|
f.write(f"Reward: {rew:.4f}\n") |
|
|
f.write("-" * 40 + "\n") |
|
|
|
|
|
|
|
|
self._total_reward += rew |
|
|
self._count += 1 |
|
|
|
|
|
|
|
|
avg_reward = self._total_reward / self._count if self._count > 0 else 0.0 |
|
|
|
|
|
|
|
|
if self.writer is not None: |
|
|
self.writer.add_scalar("reward/avg", avg_reward, self._count) |
|
|
|
|
|
|
|
|
self.logger.info(f"Recorded {len(conversations)} items, avg_reward={avg_reward:.4f}") |
|
|
|
|
|
def finalize(self): |
|
|
"""Finalize evaluation: write final average reward to both log file and TensorBoard.""" |
|
|
|
|
|
avg_reward = self._total_reward / self._count if self._count > 0 else 0.0 |
|
|
|
|
|
|
|
|
with open(self.log_file, "a", encoding="utf-8") as f: |
|
|
f.write("\nFinal Results\n") |
|
|
f.write("=" * 40 + "\n") |
|
|
f.write(f"Average Reward: {avg_reward:.4f}\n") |
|
|
|
|
|
|
|
|
if self.writer: |
|
|
self.writer.add_scalar("ave_reward_final", avg_reward, global_step=self._count) |
|
|
|
|
|
|