Spaces:
Running
on
Zero
Running
on
Zero
| import numpy as np | |
| from collections import deque | |
| import torch | |
| # import warnings | |
| # def warn_with_traceback(message, category, filename, lineno, file=None, line=None): | |
| # print(f"\n⚠️ RuntimeWarning caught: {message}") | |
| # print("first 10 values:", rewards) | |
| # return warnings.default_action | |
| # warnings.showwarning = warn_with_traceback | |
| class PerPromptStatTracker: | |
| def __init__(self, global_std=False): | |
| self.global_std = global_std | |
| self.stats = {} | |
| self.history_prompts = set() | |
| def update(self, prompts, rewards, type='grpo'): | |
| prompts = np.array(prompts) | |
| rewards = np.array(rewards, dtype=np.float64) | |
| unique = np.unique(prompts) | |
| advantages = np.empty_like(rewards)*0.0 | |
| # try: | |
| # advantages = np.empty_like(rewards) * 0.0 | |
| # except RuntimeWarning as e: | |
| # print("⚠️ RuntimeWarning:", e) | |
| # print("rewards shape:", rewards.shape) | |
| # print("rewards first 10:", rewards) | |
| # raise | |
| for prompt in unique: | |
| prompt_rewards = rewards[prompts == prompt] | |
| if prompt not in self.stats: | |
| self.stats[prompt] = [] | |
| self.stats[prompt].extend(prompt_rewards) | |
| self.history_prompts.add(hash(prompt)) # Add hash of prompt to history_prompts | |
| for prompt in unique: | |
| self.stats[prompt] = np.stack(self.stats[prompt]) | |
| prompt_rewards = rewards[prompts == prompt] # Fix: Recalculate prompt_rewards for each prompt | |
| mean = np.mean(self.stats[prompt], axis=0, keepdims=True) | |
| if self.global_std: | |
| std = np.std(rewards, axis=0, keepdims=True) + 1e-4 # Use global std of all rewards | |
| else: | |
| std = np.std(self.stats[prompt], axis=0, keepdims=True) + 1e-4 | |
| if type=='grpo': | |
| advantages[prompts == prompt] = (prompt_rewards - mean) / std | |
| elif type=='rwr': | |
| # advantages[prompts == prompt] = (prompt_rewards - mean) / std | |
| advantages[prompts == prompt] = prompt_rewards | |
| # advantages[prompts == prompt] = torch.softmax(torch.tensor(prompt_rewards), dim=0).numpy() | |
| elif type=='sft': | |
| advantages[prompts == prompt] = (torch.tensor(prompt_rewards) == torch.max(torch.tensor(prompt_rewards))).float().numpy() | |
| elif type=='dpo': | |
| # Get the advantages of the current prompt | |
| prompt_advantages = torch.tensor(prompt_rewards) | |
| # Find the indices of the maximum and minimum values | |
| max_idx = torch.argmax(prompt_advantages) | |
| min_idx = torch.argmin(prompt_advantages) | |
| # If all rewards in a group are the same | |
| if max_idx == min_idx: | |
| min_idx = 0 | |
| max_idx = 1 | |
| result = torch.zeros_like(prompt_advantages).float() | |
| # Set the maximum index to 1, minimum index to -1 | |
| result[max_idx] = 1.0 | |
| result[min_idx] = -1.0 | |
| advantages[prompts == prompt] = result.numpy() | |
| # print("reward difference one group", prompt_advantages[max_idx]-prompt_advantages[min_idx]) | |
| return advantages | |
| def get_stats(self): | |
| avg_group_size = sum(len(v) for v in self.stats.values()) / len(self.stats) if self.stats else 0 | |
| history_prompts = len(self.history_prompts) | |
| return avg_group_size, history_prompts | |
| def clear(self): | |
| self.stats = {} | |
| def main(): | |
| tracker = PerPromptStatTracker() | |
| prompts = ['a', 'b', 'a', 'c', 'b', 'a'] | |
| rewards = [1, 2, 3, 4, 5, 6] | |
| advantages = tracker.update(prompts, rewards) | |
| print("Advantages:", advantages) | |
| avg_group_size, history_prompts = tracker.get_stats() | |
| print("Average Group Size:", avg_group_size) | |
| print("History Prompts:", history_prompts) | |
| tracker.clear() | |
| print("Stats after clear:", tracker.stats) | |
| if __name__ == "__main__": | |
| main() |