diff --git a/.hydra/config.yaml b/.hydra/config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..c7b5e78483cb9163122e055603a1b0193118c1cb --- /dev/null +++ b/.hydra/config.yaml @@ -0,0 +1,178 @@ +experiment: + wandb_enabled: true + nb_epochs: 3000 + nb_matches_per_iteration: 64 + reinit_matches_each_it: true + checkpoint_every_n_iterations: 10 + start_epoch: 0 + resume_experiment: true + base_seed: 0 + seed_group_size: 8 + train: true + stat_methods_for_live_wandb: mllm.markov_games.negotiation.negotiation_statistics + name: no_press_10_1_ties_ad_align_nocurrtimestep + agent_buffer: true + keep_agent_buffer_count: ${lora_count} + agent_buffer_recent_k: -1 +logging: + wandb: + enabled: false + project: llm-negotiation + entity: null + mode: online + name: null + group: null + tags: [] + notes: null +temperature: 1.0 +markov_games: + runner_method_name: LinearRunner + runner_kwargs: {} + group_by_round: true + simulation_class_name: NoPressSimulation + simulation_init_args: + nb_of_rounds: 10 + quota_messages_per_agent_per_round: 0 + game_type: 10-1-ties + atleast_one_conflict: true + item_types: + - hats + - books + - balls + agents: + 0: + agent_id: ${agent_0_id} + agent_name: Alice + agent_class_name: NoPressAgent + policy_id: base_llm/agent_adapter + init_kwargs: + goal: Maximize your total points over the whole game. + 1: + agent_id: ${agent_1_id} + agent_name: Bob + agent_class_name: NoPressAgent + policy_id: base_llm/agent_adapter + init_kwargs: + goal: Maximize your total points over the whole game. +models: + base_llm: + class: LeanLocalLLM + init_args: + llm_id: base_llm + model_name: Qwen/Qwen2.5-7B-Instruct + inference_backend: vllm + hf_kwargs: + device_map: auto + torch_dtype: bfloat16 + max_memory: + 0: 20GiB + attn_implementation: flash_attention_2 + inference_backend_init_kwargs: + enable_lora: true + seed: ${experiment.base_seed} + enable_prefix_caching: true + max_model_len: 10000.0 + gpu_memory_utilization: 0.5 + dtype: bfloat16 + trust_remote_code: true + max_lora_rank: 32 + enforce_eager: false + max_loras: ${lora_count} + max_cpu_loras: ${lora_count} + enable_sleep_mode: true + inference_backend_sampling_params: + temperature: ${temperature} + top_p: 1.0 + max_tokens: 400 + top_k: -1 + logprobs: 0 + adapter_configs: + agent_adapter: + task_type: CAUSAL_LM + r: 32 + lora_alpha: 64 + lora_dropout: 0.0 + target_modules: all-linear + critic_adapter: + task_type: CAUSAL_LM + r: 32 + lora_alpha: 64 + lora_dropout: 0.0 + target_modules: all-linear + enable_thinking: null + regex_max_attempts: 3 +critics: + agent_critic: + module_pointer: + - base_llm + - critic_adapter +optimizers: + agent_optimizer: + module_pointer: + - base_llm + - agent_adapter + optimizer_class_name: torch.optim.Adam + init_args: + lr: 3.0e-06 + weight_decay: 0.0 + critic_optimizer: + module_pointer: agent_critic + optimizer_class_name: torch.optim.Adam + init_args: + lr: 3.0e-06 + weight_decay: 0.0 +trainers: + agent_trainer: + class: TrainerAdAlign + module_pointers: + policy: + - base_llm + - agent_adapter + policy_optimizer: agent_optimizer + critic: agent_critic + critic_optimizer: critic_optimizer + kwargs: + entropy_coeff: 0.0 + entropy_topk: null + entropy_mask_regex: null + kl_coeff: 0.001 + gradient_clipping: 1.0 + restrict_tokens: null + mini_batch_size: 1 + use_gradient_checkpointing: false + temperature: ${temperature} + device: cuda:0 + use_gae: false + whiten_advantages: false + whiten_advantages_time_step_wise: false + skip_discounted_state_visitation: true + use_gae_lambda_annealing: false + gae_lambda_annealing_method: None + gae_lambda_annealing_method_params: None + gae_lambda_annealing_limit: 0.95 + discount_factor: 0.9 + use_rloo: true + enable_tokenwise_logging: false + pg_loss_normalization: nb_tokens + truncated_importance_sampling_ratio_cap: 2.0 + reward_normalizing_constant: 100.0 + ad_align_force_coop_first_step: false + ad_align_clipping: null + ad_align_gamma: 0.9 + ad_align_exclude_k_equals_t: true + ad_align_use_sign: false + ad_align_beta: 1.0 + use_old_ad_align: true + use_time_regularization: false + rloo_branch: false + reuse_baseline: false +train_on_which_data: + agent_trainer: ${agent_ids} +lora_count: 30 +common_agent_kwargs: + goal: Maximize your total points over the whole game. +agent_0_id: Alice +agent_1_id: Bob +agent_ids: +- Alice +- Bob diff --git a/.hydra/hydra.yaml b/.hydra/hydra.yaml new file mode 100644 index 0000000000000000000000000000000000000000..7a65f4ff522ef2a1ea67f0f10b48ef49ca896a7a --- /dev/null +++ b/.hydra/hydra.yaml @@ -0,0 +1,154 @@ +hydra: + run: + dir: ${oc.env:SCRATCH}/llm_negotiation/${now:%Y_%m}/${experiment.name} + sweep: + dir: multirun/${now:%Y-%m-%d}/${now:%H-%M-%S} + subdir: ${hydra.job.num} + launcher: + _target_: hydra._internal.core_plugins.basic_launcher.BasicLauncher + sweeper: + _target_: hydra._internal.core_plugins.basic_sweeper.BasicSweeper + max_batch_size: null + params: null + help: + app_name: ${hydra.job.name} + header: '${hydra.help.app_name} is powered by Hydra. + + ' + footer: 'Powered by Hydra (https://hydra.cc) + + Use --hydra-help to view Hydra specific help + + ' + template: '${hydra.help.header} + + == Configuration groups == + + Compose your configuration from those groups (group=option) + + + $APP_CONFIG_GROUPS + + + == Config == + + Override anything in the config (foo.bar=value) + + + $CONFIG + + + ${hydra.help.footer} + + ' + hydra_help: + template: 'Hydra (${hydra.runtime.version}) + + See https://hydra.cc for more info. + + + == Flags == + + $FLAGS_HELP + + + == Configuration groups == + + Compose your configuration from those groups (For example, append hydra/job_logging=disabled + to command line) + + + $HYDRA_CONFIG_GROUPS + + + Use ''--cfg hydra'' to Show the Hydra config. + + ' + hydra_help: ??? + hydra_logging: + version: 1 + formatters: + simple: + format: '[%(asctime)s][HYDRA] %(message)s' + handlers: + console: + class: logging.StreamHandler + formatter: simple + stream: ext://sys.stdout + root: + level: INFO + handlers: + - console + loggers: + logging_example: + level: DEBUG + disable_existing_loggers: false + job_logging: + version: 1 + formatters: + simple: + format: '[%(asctime)s][%(name)s][%(levelname)s] - %(message)s' + handlers: + console: + class: logging.StreamHandler + formatter: simple + stream: ext://sys.stdout + file: + class: logging.FileHandler + formatter: simple + filename: ${hydra.runtime.output_dir}/${hydra.job.name}.log + root: + level: INFO + handlers: + - console + - file + disable_existing_loggers: false + env: {} + mode: RUN + searchpath: [] + callbacks: {} + output_subdir: .hydra + overrides: + hydra: + - hydra.mode=RUN + task: [] + job: + name: run + chdir: false + override_dirname: '' + id: ??? + num: ??? + config_name: no_press_10_1_ties_ad_align_nocurrtimestep.yaml + env_set: {} + env_copy: [] + config: + override_dirname: + kv_sep: '=' + item_sep: ',' + exclude_keys: [] + runtime: + version: 1.3.2 + version_base: '1.1' + cwd: /scratch/muqeeth/llm_negotiation + config_sources: + - path: hydra.conf + schema: pkg + provider: hydra + - path: /scratch/muqeeth/llm_negotiation/configs + schema: file + provider: main + - path: '' + schema: structured + provider: schema + output_dir: /scratch/muqeeth/llm_negotiation/2025_11/no_press_10_1_ties_ad_align_nocurrtimestep + choices: + hydra/env: default + hydra/callbacks: null + hydra/job_logging: default + hydra/hydra_logging: default + hydra/hydra_help: default + hydra/help: default + hydra/sweeper: basic + hydra/launcher: basic + hydra/output: default + verbose: false diff --git a/.hydra/overrides.yaml b/.hydra/overrides.yaml new file mode 100644 index 0000000000000000000000000000000000000000..fe51488c7066f6687ef680d6bfaa4f7768ef205c --- /dev/null +++ b/.hydra/overrides.yaml @@ -0,0 +1 @@ +[] diff --git a/seed_0/Qwen/Qwen2.5-7B-Instruct/adapters/README.md b/seed_0/Qwen/Qwen2.5-7B-Instruct/adapters/README.md new file mode 100644 index 0000000000000000000000000000000000000000..952935e8a936512044016a9bc1f922b109c88143 --- /dev/null +++ b/seed_0/Qwen/Qwen2.5-7B-Instruct/adapters/README.md @@ -0,0 +1,207 @@ +--- +base_model: Qwen/Qwen2.5-7B-Instruct +library_name: peft +pipeline_tag: text-generation +tags: +- base_model:adapter:Qwen/Qwen2.5-7B-Instruct +- lora +- transformers +--- + +# Model Card for Model ID + + + + + +## Model Details + +### Model Description + + + + + +- **Developed by:** [More Information Needed] +- **Funded by [optional]:** [More Information Needed] +- **Shared by [optional]:** [More Information Needed] +- **Model type:** [More Information Needed] +- **Language(s) (NLP):** [More Information Needed] +- **License:** [More Information Needed] +- **Finetuned from model [optional]:** [More Information Needed] + +### Model Sources [optional] + + + +- **Repository:** [More Information Needed] +- **Paper [optional]:** [More Information Needed] +- **Demo [optional]:** [More Information Needed] + +## Uses + + + +### Direct Use + + + +[More Information Needed] + +### Downstream Use [optional] + + + +[More Information Needed] + +### Out-of-Scope Use + + + +[More Information Needed] + +## Bias, Risks, and Limitations + + + +[More Information Needed] + +### Recommendations + + + +Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations. + +## How to Get Started with the Model + +Use the code below to get started with the model. + +[More Information Needed] + +## Training Details + +### Training Data + + + +[More Information Needed] + +### Training Procedure + + + +#### Preprocessing [optional] + +[More Information Needed] + + +#### Training Hyperparameters + +- **Training regime:** [More Information Needed] + +#### Speeds, Sizes, Times [optional] + + + +[More Information Needed] + +## Evaluation + + + +### Testing Data, Factors & Metrics + +#### Testing Data + + + +[More Information Needed] + +#### Factors + + + +[More Information Needed] + +#### Metrics + + + +[More Information Needed] + +### Results + +[More Information Needed] + +#### Summary + + + +## Model Examination [optional] + + + +[More Information Needed] + +## Environmental Impact + + + +Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700). + +- **Hardware Type:** [More Information Needed] +- **Hours used:** [More Information Needed] +- **Cloud Provider:** [More Information Needed] +- **Compute Region:** [More Information Needed] +- **Carbon Emitted:** [More Information Needed] + +## Technical Specifications [optional] + +### Model Architecture and Objective + +[More Information Needed] + +### Compute Infrastructure + +[More Information Needed] + +#### Hardware + +[More Information Needed] + +#### Software + +[More Information Needed] + +## Citation [optional] + + + +**BibTeX:** + +[More Information Needed] + +**APA:** + +[More Information Needed] + +## Glossary [optional] + + + +[More Information Needed] + +## More Information [optional] + +[More Information Needed] + +## Model Card Authors [optional] + +[More Information Needed] + +## Model Card Contact + +[More Information Needed] +### Framework versions + +- PEFT 0.17.1 \ No newline at end of file diff --git a/seed_0/Qwen/Qwen2.5-7B-Instruct/adapters/agent_adapter/adapter_config.json b/seed_0/Qwen/Qwen2.5-7B-Instruct/adapters/agent_adapter/adapter_config.json new file mode 100644 index 0000000000000000000000000000000000000000..2d773bc0e89a092141ca5b151d2cc8043f7c7dd2 --- /dev/null +++ b/seed_0/Qwen/Qwen2.5-7B-Instruct/adapters/agent_adapter/adapter_config.json @@ -0,0 +1,42 @@ +{ + "alpha_pattern": {}, + "auto_mapping": null, + "base_model_name_or_path": "Qwen/Qwen2.5-7B-Instruct", + "bias": "none", + "corda_config": null, + "eva_config": null, + "exclude_modules": null, + "fan_in_fan_out": false, + "inference_mode": true, + "init_lora_weights": true, + "layer_replication": null, + "layers_pattern": null, + "layers_to_transform": null, + "loftq_config": {}, + "lora_alpha": 64, + "lora_bias": false, + "lora_dropout": 0.0, + "megatron_config": null, + "megatron_core": "megatron.core", + "modules_to_save": null, + "peft_type": "LORA", + "qalora_group_size": 16, + "r": 32, + "rank_pattern": {}, + "revision": null, + "target_modules": [ + "gate_proj", + "v_proj", + "k_proj", + "down_proj", + "up_proj", + "o_proj", + "q_proj" + ], + "target_parameters": null, + "task_type": "CAUSAL_LM", + "trainable_token_indices": null, + "use_dora": false, + "use_qalora": false, + "use_rslora": false +} \ No newline at end of file diff --git a/seed_0/Qwen/Qwen2.5-7B-Instruct/adapters/critic_adapter/adapter_config.json b/seed_0/Qwen/Qwen2.5-7B-Instruct/adapters/critic_adapter/adapter_config.json new file mode 100644 index 0000000000000000000000000000000000000000..2d773bc0e89a092141ca5b151d2cc8043f7c7dd2 --- /dev/null +++ b/seed_0/Qwen/Qwen2.5-7B-Instruct/adapters/critic_adapter/adapter_config.json @@ -0,0 +1,42 @@ +{ + "alpha_pattern": {}, + "auto_mapping": null, + "base_model_name_or_path": "Qwen/Qwen2.5-7B-Instruct", + "bias": "none", + "corda_config": null, + "eva_config": null, + "exclude_modules": null, + "fan_in_fan_out": false, + "inference_mode": true, + "init_lora_weights": true, + "layer_replication": null, + "layers_pattern": null, + "layers_to_transform": null, + "loftq_config": {}, + "lora_alpha": 64, + "lora_bias": false, + "lora_dropout": 0.0, + "megatron_config": null, + "megatron_core": "megatron.core", + "modules_to_save": null, + "peft_type": "LORA", + "qalora_group_size": 16, + "r": 32, + "rank_pattern": {}, + "revision": null, + "target_modules": [ + "gate_proj", + "v_proj", + "k_proj", + "down_proj", + "up_proj", + "o_proj", + "q_proj" + ], + "target_parameters": null, + "task_type": "CAUSAL_LM", + "trainable_token_indices": null, + "use_dora": false, + "use_qalora": false, + "use_rslora": false +} \ No newline at end of file diff --git a/src_code_for_reproducibility/__pycache__/__init__.cpython-312.pyc b/src_code_for_reproducibility/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ad6cd81dffe4967949915fd4aae73bacfdaa0a73 Binary files /dev/null and b/src_code_for_reproducibility/__pycache__/__init__.cpython-312.pyc differ diff --git a/src_code_for_reproducibility/chat_utils/__pycache__/apply_template.cpython-312.pyc b/src_code_for_reproducibility/chat_utils/__pycache__/apply_template.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..064ba6082198b109d499f9d35f8678a8fe6a41ab Binary files /dev/null and b/src_code_for_reproducibility/chat_utils/__pycache__/apply_template.cpython-312.pyc differ diff --git a/src_code_for_reproducibility/chat_utils/__pycache__/chat_turn.cpython-312.pyc b/src_code_for_reproducibility/chat_utils/__pycache__/chat_turn.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..788adbbff9903e4ef506b14867c297591f06a1f2 Binary files /dev/null and b/src_code_for_reproducibility/chat_utils/__pycache__/chat_turn.cpython-312.pyc differ diff --git a/src_code_for_reproducibility/chat_utils/__pycache__/template_specific.cpython-312.pyc b/src_code_for_reproducibility/chat_utils/__pycache__/template_specific.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5154ddde7d30395315ba112248e798d4a5b53197 Binary files /dev/null and b/src_code_for_reproducibility/chat_utils/__pycache__/template_specific.cpython-312.pyc differ diff --git a/src_code_for_reproducibility/chat_utils/apply_template.py b/src_code_for_reproducibility/chat_utils/apply_template.py new file mode 100644 index 0000000000000000000000000000000000000000..a8cc5d253c2088e624bcf0dff645d95bf5747d6b --- /dev/null +++ b/src_code_for_reproducibility/chat_utils/apply_template.py @@ -0,0 +1,84 @@ +import torch + +from mllm.chat_utils.chat_turn import ChatTurn +from mllm.chat_utils.template_specific import ( + custom_gemma3_template, + custom_llama3_template, + custom_qwen2_template, + custom_qwen3_template, + gemma3_assistant_postfix, + qwen2_assistant_postfix, + qwen3_assistant_postfix, +) + + +def get_custom_chat_template(tokenizer) -> str: + """ + Get the chat template for the tokenizer. + """ + if "qwen2" in tokenizer.name_or_path.lower(): + return custom_qwen2_template + elif "llama" in tokenizer.name_or_path.lower(): + return custom_llama3_template + elif "qwen3" in tokenizer.name_or_path.lower(): + return custom_qwen3_template + elif "gemma" in tokenizer.name_or_path.lower(): + return custom_gemma3_template + else: + raise ValueError(f"Tokenizer {tokenizer.name_or_path} not supported") + + +def get_custom_assistant_postfix(tokenizer) -> torch.Tensor: + """ + Get the custom assistant postfix for the tokenizer. + """ + if "qwen2" in tokenizer.name_or_path.lower(): + return qwen2_assistant_postfix + elif "qwen3" in tokenizer.name_or_path.lower(): + return qwen3_assistant_postfix + elif "gemma" in tokenizer.name_or_path.lower(): + return gemma3_assistant_postfix + return torch.tensor([], dtype=torch.long) + + +def tokenize_chats(chats: list[ChatTurn], tokenizer, enable_thinking) -> None: + """ + Set the chat_template_token_ids for each chat turn. + # TODO: use engine tokens if available + """ + custom_template = get_custom_chat_template(tokenizer) + custom_assistant_postfix: torch.Tensor = get_custom_assistant_postfix(tokenizer) + for i, chat in enumerate(chats): + if chat.chat_template_token_ids is None: + if chat.role == "user": + next_chat = chats[i + 1] if i + 1 < len(chats) else None + add_generation_prompt = True + if next_chat and next_chat.role == "user": + add_generation_prompt = False + encoded_chat = tokenizer.apply_chat_template( + [chat], + return_tensors="pt", + chat_template=custom_template, + add_generation_prompt=add_generation_prompt, + add_system_prompt=True if i == 0 else False, + enable_thinking=enable_thinking, + ).flatten() + previous_chat = chats[i - 1] if i > 0 else None + if previous_chat and previous_chat.role == "assistant": + encoded_chat = torch.cat([custom_assistant_postfix, encoded_chat]) + elif chat.role == "assistant": + encoded_chat = chat.out_token_ids + chat.chat_template_token_ids = encoded_chat + + +def chat_turns_to_token_ids( + chats: list[ChatTurn], tokenizer, enable_thinking +) -> list[int]: + """ + Tokenize the chat turns and set the chat_template_token_ids for each chat turn. + """ + tokenize_chats(chats=chats, tokenizer=tokenizer, enable_thinking=enable_thinking) + token_ids = [] + for chat in chats: + token_ids.append(chat.chat_template_token_ids) + return torch.cat(token_ids) diff --git a/src_code_for_reproducibility/chat_utils/chat_turn.py b/src_code_for_reproducibility/chat_utils/chat_turn.py new file mode 100644 index 0000000000000000000000000000000000000000..d22c1b853def374fa0ee21eb523f2b9d104ed35b --- /dev/null +++ b/src_code_for_reproducibility/chat_utils/chat_turn.py @@ -0,0 +1,27 @@ +from __future__ import annotations + +import json +from dataclasses import dataclass +from pathlib import Path +from typing import Any, List, Literal, Optional, Tuple + +import jsonschema +import torch +from pydantic import BaseModel, ConfigDict, Field, model_validator + +AgentId = str + + +class ChatTurn(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True) # needed for torch tensors + + role: str = Field(pattern="^(user|assistant)$") + agent_id: AgentId # ID of the agent with which the chat occured + content: str + reasoning_content: str | None = None + chat_template_token_ids: torch.LongTensor | None = None # Token ids of chat template format. For example, token ids of "{content}"" + out_token_ids: torch.LongTensor | None = ( + None # tokens generated from inference engine + ) + log_probs: torch.FloatTensor | None = None + is_state_end: bool = False # indicates whether this chat turn marks the end of a state in the trajectory diff --git a/src_code_for_reproducibility/chat_utils/template_specific.py b/src_code_for_reproducibility/chat_utils/template_specific.py new file mode 100644 index 0000000000000000000000000000000000000000..aa5f0513bc10d7597ee305c968540b7bf02ab741 --- /dev/null +++ b/src_code_for_reproducibility/chat_utils/template_specific.py @@ -0,0 +1,109 @@ +import huggingface_hub +import torch +from transformers import AutoTokenizer + +custom_llama3_template = """ +{%- if add_system_prompt %} + {{- '<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nCutting Knowledge Date: December 2023\nToday Date: 26 Jul 2024\n\n<|eot_id|>' }} +{%- endif %} +{%- for message in messages %} + {{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n' + message['content'] | trim + '<|eot_id|>' }} +{%- endfor %} + +{%- if add_generation_prompt %} + {{- '<|start_header_id|>' + 'assistant' + '<|end_header_id|>\n\n' }} +{%- endif %} +""" + +qwen2_assistant_postfix = ( + AutoTokenizer.from_pretrained("Qwen/Qwen2.5-7B-Instruct") + .encode("\n", return_tensors="pt") + .flatten() +) +qwen3_assistant_postfix = ( + AutoTokenizer.from_pretrained("Qwen/Qwen3-8B") + .encode("\n", return_tensors="pt") + .flatten() +) +gemma3_assistant_postfix = ( + AutoTokenizer.from_pretrained("google/gemma-3-4b-it") + .encode("\n", return_tensors="pt") + .flatten() +) +custom_qwen2_template = """ +{%- if add_system_prompt %} + {{- '<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n' }} +{%- endif %} +{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %} +{%- for message in messages %} + {%- if message.content is string %} + {%- set content = message.content %} + {%- else %} + {%- set content = '' %} + {%- endif %} + {%- if (message.role == "user") %} + {{- '<|im_start|>' + message.role + '\n' + content + '<|im_end|>' + '\n' }} + {%- elif message.role == "assistant" %} + {%- set reasoning_content = '' %} + {%- if message.reasoning_content is string %} + {%- set reasoning_content = message.reasoning_content %} + {%- else %} + {%- if '' in content %} + {%- set reasoning_content = content.split('')[0].rstrip('\n').split('')[-1].lstrip('\n') %} + {%- set content = content.split('')[-1].lstrip('\n') %} + {%- endif %} + {%- endif %} + {%- if loop.index0 > ns.last_query_index %} + {%- if reasoning_content %} + {{- '<|im_start|>' + message.role + '\n\n' + reasoning_content.strip('\n') + '\n\n\n' + content.lstrip('\n') }} + {%- else %} + {{- '<|im_start|>' + message.role + '\n' + content }} + {%- endif %} + {%- else %} + {{- '<|im_start|>' + message.role + '\n' + content }} + {%- endif %} + {{- '<|im_end|>\n' }} + {%- endif %} +{%- endfor %} +{%- if add_generation_prompt %} + {{- '<|im_start|>assistant\n' }} +{%- endif %} +""" + +custom_qwen3_template = """ +{%- for message in messages %} + {%- if message.content is string %} + {%- set content = message.content %} + {%- else %} + {%- set content = '' %} + {%- endif %} + {%- if (message.role == "user") %} + {{- '<|im_start|>' + message.role + '\n' + content + '<|im_end|>' + '\n' }} + {%- elif message.role == "assistant" %} + {{- '<|im_start|>' + message.role + '\n' + content + '<|im_end|>' + '\n' }} + {%- endif %} +{%- endfor %} +{%- if add_generation_prompt %} + {{- '<|im_start|>assistant\n' }} + {%- if enable_thinking is defined and enable_thinking is false %} + {{- '\n\n\n\n' }} + {%- endif %} +{%- endif %} +""" + +custom_gemma3_template = """ +{%- if add_system_prompt %} +{{- bos_token -}} +{%- endif %} +{%- for message in messages -%} +{%- if message['role'] == 'assistant' -%} +{%- set role = 'model' -%} +{%- else -%} +{%- set role = message['role'] -%} +{%- endif -%} +{{ '' + role + '\n' + message['content'] | trim + '\n' }} +{%- endfor -%} +{%- if add_generation_prompt -%} +{{ 'model\n' }} +{%- endif -%} +""" diff --git a/src_code_for_reproducibility/docs/Makefile b/src_code_for_reproducibility/docs/Makefile new file mode 100644 index 0000000000000000000000000000000000000000..9375655022388557d65910465b09d2a5fa9d2e4a --- /dev/null +++ b/src_code_for_reproducibility/docs/Makefile @@ -0,0 +1,19 @@ +# Minimal makefile for Sphinx documentation + +# You can set these variables from the command line, and also +# from the environment for the first two. +SPHINXOPTS ?= +SPHINXBUILD ?= sphinx-build +SOURCEDIR = source +BUILDDIR = build + +# Put it first so that "make" without argument is like "make help". +help: + @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(SPHINXFLAGS) + +.PHONY: help Makefile + +# Catch-all target: route all unknown targets to Sphinx using the new +# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). +%: Makefile + @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(SPHINXFLAGS) diff --git a/src_code_for_reproducibility/docs/generate_docs.py b/src_code_for_reproducibility/docs/generate_docs.py new file mode 100644 index 0000000000000000000000000000000000000000..e644cbbf091a97420500fb47346c07be5ed141ac --- /dev/null +++ b/src_code_for_reproducibility/docs/generate_docs.py @@ -0,0 +1,249 @@ +#!/usr/bin/env python3 +""" +Script to automatically generate Sphinx documentation for all modules and build the HTML website. +""" +import importlib.util +import os +import subprocess +import sys + + +def check_and_install_dependencies(): + """Check for required dependencies and install them if missing.""" + required_packages = [ + "sphinx", + "sphinx-rtd-theme", + "sphinxcontrib-napoleon", + "sphinxcontrib-mermaid", + "sphinx-autodoc-typehints", + ] + + missing_packages = [] + + for package in required_packages: + # Convert package name to module name (replace - with _) + module_name = package.replace("-", "_") + + # Check if the package is installed + if importlib.util.find_spec(module_name) is None: + missing_packages.append(package) + + # Install missing packages + if missing_packages: + print(f"Installing missing dependencies: {', '.join(missing_packages)}") + subprocess.check_call( + [sys.executable, "-m", "pip", "install"] + missing_packages + ) + print("Dependencies installed successfully") + else: + print("All required dependencies are already installed") + + +def create_makefile(docs_dir): + """Create a Makefile for Sphinx documentation if it doesn't exist.""" + makefile_path = os.path.join(docs_dir, "Makefile") + + if os.path.exists(makefile_path): + print(f"Makefile already exists at {makefile_path}") + return + + print(f"Creating Makefile at {makefile_path}") + + makefile_content = """# Minimal makefile for Sphinx documentation + +# You can set these variables from the command line, and also +# from the environment for the first two. +SPHINXOPTS ?= +SPHINXBUILD ?= sphinx-build +SOURCEDIR = source +BUILDDIR = build + +# Put it first so that "make" without argument is like "make help". +help: + @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(SPHINXFLAGS) + +.PHONY: help Makefile + +# Catch-all target: route all unknown targets to Sphinx using the new +# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). +%: Makefile + @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(SPHINXFLAGS) +""" + + with open(makefile_path, "w") as f: + f.write(makefile_content) + + print("Makefile created successfully") + + +def create_make_bat(docs_dir): + """Create a make.bat file for Windows if it doesn't exist.""" + make_bat_path = os.path.join(docs_dir, "make.bat") + + if os.path.exists(make_bat_path): + print(f"make.bat already exists at {make_bat_path}") + return + + print(f"Creating make.bat at {make_bat_path}") + + make_bat_content = """@ECHO OFF + +pushd %~dp0 + +REM Command file for Sphinx documentation + +if "%SPHINXBUILD%" == "" ( + set SPHINXBUILD=sphinx-build +) +set SOURCEDIR=source +set BUILDDIR=build + +%SPHINXBUILD% >NUL 2>NUL +if errorlevel 9009 ( + echo. + echo.The 'sphinx-build' command was not found. Make sure you have Sphinx + echo.installed, then set the SPHINXBUILD environment variable to point + echo.to the full path of the 'sphinx-build' executable. Alternatively you + echo.may add the Sphinx directory to PATH. + echo. + echo.If you don't have Sphinx installed, grab it from + echo.https://www.sphinx-doc.org/ + exit /b 1 +) + +if "%1" == "" goto help + +%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% +goto end + +:help +%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% + +:end +popd +""" + + with open(make_bat_path, "w") as f: + f.write(make_bat_content) + + print("make.bat created successfully") + + +def main(): + # Check and install required dependencies + print("=== Checking dependencies ===") + check_and_install_dependencies() + + # Get the directory of this script + script_dir = os.path.dirname(os.path.abspath(__file__)) + + # Path to the project root + project_root = os.path.dirname(script_dir) + + # Path to the source directory + source_dir = os.path.join(project_root, "src") + + # Path to the docs source directory + docs_source_dir = os.path.join(script_dir, "source") + + # Print paths for debugging + print(f"Script directory: {script_dir}") + print(f"Project root: {project_root}") + print(f"Source directory: {source_dir}") + print(f"Docs source directory: {docs_source_dir}") + + # Make sure the source directory exists + if not os.path.exists(source_dir): + print(f"Error: Source directory {source_dir} does not exist!") + sys.exit(1) + + # Make sure the docs source directory exists + if not os.path.exists(docs_source_dir): + print(f"Creating docs source directory: {docs_source_dir}") + os.makedirs(docs_source_dir) + + # Step 1: Run sphinx-apidoc to generate .rst files for all modules + print("\n=== Generating API documentation ===") + cmd = [ + "sphinx-apidoc", + "-f", # Force overwriting of existing files + "-e", # Put module documentation before submodule documentation + "-M", # Put module documentation before subpackage documentation + "-o", + docs_source_dir, # Output directory + source_dir, # Source code directory + ] + + print(f"Running command: {' '.join(cmd)}") + result = subprocess.run(cmd, capture_output=True, text=True) + + # Print the output of the command + print("STDOUT:") + print(result.stdout) + + print("STDERR:") + print(result.stderr) + + if result.returncode != 0: + print(f"Error: sphinx-apidoc failed with return code {result.returncode}") + sys.exit(1) + + # List the files in the docs source directory + print("\nFiles in docs/source directory:") + for file in sorted(os.listdir(docs_source_dir)): + print(f" {file}") + + print("\nDocumentation source files generated successfully!") + + # Step 2: Create Makefile and make.bat if they don't exist + create_makefile(script_dir) + create_make_bat(script_dir) + + # Step 3: Build the HTML documentation + print("\n=== Building HTML documentation ===") + + # Determine the build command based on the platform + if os.name == "nt": # Windows + build_cmd = ["make.bat", "html"] + else: # Unix/Linux/Mac + build_cmd = ["make", "html"] + + # Change to the docs directory to run the build command + os.chdir(script_dir) + + print(f"Running command: {' '.join(build_cmd)}") + build_result = subprocess.run(build_cmd, capture_output=True, text=True) + + # Print the output of the build command + print("STDOUT:") + print(build_result.stdout) + + print("STDERR:") + print(build_result.stderr) + + if build_result.returncode != 0: + print(f"Error: HTML build failed with return code {build_result.returncode}") + sys.exit(1) + + # Get the path to the built HTML documentation + html_dir = os.path.join(script_dir, "build", "html") + index_path = os.path.join(html_dir, "index.html") + + if os.path.exists(index_path): + print(f"\nHTML documentation built successfully!") + print(f"You can view it by opening: {index_path}") + + # Try to open the documentation in a browser + try: + import webbrowser + + print("\nAttempting to open documentation in your default browser...") + webbrowser.open(f"file://{index_path}") + except Exception as e: + print(f"Could not open browser automatically: {e}") + else: + print(f"\nWarning: HTML index file not found at {index_path}") + + +if __name__ == "__main__": + main() diff --git a/src_code_for_reproducibility/docs/make.bat b/src_code_for_reproducibility/docs/make.bat new file mode 100644 index 0000000000000000000000000000000000000000..dc1312ab09ca6fb0267dee6b28a38e69c253631a --- /dev/null +++ b/src_code_for_reproducibility/docs/make.bat @@ -0,0 +1,35 @@ +@ECHO OFF + +pushd %~dp0 + +REM Command file for Sphinx documentation + +if "%SPHINXBUILD%" == "" ( + set SPHINXBUILD=sphinx-build +) +set SOURCEDIR=source +set BUILDDIR=build + +%SPHINXBUILD% >NUL 2>NUL +if errorlevel 9009 ( + echo. + echo.The 'sphinx-build' command was not found. Make sure you have Sphinx + echo.installed, then set the SPHINXBUILD environment variable to point + echo.to the full path of the 'sphinx-build' executable. Alternatively you + echo.may add the Sphinx directory to PATH. + echo. + echo.If you don't have Sphinx installed, grab it from + echo.https://www.sphinx-doc.org/ + exit /b 1 +) + +if "%1" == "" goto help + +%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% +goto end + +:help +%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% + +:end +popd diff --git a/src_code_for_reproducibility/markov_games/__init__.py b/src_code_for_reproducibility/markov_games/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src_code_for_reproducibility/markov_games/__pycache__/__init__.cpython-312.pyc b/src_code_for_reproducibility/markov_games/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..79ecfdf770d36c61d6e01ac1abfefdbf47bd0445 Binary files /dev/null and b/src_code_for_reproducibility/markov_games/__pycache__/__init__.cpython-312.pyc differ diff --git a/src_code_for_reproducibility/markov_games/__pycache__/agent.cpython-312.pyc b/src_code_for_reproducibility/markov_games/__pycache__/agent.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5c7a88e6f07a6e33667afa5f45af17ff3e1101f1 Binary files /dev/null and b/src_code_for_reproducibility/markov_games/__pycache__/agent.cpython-312.pyc differ diff --git a/src_code_for_reproducibility/markov_games/__pycache__/alternative_actions_runner.cpython-312.pyc b/src_code_for_reproducibility/markov_games/__pycache__/alternative_actions_runner.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..25d78c9df447299bf96d312e456411dab328cb98 Binary files /dev/null and b/src_code_for_reproducibility/markov_games/__pycache__/alternative_actions_runner.cpython-312.pyc differ diff --git a/src_code_for_reproducibility/markov_games/__pycache__/gather_and_export_utils.cpython-312.pyc b/src_code_for_reproducibility/markov_games/__pycache__/gather_and_export_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2ec1c3c09b21b671420bb64159a63b681f017c22 Binary files /dev/null and b/src_code_for_reproducibility/markov_games/__pycache__/gather_and_export_utils.cpython-312.pyc differ diff --git a/src_code_for_reproducibility/markov_games/__pycache__/group_timesteps.cpython-312.pyc b/src_code_for_reproducibility/markov_games/__pycache__/group_timesteps.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a08c083ee205f85e385dc6777c7cee64b4264a16 Binary files /dev/null and b/src_code_for_reproducibility/markov_games/__pycache__/group_timesteps.cpython-312.pyc differ diff --git a/src_code_for_reproducibility/markov_games/__pycache__/linear_runner.cpython-312.pyc b/src_code_for_reproducibility/markov_games/__pycache__/linear_runner.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..df5ef6b13b173c22c644103ca62b79d94f6c2d31 Binary files /dev/null and b/src_code_for_reproducibility/markov_games/__pycache__/linear_runner.cpython-312.pyc differ diff --git a/src_code_for_reproducibility/markov_games/__pycache__/markov_game.cpython-312.pyc b/src_code_for_reproducibility/markov_games/__pycache__/markov_game.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d8c96718e7c3e5d3cf39f1b82f5205f194c63dd5 Binary files /dev/null and b/src_code_for_reproducibility/markov_games/__pycache__/markov_game.cpython-312.pyc differ diff --git a/src_code_for_reproducibility/markov_games/__pycache__/mg_utils.cpython-312.pyc b/src_code_for_reproducibility/markov_games/__pycache__/mg_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9ce2037b2a1998ae2b089e86267e1095a13e028f Binary files /dev/null and b/src_code_for_reproducibility/markov_games/__pycache__/mg_utils.cpython-312.pyc differ diff --git a/src_code_for_reproducibility/markov_games/__pycache__/rollout_tree.cpython-312.pyc b/src_code_for_reproducibility/markov_games/__pycache__/rollout_tree.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6fda4248e901a18f053224972d4960420aab1a4c Binary files /dev/null and b/src_code_for_reproducibility/markov_games/__pycache__/rollout_tree.cpython-312.pyc differ diff --git a/src_code_for_reproducibility/markov_games/__pycache__/simulation.cpython-312.pyc b/src_code_for_reproducibility/markov_games/__pycache__/simulation.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..be84edf7bf77be552ce53cfe5c1b32014b7d7a03 Binary files /dev/null and b/src_code_for_reproducibility/markov_games/__pycache__/simulation.cpython-312.pyc differ diff --git a/src_code_for_reproducibility/markov_games/agent.py b/src_code_for_reproducibility/markov_games/agent.py new file mode 100644 index 0000000000000000000000000000000000000000..9961d860fecca7eaa4d13cb195b50dc658be5bec --- /dev/null +++ b/src_code_for_reproducibility/markov_games/agent.py @@ -0,0 +1,76 @@ +""" +In simple RL paradise, where the action dimensions are constant and well defined, +Agent classes are not necessary. But in MARL, with LLM's, there isn't always +a direct path from policy to action. For instance, from the observation of the environment, +a prompt must be created. Then, the outputs of the policy might be incorrect, so a second +request to the LLM must be sent before the action is well defined. This is why this Agent class exists. +It acts as a mini environment, bridging the gap between the core simulation and +the LLM policies. +""" + +from abc import ABC, abstractmethod +from collections.abc import Callable +from typing import Any, Tuple + +from numpy.random import default_rng + +from mllm.markov_games.rollout_tree import AgentActLog + + +class Agent(ABC): + @abstractmethod + def __init__( + self, + seed: int, + agent_id: str, + agent_name: str, + agent_policy: Callable[[list[dict]], str], + *args, + **kwargs, + ): + """ + Initialize the agent state. + """ + self.seed = seed + self.agent_id = agent_id + self.agent_name = agent_name + self.policy = policy + self.rng = default_rng(self.seed) + raise NotImplementedError + + async def act(self, observation) -> Tuple[Any, AgentActLog]: + """ + Query (possibly multiple times) a policy (or possibly a pool of policies) to + obtain the action of the agent. + + Example: + action = None + prompt = self.observation_to_prompt(observation) + while not self.valid(action): + output = await self.policy.generate(prompt) + action = self.policy_output_to_action(output) + return action + + Returns: + action + step_info + """ + raise NotImplementedError + + def get_safe_copy(self): + """ + Return copy of the agent object that is decorrelated from the original object. + """ + raise NotImplementedError + + def reset(self): + raise NotImplementedError + + def render(self): + raise NotImplementedError + + def close(self): + raise NotImplementedError + + def get_agent_info(self): + raise NotImplementedError diff --git a/src_code_for_reproducibility/markov_games/alternative_actions_runner.py b/src_code_for_reproducibility/markov_games/alternative_actions_runner.py new file mode 100644 index 0000000000000000000000000000000000000000..c64db2deda539a1a71e045309cfdf257d2cbc614 --- /dev/null +++ b/src_code_for_reproducibility/markov_games/alternative_actions_runner.py @@ -0,0 +1,138 @@ +import asyncio +import copy +import json +import os.path +from typing import Any, Tuple + +from mllm.markov_games.markov_game import AgentAndActionSafeCopy, MarkovGame +from mllm.markov_games.rollout_tree import ( + AgentActLog, + RolloutTreeBranchNode, + RolloutTreeNode, + RolloutTreeRootNode, + StepLog, +) + +AgentId = str + + + +async def run_with_unilateral_alt_action( + markov_game: MarkovGame, + agent_id: AgentId, + time_step: int, + branch_node: RolloutTreeBranchNode, + max_depth: int, +): + """ + This function is used to generate a new branch for a given agent. + """ + + # Generate alternative action and take a step + await markov_game.set_action_of_agent(agent_id) + terminated: bool = markov_game.take_simulation_step() + step_log = markov_game.get_step_log() + first_alternative_node = RolloutTreeNode( + step_log=step_log, + time_step=time_step, + ) + + # Generate rest of trajectory up to max depth + time_step += 1 + counter = 1 + previous_node = first_alternative_node + while not terminated and counter <= max_depth: + terminated, step_log = await markov_game.step() + current_node = RolloutTreeNode(step_log=step_log, time_step=time_step) + previous_node.child = current_node + previous_node = current_node + counter += 1 + time_step += 1 + + if branch_node.branches == None: + branch_node.branches = {agent_id: [first_alternative_node]} + else: + agent_branches = branch_node.branches.get(agent_id, []) + agent_branches.append(first_alternative_node) + branch_node.branches[agent_id] = agent_branches + + +async def AlternativeActionsRunner( + markov_game: MarkovGame, + output_folder: str, + nb_alternative_actions: int, + max_depth: int, + branch_only_on_new_round: bool = False, +): + """ + This method generates a trajectory with partially completed branches, + where the branching comes from taking unilateraly different actions. + The resulting data is used to estimate the updated advantage alignment policy gradient terms. + Let k := nb_sub_steps. Then the number of steps generated is O(Tk), where T is + the maximum trajectory length. + """ + + tasks = [] + time_step = 0 + terminated = False + root = RolloutTreeRootNode( + id=markov_game.get_id(), + crn_id=markov_game.get_crn_id() + ) + previous_node = root + + while not terminated: + mg_before_action = markov_game.get_safe_copy() + + # Get safe copies for main branch + agent_action_safe_copies: dict[ + AgentId, AgentAndActionSafeCopy + ] = await markov_game.get_actions_of_agents_without_side_effects() + + markov_game.set_actions_of_agents_manually(agent_action_safe_copies) + terminated = markov_game.take_simulation_step() + main_node = RolloutTreeNode( + step_log=markov_game.get_step_log(), time_step=time_step + ) + branch_node = RolloutTreeBranchNode(main_child=main_node) + previous_node.child = branch_node + previous_node = main_node + + # Get alternative branches by generating new unilateral actions + for agent_id in markov_game.agent_ids: + for _ in range(nb_alternative_actions): + # Get safe copies for branches + branch_agent_action_safe_copies: dict[ + AgentId, AgentAndActionSafeCopy + ] = { + agent_id: AgentAndActionSafeCopy( + action=copy.deepcopy(agent_action_safe_copy.action), + action_info=copy.deepcopy(agent_action_safe_copy.action_info), + agent_after_action=agent_action_safe_copy.agent_after_action.get_safe_copy(), + ) + for agent_id, agent_action_safe_copy in agent_action_safe_copies.items() + } + mg_branch: MarkovGame = mg_before_action.get_safe_copy() + other_agent_id = [id for id in mg_branch.agent_ids if id != agent_id][0] + mg_branch.set_action_and_agent_after_action_manually( + agent_id=other_agent_id, + agent_action_safe_copy=branch_agent_action_safe_copies[ + other_agent_id + ], + ) + task = asyncio.create_task( + run_with_unilateral_alt_action( + markov_game=mg_branch, + time_step=time_step, + agent_id=agent_id, + branch_node=branch_node, + max_depth=max_depth, + ) + ) + tasks.append(task) + time_step += 1 + + # wait for all branches to complete + await asyncio.gather(*tasks) + + return root diff --git a/src_code_for_reproducibility/markov_games/group_timesteps.py b/src_code_for_reproducibility/markov_games/group_timesteps.py new file mode 100644 index 0000000000000000000000000000000000000000..dad5271c500f539f2719110bd676f183746e51e4 --- /dev/null +++ b/src_code_for_reproducibility/markov_games/group_timesteps.py @@ -0,0 +1,150 @@ +""" +This module contains the logic for grouping time steps. +""" +import copy +from typing import Callable + +from mllm.markov_games.markov_game import MarkovGame +from mllm.markov_games.rollout_tree import ( + AgentActLog, + RolloutTreeBranchNode, + RolloutTreeNode, + RolloutTreeRootNode, + StepLog, +) +from mllm.markov_games.simulation import SimulationStepLog + +AgentId = str + + +def group_time_steps( + rollout_tree: RolloutTreeRootNode, + accumulation_stop_condition: Callable[[StepLog], bool], +) -> RolloutTreeRootNode: + """ + During generation, we create rollout trees according to the real time steps. + However, during training, we might want to treat groups of time steps as a single time step. + As a concrete example, take Trust-and-Split. At each round, say we have X time steps of communication and then one time step for the split. + Then the communication actions will not get any reward, and the split action will get the reward. During REINFORCE training, with discounting, this + can cause training instability. We could instead treat every action in the round as being part of a single action, and give it the reward of the split action. + This method helps to do this sort of grouping. + It accumulates actions until the accumulation_stop_condition is met, and then creates a new node with the accumulated actions. + It then recursively calls itself on the child node. + Details: + - The reward for the group is the reward of the last time step in the group. + - The simulation log for the group is the simulation log of the last time step in the group. + - The state end for the group becomes the first state end in the group. + - The agent info for the group is the agent info of the last time step in the group. + """ + + def group_step_logs(step_logs: list[StepLog]) -> StepLog: + """ + Concatenate per-agent chat turns across steps; keep only the first is_state_end. + """ + last_sim_log = step_logs[-1].simulation_step_log + agent_ids = {aid for s in step_logs for aid in s.action_logs.keys()} + grouped_logs: dict[AgentId, AgentActLog] = {} + for aid in agent_ids: + turns = [] + for s in step_logs: + act = s.action_logs.get(aid) + if act and act.chat_turns: + turns.extend(copy.deepcopy(act.chat_turns)) + disable_is_state_end = False + # Only the first state_end should be True, the rest should be False + for t in turns: + if t.is_state_end: + if disable_is_state_end: + t.is_state_end = False + else: + disable_is_state_end = True + continue + grouped_logs[aid] = AgentActLog( + chat_turns=turns, info=step_logs[-1].action_logs[aid].info + ) + return StepLog(action_logs=grouped_logs, simulation_step_log=last_sim_log) + + def group_time_steps_rec( + current_node: RolloutTreeNode | RolloutTreeBranchNode, + group_time_step: int, + accumulation_step_logs: list[StepLog], + ) -> RolloutTreeNode | RolloutTreeBranchNode: + """ + Groups time steps. Recursion is used to handle branches. + """ + assert isinstance(current_node, RolloutTreeNode) or isinstance( + current_node, RolloutTreeBranchNode + ), "Current node must be a tree node or a branch node. Is of type: " + str( + type(current_node) + ) + first_group_node = None + current_group_node = None + while current_node is not None: + if isinstance(current_node, RolloutTreeBranchNode): + raise Exception( + "Grouping timesteps by round is not supported for branching trajectories yet." + ) + # Special recursive case for branches + # if isinstance(current_node, RolloutTreeBranchNode): + # branches = {} + # for agent_id, branch_nodes in current_node.branches.items(): + # branch_group_nodes = [] + # for branch_node in branch_nodes: + # branch_group_node = group_time_steps_rec( + # current_node=branch_node, + # group_time_step=group_time_step, + # accumulation_step_logs=copy.deepcopy(accumulation_step_logs)) + # branch_group_nodes.append(branch_group_node) + # branches[agent_id] = branch_group_nodes + + # main_child_group_node = group_time_steps_rec( + # current_node=current_node.main_child, + # group_time_step=group_time_step, + # accumulation_step_logs=copy.deepcopy(accumulation_step_logs)) + + # return RolloutTreeBranchNode(main_child=main_child_group_node, branches=branches) + + # Accumulate + accumulation_step_logs.append(current_node.step_log) + if accumulation_stop_condition(current_node.step_log): + grouped_step_logs = group_step_logs(accumulation_step_logs) + accumulation_step_logs = [] + new_group_node = RolloutTreeNode( + step_log=grouped_step_logs, time_step=group_time_step, child=None + ) + if first_group_node == None: + first_group_node = new_group_node + group_time_step += 1 + if current_group_node is not None: + current_group_node.child = new_group_node + current_group_node = new_group_node + current_node = current_node.child + return first_group_node + + node = group_time_steps_rec( + current_node=rollout_tree.child, group_time_step=0, accumulation_step_logs=[] + ) + return RolloutTreeRootNode( + id=rollout_tree.id, + crn_id=rollout_tree.crn_id, + child=node, + agent_ids=rollout_tree.agent_ids, + ) + + +def stop_when_round_ends(step_log: StepLog) -> bool: + """ + Simplest stop condition. Will return True if step log is the last time step of a round. + This will throw an error if this information is not available in the simulation info. + """ + assert ( + "is_last_timestep_in_round" in step_log.simulation_step_log.info.keys() + ), "To group by round, is_last_timestep_in_round must be set in the info of your simulation step log at each time step." + return step_log.simulation_step_log.info["is_last_timestep_in_round"] + + +def group_by_round(rollout_tree: RolloutTreeRootNode) -> RolloutTreeRootNode: + """ + Groups time steps by round. + """ + return group_time_steps(rollout_tree, stop_when_round_ends) diff --git a/src_code_for_reproducibility/markov_games/linear_runner.py b/src_code_for_reproducibility/markov_games/linear_runner.py new file mode 100644 index 0000000000000000000000000000000000000000..f81ab6d37ae9c680b2f2a53388988117d37f8a47 --- /dev/null +++ b/src_code_for_reproducibility/markov_games/linear_runner.py @@ -0,0 +1,30 @@ +import asyncio +import json +import os.path + +from mllm.markov_games.markov_game import MarkovGame +from mllm.markov_games.rollout_tree import RolloutTreeNode, RolloutTreeRootNode + + +async def LinearRunner( + markov_game: MarkovGame, output_folder: str +) -> RolloutTreeRootNode: + """ + This method generates a trajectory without branching. + """ + time_step = 0 + terminated = False + root = RolloutTreeRootNode( + id=markov_game.get_id(), + crn_id=markov_game.get_crn_id(), + agent_ids=markov_game.get_agent_ids(), + ) + previous_node = root + while not terminated: + terminated, step_log = await markov_game.step() + current_node = RolloutTreeNode(step_log=step_log, time_step=time_step) + previous_node.child = current_node + previous_node = current_node + time_step += 1 + + return root diff --git a/src_code_for_reproducibility/markov_games/markov_game.py b/src_code_for_reproducibility/markov_games/markov_game.py new file mode 100644 index 0000000000000000000000000000000000000000..73a48213bddcf0a59976fa0870eec19f59ae47d9 --- /dev/null +++ b/src_code_for_reproducibility/markov_games/markov_game.py @@ -0,0 +1,208 @@ +""" +This class unifies a simulation, and the agents acting in it (see `simulation.py` & `agent.py`). +In a MarkovGame step, + 1) each agent takes an action, + 2) the state transitions with respect to these actions, + 3) all relevant data of the step is appended to the historical data list + +In order to perform 3), the agents and the simulation are expected, at each time step, +to return a log of the state transition (from their perspective). +For instance, the Simulation might send rewards and the agents might send prompting contexts to be used later to generate the training data. +A different approach would be to simply have the agents keep their data private and log it upon completion of a trajectory. +The approach we use here centralizes the data gathering aspect, +making it easy to create sub-trajectories (in the `runners` defined in `runners.py`) descriptions that +only log information for step transitions occuring after the branching out. +""" +import asyncio +import copy +import json +import os +from dataclasses import dataclass +from typing import Any, List, Literal, Optional, Tuple + +from transformers.models.idefics2 import Idefics2Config + +from mllm.markov_games.agent import Agent +from mllm.markov_games.rollout_tree import AgentActLog, StepLog +from mllm.markov_games.simulation import Simulation + +AgentId = str + + +@dataclass +class AgentAndActionSafeCopy: + action: Any + action_info: AgentActLog + agent_after_action: type[Agent] + + +class MarkovGame(object): + def __init__( + self, + id: int, + agents: dict[AgentId, type[Agent]], + simulation: type[Simulation], + crn_id: int, + ): + """ + Args: + agents: + output_path: + Path where the step infos are saved. + simulation: + Simulation object. Example: IPDSimulation + """ + self.agents = agents + self.agent_ids = self.agents.keys() + self.simulation = simulation + self.simulation_step_log = None + self.agent_step_logs = {agent_id: None for agent_id in self.agent_ids} + self.actions = {} + self.id = id + self.crn_id = crn_id + + def get_id(self) -> str: + return self.id + + def get_crn_id(self) -> int: + return self.crn_id + + def get_agent_ids(self) -> List[AgentId]: + return list(self.agent_ids) + + async def get_action_of_agent_without_side_effects( + self, agent_id: AgentId + ) -> Tuple[Any, AgentActLog]: + """ + Safe function to get an action of an agent without modifying the agent or the simulation. + """ + agent = self.agents[agent_id] + agent_before_action = agent.get_safe_copy() + obs = self.simulation.get_obs_agent(agent_id) + action, action_info = await agent.act(observation=obs) + self.agents[agent_id] = agent_before_action + agent_after_action = agent.get_safe_copy() + return AgentAndActionSafeCopy(action, action_info, agent_after_action) + + async def get_actions_of_agents_without_side_effects( + self, + ) -> dict[AgentId, AgentAndActionSafeCopy]: + """ + Safe function to get an action of an agent without modifying the agent or the simulation. + """ + tasks = [] + for agent_id in self.agent_ids: + task = asyncio.create_task( + self.get_action_of_agent_without_side_effects(agent_id) + ) + tasks.append(task) + agent_and_action_safe_copies: list[ + AgentAndActionSafeCopy + ] = await asyncio.gather(*tasks) + return { + agent_id: agent_and_action_safe_copy + for agent_id, agent_and_action_safe_copy in zip( + self.agent_ids, agent_and_action_safe_copies + ) + } + + def set_action_and_agent_after_action_manually( + self, + agent_id: AgentId, + agent_action_safe_copy: AgentAndActionSafeCopy, + ): + """ + Set the action and the agent after action manually. + """ + self.actions[agent_id] = agent_action_safe_copy.action + self.agent_step_logs[agent_id] = agent_action_safe_copy.action_info + self.agents[agent_id] = agent_action_safe_copy.agent_after_action + + def set_actions_of_agents_manually( + self, actions: dict[AgentId, AgentAndActionSafeCopy] + ): + """ + Set the actions of agents manually. + """ + for agent_id, agent_action_safe_copy in actions.items(): + self.set_action_and_agent_after_action_manually( + agent_id, agent_action_safe_copy + ) + + async def set_action_of_agent(self, agent_id: AgentId): + """ + TOWRITE + """ + agent = self.agents[agent_id] + obs = self.simulation.get_obs_agent(agent_id) + action, action_info = await agent.act(observation=obs) + self.actions[agent_id] = action + self.agent_step_logs[agent_id] = action_info + + async def set_actions(self): + """ + TOWRITE + """ + # background_tasks = set() + tasks = [] + for agent_id in self.agent_ids: + task = asyncio.create_task(self.set_action_of_agent(agent_id)) + tasks.append(task) + await asyncio.gather(*tasks) + + def take_simulation_step(self): + """ + TOWRITE + """ + terminated, self.simulation_step_log = self.simulation.step(self.actions) + return terminated + + def get_step_log(self) -> StepLog: + """ + TOWRITE + TODO: assert actions and simulation have taken step + """ + step_log = StepLog( + simulation_step_log=self.simulation_step_log, + action_logs=self.agent_step_logs, + ) + return step_log + + async def step(self) -> Tuple[bool, StepLog]: + """ + TOWRITE + """ + await self.set_actions() + terminated = self.take_simulation_step() + step_log = self.get_step_log() + return terminated, step_log + + def get_safe_copy(self): + """ + TOWRITE + """ + + new_markov_game = copy.copy(self) + new_simulation = self.simulation.get_safe_copy() + new_agents = { + agent_id: agent.get_safe_copy() for agent_id, agent in self.agents.items() + } + + # Reassign copied components + new_markov_game.simulation = new_simulation + new_markov_game.agents = new_agents + + # IMPORTANT: ensure agent_ids references the new agents dict, not the original + new_markov_game.agent_ids = new_markov_game.agents.keys() + + # Deep-copy step data to avoid correlation + new_markov_game.simulation_step_log = copy.deepcopy(self.simulation_step_log) + new_markov_game.actions = copy.deepcopy(self.actions) + # Rebuild logs to align exactly with new agent ids + old_agent_step_logs = copy.deepcopy(self.agent_step_logs) + new_markov_game.agent_step_logs = { + agent_id: old_agent_step_logs.get(agent_id) + for agent_id in new_markov_game.agent_ids + } + + return new_markov_game diff --git a/src_code_for_reproducibility/markov_games/mg_utils.py b/src_code_for_reproducibility/markov_games/mg_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..27e0eb606ea6cfd5682c3b21a53e783c23e9b43f --- /dev/null +++ b/src_code_for_reproducibility/markov_games/mg_utils.py @@ -0,0 +1,89 @@ +import asyncio +import copy +from collections.abc import Callable +from dataclasses import dataclass + +from mllm.markov_games.ipd.ipd_agent import IPDAgent +from mllm.markov_games.ipd.ipd_simulation import IPD +from mllm.markov_games.markov_game import MarkovGame +from mllm.markov_games.negotiation.dond_agent import DealNoDealAgent +from mllm.markov_games.negotiation.dond_simulation import DealNoDealSimulation +from mllm.markov_games.negotiation.nego_hard_coded_policies import ( + HardCodedNegoGreedyPolicy, + HardCodedNegoWelfareMaximizingPolicy, +) +from mllm.markov_games.ipd.Ipd_hard_coded_agents import AlwaysCooperateIPDAgent, AlwaysDefectIPDAgent +from mllm.markov_games.negotiation.no_press_nego_agent import NoPressAgent +from mllm.markov_games.negotiation.no_press_nego_simulation import NoPressSimulation +from mllm.markov_games.negotiation.tas_agent import TrustAndSplitAgent +from mllm.markov_games.negotiation.tas_rps_agent import TrustAndSplitRPSAgent +from mllm.markov_games.negotiation.tas_rps_simulation import TrustAndSplitRPSSimulation +from mllm.markov_games.negotiation.tas_simple_agent import TrustAndSplitSimpleAgent +from mllm.markov_games.negotiation.tas_simple_simulation import ( + TrustAndSplitSimpleSimulation, +) +from mllm.markov_games.negotiation.tas_simulation import TrustAndSplitSimulation +from mllm.markov_games.rollout_tree import ( + AgentActLog, + RolloutTreeBranchNode, + RolloutTreeNode, + RolloutTreeRootNode, + StepLog, +) +from mllm.markov_games.simulation import SimulationStepLog + +AgentId = str + + +@dataclass +class AgentConfig: + agent_id: str + agent_name: str + agent_class_name: str + policy_id: str + init_kwargs: dict + + +@dataclass +class MarkovGameConfig: + id: int + seed: int + simulation_class_name: str + simulation_init_args: dict + agent_configs: list[AgentConfig] + + +def init_markov_game_components( + config: MarkovGameConfig, policies: dict[str, Callable[[list[dict]], str]] +): + """ + TOWRITE + """ + agents = {} + agent_names = [] + for agent_config in config.agent_configs: + agent_id = agent_config.agent_id + agent_name = agent_config.agent_name + agent_class = eval(agent_config.agent_class_name) + agent = agent_class( + seed=config.seed, + agent_id=agent_id, + agent_name=agent_name, + policy=policies[agent_config.policy_id], + **agent_config.init_kwargs, + ) + agents[agent_id] = agent + agent_names.append(agent_name) + simulation = eval(config.simulation_class_name)( + seed=config.seed, + agent_ids=list(agents.keys()), + agent_names=agent_names, + **config.simulation_init_args, + ) + markov_game = MarkovGame( + id=config.id, + crn_id=config.seed, + agents=agents, + simulation=simulation, + ) + return markov_game diff --git a/src_code_for_reproducibility/markov_games/negotiation/dond_simulation.py b/src_code_for_reproducibility/markov_games/negotiation/dond_simulation.py new file mode 100644 index 0000000000000000000000000000000000000000..a27d6ce2cf7e31a0cddd341db39ae7898b086115 --- /dev/null +++ b/src_code_for_reproducibility/markov_games/negotiation/dond_simulation.py @@ -0,0 +1,153 @@ +import copy +from dataclasses import dataclass +from typing import Any, Dict, List, Tuple + +from numpy.random import default_rng + +from mllm.markov_games.rollout_tree import SimulationStepLog +from mllm.markov_games.negotiation.nego_simulation import Split, NegotiationState, NegotiationObs, NegotiationSimulation +from mllm.utils.get_coagent_id import get_coagent_id + + +AgentId = str + + +@dataclass +class DealNoDealState(NegotiationState): + item_types: List[str] + values: Dict[AgentId, Dict[str, int]] + +@dataclass +class DealNoDealObs(NegotiationObs): + my_values: Dict[str, int] + item_types: List[str] + previous_values_coagent: Dict[str, int] | None + + +def random_partition_integer(rng, total: int, parts: int) -> List[int]: + if parts <= 0: + return [] + if total <= 0: + return [0 for _ in range(parts)] + cuts = sorted(rng.integers(0, total + 1, size=parts - 1).tolist()) + vals = [] + prev = 0 + for c in cuts + [total]: + vals.append(c - prev) + prev = c + return vals + +class DealNoDealSimulation(NegotiationSimulation): + + def __init__( + self, + item_types: List[str] = ["books", "hats", "balls"], + *args, + **kwargs, + ): + super().__init__(item_types=item_types, *args, **kwargs) + self.reset() + + def _other(self, agent_id: AgentId) -> AgentId: + return get_coagent_id(self.agent_ids, agent_id) + + def _sample_stock(self) -> Dict[str, int]: + # total items between 5 and 7 + total_items = int(self.rng.integers(5, 8)) + # nonnegative per-type counts summing to total_items + parts = random_partition_integer(self.rng, total_items, len(self.item_types)) + # allow zeros per type + return {t: int(c) for t, c in zip(self.item_types, parts)} + + def _sample_values_pair(self) -> Dict[AgentId, Dict[str, int]]: + # Each agent has integer non-negative values that sum to 10 + # Each item type valued by at least one agent + # Some item type valued by both agents + while True: + vals_a = random_partition_integer(self.rng, 10, len(self.item_types)) + vals_b = random_partition_integer(self.rng, 10, len(self.item_types)) + a = {t: int(v) for t, v in zip(self.item_types, vals_a)} + b = {t: int(v) for t, v in zip(self.item_types, vals_b)} + # each item valued by at least one + ok1 = all((a[t] > 0) or (b[t] > 0) for t in self.item_types) + # some item valued by both + ok2 = any((a[t] > 0) and (b[t] > 0) for t in self.item_types) + if ok1 and ok2: + return {self.agent_ids[0]: a, self.agent_ids[1]: b} + + def _is_valid_allocation(self, allocation: Dict[str, int], stock: Dict[str, int]) -> bool: + for t in self.item_types: + v = allocation.get(t) + if v is None: + return False + if not isinstance(v, int): + return False + if v < 0 or v > int(stock.get(t, 0)): + return False + return True + + def set_new_round_of_variant(self): + # Keep same values, resample stock + self.state.quantities = self._sample_stock() + + def get_info_of_variant(self, state: NegotiationState, actions: Dict[AgentId, Any]) -> Dict[str, Any]: + return { + "quantities": copy.deepcopy(state.quantities), + "values": copy.deepcopy(state.values), + 'splits': copy.deepcopy(state.splits), + } + + def get_rewards(self, splits: Dict[AgentId, Split]) -> Dict[AgentId, float]: + """ + Returns the rewards for each agent. + """ + split_a = splits[self.agent_ids[0]].items_given_to_self + split_b = splits[self.agent_ids[1]].items_given_to_self + rewards = {self.agent_ids[0]: 0, self.agent_ids[1]: 0} + for t in self.item_types: + # If not complementary, return 0! + if not split_a[t] + split_b[t] == self.state.quantities[t]: + return {self.agent_ids[0]: 0, self.agent_ids[1]: 0} + rewards[self.agent_ids[0]] += split_a[t] * self.state.values[self.agent_ids[0]][t] + rewards[self.agent_ids[1]] += split_b[t] * self.state.values[self.agent_ids[1]][t] + return rewards + + def get_obs(self): + return {agent_id: self.get_obs_agent(agent_id) for agent_id in self.agent_ids} + + def get_obs_agent(self, agent_id): + other_id = self._other(agent_id) + obs = DealNoDealObs( + round_nb=self.state.round_nb, + last_message=self.state.last_message, + current_agent=self.state.current_agent, + quantities=copy.deepcopy(self.state.quantities), + value=0.0, # unused in DOND + other_agent_split=None, # not meaningful until split + split_phase=self.state.split_phase, + quota_messages_per_agent_per_round=self.quota_messages_per_agent_per_round, + my_values=copy.deepcopy(self.state.values[agent_id]), + item_types=list(self.item_types), + previous_values_coagent=copy.deepcopy(self.state.values.get(other_id, {})), + ) + return obs + + def reset(self): + start_agent = self.agent_ids[self._starting_agent_index] + stock = self._sample_stock() + values = self._sample_values_pair() + self.state = DealNoDealState( + round_nb=0, + last_message="", + current_agent=start_agent, + quantities=stock, + values=values, + previous_values=None, + splits={aid: None for aid in self.agent_ids}, + nb_messages_sent={aid: 0 for aid in self.agent_ids}, + split_phase=False, + item_types=list(self.item_types), + ) + return self.get_obs() + + diff --git a/src_code_for_reproducibility/markov_games/negotiation/nego_agent.py b/src_code_for_reproducibility/markov_games/negotiation/nego_agent.py new file mode 100644 index 0000000000000000000000000000000000000000..9b5bf4e3ca4ee7faa982360674e19d9eff6980dc --- /dev/null +++ b/src_code_for_reproducibility/markov_games/negotiation/nego_agent.py @@ -0,0 +1,242 @@ +import copy +from abc import abstractmethod +from collections.abc import Callable +from dataclasses import dataclass +from typing import Any, Dict, List, Tuple + +import numpy as np + +from mllm.markov_games.agent import Agent +from mllm.markov_games.negotiation.nego_simulation import Message, NegotiationObs, Split +from mllm.markov_games.rollout_tree import AgentActLog, ChatTurn + + +@dataclass +class NegotiationAgentState: + round_nb: int + nb_messages_sent_this_round: int + chat_counter: int + chat_history: List[ChatTurn] + + +class NegotiationAgent(Agent): + def __init__( + self, + seed: int, + agent_id: str, + agent_name: str, + policy: Callable[[List[Dict]], str], + goal: str, + exploration_prompts: List[str] = [], + exploration_prompt_probs: List[float] = [], + ): + self.seed = seed + self.agent_id = agent_id + self.agent_name = agent_name + self.policy = policy + self.goal = goal + self.exploration_prompts_toggled = len(exploration_prompts) > 0 + if self.exploration_prompts_toggled: + exploration_prompts = copy.deepcopy(exploration_prompts) + exploration_prompts.append(None) + self.exploration_prompts = exploration_prompts + self.exploration_prompt_probs = np.array(exploration_prompt_probs) + assert self.exploration_prompt_probs.sum() <= 1 + assert np.all(self.exploration_prompt_probs >= 0) + self.exploration_prompt_probs = np.append( + self.exploration_prompt_probs, 1 - self.exploration_prompt_probs.sum() + ) + self.state = NegotiationAgentState( + round_nb=0, nb_messages_sent_this_round=0, chat_counter=0, chat_history=[] + ) + + # Implemented in variants + self.intro_prompt = "" + self.new_round_prompt = "" + self.last_round_prompt = "" + self.send_split_prompt = "" + self.wait_for_message_prompt = "" + self.last_message_prompt = "" + self.send_message_prompt = "" + + @abstractmethod + def get_message_regex(self, observation: NegotiationObs) -> str: + pass + + @abstractmethod + def get_split_regex(self, observation: NegotiationObs) -> str: + pass + + @abstractmethod + def get_split_action( + self, policy_output: str, observation: NegotiationObs + ) -> Split: + pass + + async def act(self, observation: NegotiationObs) -> Tuple[Any, AgentActLog]: + def dict_to_str(d: dict) -> str: + return ", ".join(f"{v} {k}" for k, v in d.items()) + + def dict_to_eq_str(d: dict) -> str: + return ", ".join(f"{k}={v}" for k, v in d.items()) + + is_our_turn = observation.current_agent == self.agent_id + action: Any = None + round_nb = observation.round_nb + + prompt_parts: List[str] = [] + obs_ctx = vars(observation) + obs_ctx_formmated = obs_ctx.copy() + for key in obs_ctx_formmated: + if isinstance(obs_ctx_formmated[key], dict) and "value" not in key: + obs_ctx_formmated[key] = dict_to_str(obs_ctx_formmated[key]) + elif isinstance(obs_ctx_formmated[key], dict) and "value" in key: + obs_ctx_formmated[key] = dict_to_eq_str(obs_ctx_formmated[key]) + + ####################################### + # build user prompt + ####################################### + + # First-ever call + is_intro = round_nb == 0 and self.state.chat_counter == 0 + if is_intro: + prompt_parts.append( + self.intro_prompt.format( + goal=self.goal, agent=self.agent_name, **obs_ctx_formmated + ) + ) + + # New round + is_new_round = round_nb > self.state.round_nb + if is_new_round or is_intro: + self.state.nb_messages_sent_this_round = 0 + if not is_intro: + prompt_parts.append(self.last_round_prompt.format(**obs_ctx_formmated)) + prompt_parts.append(self.new_round_prompt.format(**obs_ctx_formmated)) + if self.exploration_prompts_toggled: + exploration_prompt = self.exploration_prompts[ + np.random.choice( + len(self.exploration_prompts), p=self.exploration_prompt_probs + ) + ] + if exploration_prompt is not None: + prompt_parts.append(exploration_prompt) + self.state.round_nb = round_nb + + # Wait for message + if not is_our_turn and not observation.split_phase: + prompt_parts.append( + self.wait_for_message_prompt.format(**obs_ctx_formmated) + ) + + # Get last message + if is_our_turn and not is_new_round and not is_intro: + prompt_parts.append(self.last_message_prompt.format(**obs_ctx_formmated)) + + # Prompt to send message + must_send_message = not observation.split_phase and is_our_turn + if must_send_message: + prompt_parts.append(self.send_message_prompt.format(**obs_ctx_formmated)) + + # Prompt to give split + must_send_split = not must_send_message and observation.split_phase + if must_send_split: + var_names = ["x", "y", "z", "w"] # Extend as needed + items_str = ", ".join( + [ + f"{var_names[i]} {item}" + for i, item in enumerate(obs_ctx["quantities"].keys()) + ] + ) + ranges_str = ", ".join( + [ + f"{var_names[i]}: 0-{obs_ctx['quantities'][item]} (integer)" + for i, item in enumerate(obs_ctx["quantities"].keys()) + ] + ) + proposal_style = f"Proposal: {items_str} where {ranges_str}." + proposal_style2 = ( + f" {items_str} where {ranges_str}." + ) + prompt_parts.append( + self.send_split_prompt.format( + proposal_style=proposal_style, + proposal_style2=proposal_style2, + **obs_ctx_formmated, + ) + ) + + # Append one ChatTurn with is_state_end=True + user_prompt = "\n".join(prompt_parts) + self.state.chat_history.append( + ChatTurn( + agent_id=self.agent_id, + role="user", + content=user_prompt, + is_state_end=True, + ) + ) + + ####################################### + # Get policy action + ####################################### + + # Query policy for the appropriate format + if must_send_message: + return_regex = self.get_message_regex(observation) + policy_output = await self.policy( + state=self.state.chat_history, + agent_id=self.agent_id, + regex=return_regex, + ) + self.state.chat_history.append( + ChatTurn( + agent_id=self.agent_id, + role="assistant", + content=policy_output.content, + reasoning_content=policy_output.reasoning_content, + log_probs=policy_output.log_probs, + out_token_ids=policy_output.out_token_ids, + is_state_end=False, + ) + ) + action = Message(message=policy_output.content) + self.state.nb_messages_sent_this_round += 1 + + elif must_send_split: + return_regex = self.get_split_regex(observation) + policy_output = await self.policy( + state=self.state.chat_history, + agent_id=self.agent_id, + regex=return_regex, + ) + self.state.chat_history.append( + ChatTurn( + agent_id=self.agent_id, + role="assistant", + content=policy_output.content, + reasoning_content=policy_output.reasoning_content, + log_probs=policy_output.log_probs, + out_token_ids=policy_output.out_token_ids, + is_state_end=False, + ) + ) + action = self.get_split_action(policy_output.content, observation) + else: + action = None + + agent_step_log = AgentActLog( + chat_turns=self.state.chat_history[self.state.chat_counter :], info=None + ) + self.state.chat_counter = len(self.state.chat_history) + return action, agent_step_log + + def get_safe_copy(self): + agent_copy = copy.copy(self) + agent_copy.state = copy.deepcopy(self.state) + return agent_copy + + def reset(self): + self.state = NegotiationAgentState( + round_nb=0, nb_messages_sent_this_round=0, chat_counter=0, chat_history=[] + ) diff --git a/src_code_for_reproducibility/markov_games/negotiation/no_press_nego_simulation.py b/src_code_for_reproducibility/markov_games/negotiation/no_press_nego_simulation.py new file mode 100644 index 0000000000000000000000000000000000000000..d182187cc72c889a76f2d1c5be4b3afb6b923ed8 --- /dev/null +++ b/src_code_for_reproducibility/markov_games/negotiation/no_press_nego_simulation.py @@ -0,0 +1,168 @@ +import copy +from collections import defaultdict +from dataclasses import dataclass +from typing import Any, Dict, List, Literal, Tuple + +from mllm.markov_games.negotiation.nego_simulation import ( + NegotiationObs, + NegotiationSimulation, + NegotiationState, + Split, + compute_tas_style_rewards, +) + +AgentId = str + + +@dataclass +class NoPressState(NegotiationState): + pass + + +@dataclass +class NoPressObs(NegotiationObs): + other_value: Dict[str, float] + + +class NoPressSimulation(NegotiationSimulation): + def __init__( + self, + game_type: Literal["10-1-exclusive", "10-1-ties", "1-to-20"] = "1-to-20", + same_round_value: bool = True, + atleast_one_conflict: bool = False, + *args, + **kwargs, + ): + self.game_type = game_type + self.same_round_value = same_round_value + self.atleast_one_conflict = atleast_one_conflict + super().__init__(*args, **kwargs) + + def _sample_values(self) -> Dict[AgentId, dict]: + values = defaultdict(dict) + if self.state is None: + item_types = self.item_types + else: + item_types = list(self.state.quantities.keys()) + while True: + for item in item_types: + if self.game_type == "10-1-exclusive": + v = int(self.rng.choice([1, 10])) + values[self.agent_ids[0]][item] = v + values[self.agent_ids[1]][item] = 10 if v == 1 else 1 + elif self.game_type == "10-1-ties": + for aid in self.agent_ids: + values[aid][item] = int(self.rng.choice([1, 10])) + elif self.game_type == "1-to-20": + for aid in self.agent_ids: + values[aid][item] = int(self.rng.integers(1, 21)) + if self.atleast_one_conflict: + has_conflict = False + for item in item_types: + agent_values_for_item = [ + values[aid][item] for aid in self.agent_ids + ] + if len(set(agent_values_for_item)) > 1: + has_conflict = True + break + if not has_conflict: + continue + agent_values = [sum(v.values()) for v in values.values()] + if len(set(agent_values)) == 1 or not self.same_round_value: + break + return values + + def _sample_quantities(self) -> Dict[str, int]: + return {item.lower(): 10 for item in self.item_types} + + def set_new_round_of_variant(self): + self.state.quantities = self._sample_quantities() + self.state.values = self._sample_values() + self.state.split_phase = True + + def get_info_of_variant( + self, state: NegotiationState, actions: Dict[AgentId, Any] + ) -> Dict[str, Any]: + return { + "quantities": copy.deepcopy(state.quantities), + "values": copy.deepcopy(state.values), + "splits": copy.deepcopy(state.splits), + } + + def get_rewards(self, splits: Dict[AgentId, Split]) -> Dict[AgentId, float]: + return compute_tas_style_rewards( + self.agent_ids, self.state.values, splits, self.state.quantities + ) + + def get_obs(self): + return {agent_id: self.get_obs_agent(agent_id) for agent_id in self.agent_ids} + + def get_obs_agent(self, agent_id): + other_id = self._other(agent_id) + last_value_coagent = ( + None + if self.state.previous_values is None + else self.state.previous_values.get(other_id) + ) + last_points_coagent = ( + None + if self.state.previous_points is None + else round(self.state.previous_points.get(other_id), 1) + ) + last_value_agent = ( + None + if self.state.previous_values is None + else self.state.previous_values.get(agent_id) + ) + last_points_agent = ( + None + if self.state.previous_points is None + else round(self.state.previous_points.get(agent_id), 1) + ) + last_split_coagent = None + last_split_agent = None + if self.state.previous_splits is not None: + last_split_coagent = self.state.previous_splits[ + other_id + ].items_given_to_self + last_split_agent = self.state.previous_splits[agent_id].items_given_to_self + obs = NoPressObs( + round_nb=self.state.round_nb, + last_message="", + quota_messages_per_agent_per_round=self.quota_messages_per_agent_per_round, + current_agent=self.state.current_agent, + other_agent=self.agent_id_to_name[other_id], + quantities=self.state.quantities, + item_types=self.item_types, + value=self.state.values[agent_id], + split_phase=self.state.split_phase, + last_split_agent=last_split_agent, + last_value_agent=last_value_agent, + last_points_agent=last_points_agent, + last_split_coagent=last_split_coagent, + last_value_coagent=last_value_coagent, + last_points_coagent=last_points_coagent, + other_value=self.state.values[other_id], + last_quantities=self.state.previous_quantities, + ) + return obs + + def reset(self): + start_agent = self.agent_ids[self._starting_agent_index] + quantities = self._sample_quantities() + values = self._sample_values() + self.state = NoPressState( + round_nb=0, + last_message="", + current_agent=start_agent, + quantities=quantities, + values=values, + previous_values=None, + splits={aid: None for aid in self.agent_ids}, + nb_messages_sent={aid: 0 for aid in self.agent_ids}, + split_phase=True, + previous_splits=None, + previous_points=None, + previous_quantities=None, + ) + return self.get_obs() diff --git a/src_code_for_reproducibility/markov_games/negotiation/tas_agent.py b/src_code_for_reproducibility/markov_games/negotiation/tas_agent.py new file mode 100644 index 0000000000000000000000000000000000000000..002160873969ab7292f0f62a091e12ec376022c6 --- /dev/null +++ b/src_code_for_reproducibility/markov_games/negotiation/tas_agent.py @@ -0,0 +1,108 @@ +from mllm.markov_games.negotiation.nego_agent import NegotiationAgent +from mllm.markov_games.negotiation.nego_simulation import Split +from mllm.markov_games.negotiation.tas_simulation import TrustAndSplitObs + + +class TrustAndSplitAgent(NegotiationAgent): + def __init__(self, num_message_chars, *args, **kwargs): + self.num_message_chars = num_message_chars + super().__init__(*args, **kwargs) + self.intro_prompt = ( + "Welcome to an iterated game. You are {agent}. The other agent is {other_agent}.\n" + "Setup:\n" + "1. The game has multiple independent rounds.\n" + "2. In each round, there are multiple items to split between the two agents.\n" + "3. Both agents are assigned a per-item value between 1 and 20 (inclusive) in each round.\n" + "4. You can only observe your own per-item values.\n" + "5. Because assignments are random, both agents are equally likely to have same expected per-item value.\n" + "\n" + "Protocol:\n" + "1. At the start of the round, one agent begins the conversation. The starting role alternates each round.\n" + "2. Agents exchange a short chat ({quota_messages_per_agent_per_round} messages per round per agent) to negotiate how to split the item.\n" + " - Use this chat to communicate your private per-item value to make informed proposals.\n" + "3. After the chat, both agents simultaneously propose the amount of each item they will keep.\n" + "4. If the total sum of proposals is less than or equal to the item quantity, both agents receive their proposed amounts.\n" + "5. If the total sum of proposals exceeds the item quantity, they are allocated proportionally.\n" + "6. Your points for the round = (amount you receive per item) x (your per-item value for that round), added across all items.\n" + "7. Points are accumulated across rounds.\n" + "Your goal: {goal}\n" + ) + self.new_round_prompt = ( + "A New Round Begins\n" + "The items to split are {quantities}.\n" + "Your per-item values are {value}." + ) + self.last_round_prompt = ( + "Last Round Summary:\n" + " - Items to split: {last_quantities}\n" + " - Your per-item values: {last_value_agent}\n" + " - {other_agent}'s per-item values: {last_value_coagent}\n" + " - You proposed: {last_split_agent}\n" + " - You earned: {last_points_agent} points\n" + " - {other_agent} proposed: {last_split_coagent}\n" + " - {other_agent} earned: {last_points_coagent} points\n" + " - Round Complete.\n" + ) + self.send_split_prompt = ( + "Message quota is finished for this round.\n" + "{other_agent} has finalized their proposal.\n" + "Submit your finalization now\n" + "Respond with {proposal_style2}" + ) + # self.wait_for_message_prompt = "Wait for {other_agent} to send a message..." + self.wait_for_message_prompt = "" + self.last_message_prompt = "{other_agent} said: {last_message}" + # self.send_message_prompt = ( + # f"Send your message now (max {self.num_message_chars} chars)." + # ) + self.send_message_prompt = f"Send your message now in ... (<={self.num_message_chars} chars)." + + def get_message_regex(self, observation: TrustAndSplitObs) -> str: + return rf"[\s\S]{{0,{self.num_message_chars}}}" + + # def get_message_regex(self, observation: TrustAndSplitObs) -> str: + # return rf"(?s).{{0,{self.num_message_chars}}}" + + def get_split_regex(self, observation: TrustAndSplitObs) -> str: + items = list(observation.quantities.keys()) + # Accept both singular and plural forms + item_pattern = "|".join( + [f"{item[:-1]}s?" if item.endswith("s") else f"{item}s?" for item in items] + ) + regex = rf"(?i) ?((?:\s*(?P(10|[0-9]))\s*(?P{item_pattern})\s*,?)+) ?" + return regex + + def get_split_action( + self, policy_output: str, observation: TrustAndSplitObs + ) -> Split: + items = list(observation.quantities.keys()) + import re as _re + + split_regex = self.get_split_regex(observation) + items_given_to_self = {item: 0 for item in items} + m = _re.match(split_regex, policy_output.strip()) + if m: + # Find all (number, item) pairs + item_pattern = "|".join( + [ + f"{item[:-1]}s?" if item.endswith("s") else f"{item}s?" + for item in items + ] + ) + inner_regex = rf"(?i)(10|[0-9])\s*({item_pattern})" + + def normalize_item_name(item_str): + for orig in items: + if item_str.lower() == orig.lower(): + return orig + if orig.endswith("s") and item_str.lower() == orig[:-1].lower(): + return orig + if ( + not orig.endswith("s") + and item_str.lower() == orig.lower() + "s" + ): + return orig + + for num, item in _re.findall(inner_regex, m.group(1)): + items_given_to_self[normalize_item_name(item)] = int(num) + return Split(items_given_to_self=items_given_to_self) diff --git a/src_code_for_reproducibility/markov_games/negotiation/tas_rps_agent.py b/src_code_for_reproducibility/markov_games/negotiation/tas_rps_agent.py new file mode 100644 index 0000000000000000000000000000000000000000..e711c2a65d336e4d9b991c68662069e96b4dfee8 --- /dev/null +++ b/src_code_for_reproducibility/markov_games/negotiation/tas_rps_agent.py @@ -0,0 +1,118 @@ +import copy +from collections.abc import Callable +from dataclasses import dataclass +from typing import Any, Dict, List, Tuple + +from mllm.markov_games.agent import Agent +from mllm.markov_games.negotiation.nego_agent import ( + Message, + NegotiationAgent, + NegotiationAgentState, + Split, +) +from mllm.markov_games.negotiation.tas_rps_simulation import TrustAndSplitRPSObs +from mllm.markov_games.rollout_tree import AgentActLog, ChatTurn + + +class TrustAndSplitRPSAgent(NegotiationAgent): + def __init__( + self, + num_message_chars: int, + message_start_end_format: bool = False, + proposal_start_end_format: bool = False, + *args, + **kwargs, + ): + self.num_message_chars = num_message_chars + self.message_start_end_format = message_start_end_format + self.proposal_start_end_format = proposal_start_end_format + super().__init__(*args, **kwargs) + self.intro_prompt = ( + "Welcome to an iterated game. You are {agent}. The other agent is {other_agent}.\n" + "\n" + "Setup:\n" + "1. The game has multiple independent rounds.\n" + "2. In each round, there are 10 coins to split between the two agents.\n" + "3. Each agent's per-coin value for that round is determined as follows:\n" + " - Both agents are randomly assigned a rock, paper or scissors hands\n" + " - Rock has the upper hand over scissors, scissors has the upper hand over paper and paper has the upper hand over rock.\n" + " - The agent with the upper hand has a per-coin value of 10.\n" + " - The agent with the lower hand has a per-coin value of 1.\n" + "4. You only see your own hand, but you may communicate it in messages and infer your value based on the other agent's hand.\n" + "5. Over many rounds both agents are equally likely to have the upper and lower hand.\n" + "\n" + "Protocol:\n" + "1. At the start of the round, one agent begins the conversation. The starting role alternates each round.\n" + "2. Agents exchange a short chat ({quota_messages_per_agent_per_round} messages per round per agent) to negotiate how to split the 10 coins.\n" + " - Use this chat to communicate your hand so that both agents can determine their per-coin values.\n" + "3. After the chat, both agents simultaneously propose how many coins they keep.\n" + "4. If the total sum of proposals is less than or equal to 10, both agents receive their proposals.\n" + "5. If the total sum of proposals exceeds 10, the coins are allocated proportionally.\n" + "6. Your points for the round = (coins you receive) x (your per-coin value for that round). \n" + "7. The points are accumulated across rounds.\n" + "Your goal: {goal}\n" + ) + self.new_round_prompt = ( + "A New Round Begins\n" + "Your hand is {hand}. You don't know {other_agent}'s hand yet.\n" + ) + # self.last_round_prompt = ( + # "Last Round Summary:\n" + # " - Your hand: {last_hand_agent}\n" + # " - {other_agent}'s hand: {last_hand_coagent}\n" + # " - Your value per coin: {last_value_agent}\n" + # " - {other_agent}'s value per coin: {last_value_coagent}\n" + # " - You proposed: {last_split_agent} coins\n" + # " - You earned: {last_points_agent} points\n" + # " - {other_agent} proposed: {last_split_coagent} coins\n" + # " - {other_agent} earned: {last_points_coagent} points\n" + # " - Round Complete.\n" + # ) + self.last_round_prompt = "In the previous round, {other_agent} had a {last_hand_value_coagent} hand and proposed {last_split_coagent} coins.\n" + if self.proposal_start_end_format: + self.send_split_prompt = ( + "Submit your proposal\n" + "Respond with <> x <> where x is an integer in [0, 10]." + ) + else: + self.send_split_prompt = ( + "Submit your proposal\n" + "Respond with x where x is an integer in [0, 10]." + ) + self.wait_for_message_prompt = "Wait for {other_agent} to send a message..." + # self.wait_for_message_prompt = "" + self.last_message_prompt = "{other_agent} said: {last_message}" + if self.message_start_end_format: + self.send_message_prompt = f"Send your message now in <>...<> (<={self.num_message_chars} chars)." + else: + self.send_message_prompt = f"Send your message now in ... (<={self.num_message_chars} chars)." + + def get_message_regex(self, observation: TrustAndSplitRPSObs) -> str: + if self.message_start_end_format: + return ( + rf"<>[\s\S]{{0,{self.num_message_chars}}}<>" + ) + else: + return rf"[\s\S]{{0,{self.num_message_chars}}}" + + def get_split_regex(self, observation: TrustAndSplitRPSObs) -> str: + if self.proposal_start_end_format: + return r"<> ?(10|[0-9]) ?<>" + else: + return r" ?(10|[0-9]) ?" + + def get_split_action( + self, policy_output: str, observation: TrustAndSplitRPSObs + ) -> Split: + import re as _re + + if self.proposal_start_end_format: + m = _re.search( + r"<> ?(10|[0-9]) ?<>", policy_output + ) + else: + m = _re.search( + r" ?(10|[0-9]) ?", policy_output + ) + coins_int = int(m.group(1)) if m else int(policy_output) + return Split(items_given_to_self={"coins": coins_int}) diff --git a/src_code_for_reproducibility/markov_games/negotiation/tas_simple_simulation.py b/src_code_for_reproducibility/markov_games/negotiation/tas_simple_simulation.py new file mode 100644 index 0000000000000000000000000000000000000000..3dbd0c43d73e3f7b18204b62e71d72b2df1d13e6 --- /dev/null +++ b/src_code_for_reproducibility/markov_games/negotiation/tas_simple_simulation.py @@ -0,0 +1,169 @@ +import copy +from collections import defaultdict +from dataclasses import dataclass +from typing import Any, Dict, List, Literal + +from numpy.random import default_rng + +from mllm.markov_games.negotiation.nego_simulation import ( + NegotiationObs, + NegotiationSimulation, + NegotiationState, + Split, + compute_tas_style_rewards, +) + +AgentId = str + + +@dataclass +class TrustAndSplitSimpleState(NegotiationState): + pass + + +@dataclass +class TrustAndSplitSimpleObs(NegotiationObs): + last_value_str_coagent: str | None + + +class TrustAndSplitSimpleSimulation(NegotiationSimulation): + def __init__( + self, + game_type: Literal["10-1-exclusive", "1-to-10"] = "1-to-10", + dist_type: Literal["uniform", "bimodal"] = "uniform", + beta_dist_alpha: float = 0.1, + beta_dist_beta: float = 0.1, + *args, + **kwargs, + ): + self.game_type = game_type + self.dist_type = dist_type + self.beta_dist_alpha = beta_dist_alpha + self.beta_dist_beta = beta_dist_beta + super().__init__(*args, **kwargs) + + def _sample_values(self) -> Dict[AgentId, dict]: + values = {} + while True: + if self.game_type == "10-1-exclusive": + v = int(self.rng.choice([1, 10])) + values[self.agent_ids[0]] = v + values[self.agent_ids[1]] = 10 if v == 1 else 1 + elif self.game_type == "1-to-10": + for aid in self.agent_ids: + if self.dist_type == "uniform": + values[aid] = int(self.rng.integers(1, 11)) + elif self.dist_type == "bimodal": + alpha, beta = self.beta_dist_alpha, self.beta_dist_beta + values[aid] = int(round(self.rng.beta(alpha, beta) * 9) + 1) + if len(set(values.values())) != 1: + break + return values + + def _sample_quantities(self) -> Dict[str, int]: + return {"coins": 10} + + def set_new_round_of_variant(self): + self.state.quantities = self._sample_quantities() + self.state.values = self._sample_values() + self.state.split_phase = False + + def get_info_of_variant( + self, state: NegotiationState, actions: Dict[AgentId, Any] + ) -> Dict[str, Any]: + return { + "quantities": copy.deepcopy(state.quantities), + "values": copy.deepcopy(state.values), + # "previous_values": copy.deepcopy(state.previous_values), + "splits": copy.deepcopy(state.splits), + } + + def get_rewards(self, splits: Dict[AgentId, Split]) -> Dict[AgentId, float]: + return compute_tas_style_rewards( + self.agent_ids, self.state.values, splits, self.state.quantities + ) + + def get_obs(self): + return {agent_id: self.get_obs_agent(agent_id) for agent_id in self.agent_ids} + + def get_obs_agent(self, agent_id): + other_id = self._other(agent_id) + last_value_coagent = ( + None + if self.state.previous_values is None + else self.state.previous_values.get(other_id) + ) + last_points_coagent = ( + None + if self.state.previous_points is None + else round(self.state.previous_points.get(other_id), 1) + ) + last_value_agent = ( + None + if self.state.previous_values is None + else self.state.previous_values.get(agent_id) + ) + last_points_agent = ( + None + if self.state.previous_points is None + else round(self.state.previous_points.get(agent_id), 1) + ) + last_split_coagent = None + last_split_agent = None + if self.state.previous_splits is not None: + last_split_coagent = self.state.previous_splits[ + other_id + ].items_given_to_self["coins"] + last_split_agent = self.state.previous_splits[agent_id].items_given_to_self[ + "coins" + ] + if last_value_agent is None or last_value_coagent is None: + last_value_str_coagent = None + else: + if last_value_coagent > last_value_agent: + last_value_str_coagent = "higher" + elif last_value_coagent < last_value_agent: + last_value_str_coagent = "lower" + else: + raise ValueError("Should not be equal values") + + obs = TrustAndSplitSimpleObs( + round_nb=self.state.round_nb, + last_message=self.state.last_message, + quota_messages_per_agent_per_round=self.quota_messages_per_agent_per_round, + current_agent=self.state.current_agent, + other_agent=self.agent_id_to_name[other_id], + quantities=self.state.quantities, + item_types=self.item_types, + value=self.state.values[agent_id], + split_phase=self.state.split_phase, + last_split_agent=last_split_agent, + last_value_agent=last_value_agent, + last_points_agent=last_points_agent, + last_split_coagent=last_split_coagent, + last_value_coagent=last_value_coagent, + last_points_coagent=last_points_coagent, + last_quantities=self.state.previous_quantities, + last_value_str_coagent=last_value_str_coagent, + ) + return obs + + def reset(self): + start_agent = self.agent_ids[self._starting_agent_index] + quantities = self._sample_quantities() + values = self._sample_values() + self.state = TrustAndSplitSimpleState( + round_nb=0, + last_message="", + current_agent=start_agent, + quantities=quantities, + values=values, + previous_values=None, + splits={aid: None for aid in self.agent_ids}, + nb_messages_sent={aid: 0 for aid in self.agent_ids}, + split_phase=False, + previous_splits=None, + previous_points=None, + previous_quantities=None, + ) + return self.get_obs() diff --git a/src_code_for_reproducibility/markov_games/negotiation/tas_simulation.py b/src_code_for_reproducibility/markov_games/negotiation/tas_simulation.py new file mode 100644 index 0000000000000000000000000000000000000000..5499a146e9da491757a8105965b2d210f8327134 --- /dev/null +++ b/src_code_for_reproducibility/markov_games/negotiation/tas_simulation.py @@ -0,0 +1,172 @@ +import copy +from collections import defaultdict +from dataclasses import dataclass +from typing import Any, Dict, List, Literal + +from numpy.random import default_rng + +from mllm.markov_games.negotiation.nego_simulation import ( + NegotiationObs, + NegotiationSimulation, + NegotiationState, + Split, + compute_tas_style_rewards, +) + +AgentId = str + + +@dataclass +class TrustAndSplitState(NegotiationState): + pass + + +@dataclass +class TrustAndSplitObs(NegotiationObs): + pass + + +class TrustAndSplitSimulation(NegotiationSimulation): + def __init__( + self, + game_type: Literal["10-1-exclusive", "10-1-ties", "1-to-20"] = "1-to-20", + same_round_value: bool = True, + atleast_one_conflict: bool = False, + *args, + **kwargs, + ): + self.game_type = game_type + self.same_round_value = same_round_value + self.atleast_one_conflict = atleast_one_conflict + super().__init__(*args, **kwargs) + + def _sample_values(self) -> Dict[AgentId, dict]: + values = defaultdict(dict) + if self.state is None: + item_types = self.item_types + else: + item_types = list(self.state.quantities.keys()) + while True: + for item in item_types: + if self.game_type == "10-1-exclusive": + v = int(self.rng.choice([1, 10])) + values[self.agent_ids[0]][item] = v + values[self.agent_ids[1]][item] = 10 if v == 1 else 1 + elif self.game_type == "10-1-ties": + for aid in self.agent_ids: + values[aid][item] = int(self.rng.choice([1, 10])) + elif self.game_type == "1-to-20": + for aid in self.agent_ids: + values[aid][item] = int(self.rng.integers(1, 21)) + agent_values = [sum(v.values()) for v in values.values()] + if self.atleast_one_conflict: + has_conflict = False + for item in item_types: + agent_values_for_item = [ + values[aid][item] for aid in self.agent_ids + ] + if ( + len(set(agent_values_for_item)) > 1 + ): # Different values for this item + has_conflict = True + break + if not has_conflict: + continue + if len(set(agent_values)) == 1 or not self.same_round_value: + break + return values + + def _sample_quantities(self) -> Dict[str, int]: + return {item.lower(): 10 for item in self.item_types} + + def set_new_round_of_variant(self): + self.state.quantities = self._sample_quantities() + self.state.values = self._sample_values() + self.state.split_phase = False + + def get_info_of_variant( + self, state: NegotiationState, actions: Dict[AgentId, Any] + ) -> Dict[str, Any]: + return { + "quantities": copy.deepcopy(state.quantities), + "values": copy.deepcopy(state.values), + # "previous_values": copy.deepcopy(state.previous_values), + "splits": copy.deepcopy(state.splits), + } + + def get_rewards(self, splits: Dict[AgentId, Split]) -> Dict[AgentId, float]: + return compute_tas_style_rewards( + self.agent_ids, self.state.values, splits, self.state.quantities + ) + + def get_obs(self): + return {agent_id: self.get_obs_agent(agent_id) for agent_id in self.agent_ids} + + def get_obs_agent(self, agent_id): + other_id = self._other(agent_id) + last_value_coagent = ( + None + if self.state.previous_values is None + else self.state.previous_values.get(other_id) + ) + last_points_coagent = ( + None + if self.state.previous_points is None + else round(self.state.previous_points.get(other_id), 1) + ) + last_value_agent = ( + None + if self.state.previous_values is None + else self.state.previous_values.get(agent_id) + ) + last_points_agent = ( + None + if self.state.previous_points is None + else round(self.state.previous_points.get(agent_id), 1) + ) + last_split_coagent = None + last_split_agent = None + if self.state.previous_splits is not None: + last_split_coagent = self.state.previous_splits[ + other_id + ].items_given_to_self + last_split_agent = self.state.previous_splits[agent_id].items_given_to_self + obs = TrustAndSplitObs( + round_nb=self.state.round_nb, + last_message=self.state.last_message, + quota_messages_per_agent_per_round=self.quota_messages_per_agent_per_round, + current_agent=self.state.current_agent, + other_agent=self.agent_id_to_name[other_id], + quantities=self.state.quantities, + item_types=self.item_types, + value=self.state.values[agent_id], + split_phase=self.state.split_phase, + last_split_agent=last_split_agent, + last_value_agent=last_value_agent, + last_points_agent=last_points_agent, + last_split_coagent=last_split_coagent, + last_value_coagent=last_value_coagent, + last_points_coagent=last_points_coagent, + last_quantities=self.state.previous_quantities, + ) + return obs + + def reset(self): + start_agent = self.agent_ids[self._starting_agent_index] + quantities = self._sample_quantities() + values = self._sample_values() + self.state = TrustAndSplitState( + round_nb=0, + last_message="", + current_agent=start_agent, + quantities=quantities, + values=values, + previous_values=None, + splits={aid: None for aid in self.agent_ids}, + nb_messages_sent={aid: 0 for aid in self.agent_ids}, + split_phase=False, + previous_splits=None, + previous_points=None, + previous_quantities=None, + ) + return self.get_obs() diff --git a/src_code_for_reproducibility/markov_games/rollout_tree.py b/src_code_for_reproducibility/markov_games/rollout_tree.py new file mode 100644 index 0000000000000000000000000000000000000000..5e90fa6d7a0e0b644ccaf3533a1513568d21868f --- /dev/null +++ b/src_code_for_reproducibility/markov_games/rollout_tree.py @@ -0,0 +1,86 @@ +""" +TODO: add parent to nodes so that some verification can be done. For instance, to ensure that node reward keys match the parent node. +""" + +from __future__ import annotations + +import json +from dataclasses import dataclass +from pathlib import Path +from typing import Any, List, Literal, Optional, Tuple + +import jsonschema +from pydantic import BaseModel, Field, model_validator + +from mllm.chat_utils.chat_turn import ChatTurn + +AgentId = str + + +class SimulationStepLog(BaseModel): + rewards: dict[AgentId, float] + info: Any = None + + +class AgentActLog(BaseModel): + chat_turns: list[ChatTurn] | None + info: Any = None + + @model_validator(mode="after") + def _exactly_one_state_end(self): + """ + This method is used to enforce that for each AgentActLog, there is exactly one ChatTurn which is a state end. + """ + if self.chat_turns != []: + n = sum(1 for t in self.chat_turns if t.is_state_end) + if n != 1: + raise ValueError( + f"AgentActLog must have exactly one ChatTurn with is_state_end=True; got {self.chat_turns}." + ) + return self + else: + return self + + +class StepLog(BaseModel): + action_logs: dict[AgentId, AgentActLog] + simulation_step_log: SimulationStepLog + + +# BranchType = Literal["unilateral_deviation", "common_deviation"] # might not be necessary +# class BranchNodeInfo(BaseModel): +# branch_id: str +# branch_for: AgentId +# branch_type: BranchType + + +class RolloutTreeNode(BaseModel): + step_log: StepLog + time_step: int + child: RolloutTreeNode | RolloutTreeBranchNode | None = None + + +class RolloutTreeBranchNode(BaseModel): + """ + First item of the tuple indicates which agent "called" for an alternative branch. + """ + + main_child: RolloutTreeNode + branches: dict[AgentId, list[RolloutTreeNode]] | None = None + + +class RolloutTreeRootNode(BaseModel): + id: int + crn_id: int # ID of the rng used to generate this rollout tree + child: RolloutTreeNode | RolloutTreeBranchNode | None = None + agent_ids: List[AgentId] = Field(min_length=1) + + +# class RolloutTreeLeafNode(BaseModel): +# step_log: StepLog +# time_step: int + + +# Necessary for self-referential stuff in pydantic +RolloutTreeBranchNode.model_rebuild() +RolloutTreeNode.model_rebuild() diff --git a/src_code_for_reproducibility/markov_games/run_markov_games.py b/src_code_for_reproducibility/markov_games/run_markov_games.py new file mode 100644 index 0000000000000000000000000000000000000000..08b84024668ed4375453d1fd515a78eb9bb23414 --- /dev/null +++ b/src_code_for_reproducibility/markov_games/run_markov_games.py @@ -0,0 +1,24 @@ +import asyncio +from collections.abc import Callable +from dataclasses import dataclass + +from torch._C import ClassType + +from mllm.markov_games.markov_game import MarkovGame +from mllm.markov_games.rollout_tree import RolloutTreeRootNode + + +async def run_markov_games( + runner: Callable[[MarkovGame], RolloutTreeRootNode], + runner_kwargs: dict, + output_folder: str, + markov_games: list[MarkovGame], +) -> list[RolloutTreeRootNode]: + tasks = [] + for mg in markov_games: + tasks.append( + asyncio.create_task( + runner(markov_game=mg, output_folder=output_folder, **runner_kwargs) + ) + ) + return await asyncio.gather(*tasks) diff --git a/src_code_for_reproducibility/markov_games/simulation.py b/src_code_for_reproducibility/markov_games/simulation.py new file mode 100644 index 0000000000000000000000000000000000000000..b0b804e2aa4c288b3d98cc8106cfd727f1cc1e1a --- /dev/null +++ b/src_code_for_reproducibility/markov_games/simulation.py @@ -0,0 +1,87 @@ +""" +A Simulation is the environment of a Markov Game. +The Simulation is not responsible for properly checking / formatting the responses of LLM's. +This is the job of the `Agent` class. +Simulations expect clean actions, and are defined similarly to `gymnasium` environments, except that they are adapted for the Multi-agent setting. +""" + +from abc import ABC, abstractmethod +from typing import Any, Tuple + +from numpy.random import default_rng + +from mllm.markov_games.rollout_tree import SimulationStepLog + + +class Simulation(ABC): + @abstractmethod + def __init__(self, seed: int, *args, **kwargs): + self.seed = seed + self.rng = default_rng(self.seed) + + @abstractmethod + def step(self, actions: Any) -> Tuple[bool, SimulationStepLog]: + """ + Returns terminated, info + """ + raise NotImplementedError + + def get_obs(self): + """Returns all agent observations in dict + + Returns: + observations + """ + raise NotImplementedError + + def get_obs_agent(self, agent_id): + """Returns observation for agent_id""" + raise NotImplementedError + + def get_obs_size(self): + """Returns the shape of the observation""" + raise NotImplementedError + + def get_state(self): + raise NotImplementedError + + def get_state_size(self): + """Returns the shape of the state""" + raise NotImplementedError + + def get_avail_actions(self): + raise NotImplementedError + + def get_avail_agent_actions(self, agent_id): + """Returns the available actions for agent_id""" + raise NotImplementedError + + def get_total_actions(self): + """Returns the total number of actions an agent could ever take""" + # TODO: This is only suitable for a discrete 1 dimensional action space for each agent + raise NotImplementedError + + def get_safe_copy(self): + """ + Return copy of the agent object that is decorrelated from the original object. + """ + raise NotImplementedError + + def reset(self): + """Returns initial observations and states""" + raise NotImplementedError + + def render(self): + raise NotImplementedError + + def close(self): + raise NotImplementedError + + # def seed(self): + # raise NotImplementedError + + def save_replay(self): + raise NotImplementedError + + def get_simulation_info(self): + raise NotImplementedError diff --git a/src_code_for_reproducibility/markov_games/statistics_runner.py b/src_code_for_reproducibility/markov_games/statistics_runner.py new file mode 100644 index 0000000000000000000000000000000000000000..eb30800f4aa4e5f919292023df77fa29e43d9c24 --- /dev/null +++ b/src_code_for_reproducibility/markov_games/statistics_runner.py @@ -0,0 +1,405 @@ +from __future__ import annotations + +import gc +import json +import pickle +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional + +from basic_render import find_iteration_folders + +from mllm.markov_games.rollout_tree import ( + RolloutTreeBranchNode, + RolloutTreeNode, + RolloutTreeRootNode, + SimulationStepLog, +) + + +def _iterate_main_nodes(root: RolloutTreeRootNode) -> Iterator[RolloutTreeNode]: + """ + Iterate the main path nodes without materializing full path lists. + """ + current = root.child + while current is not None: + if isinstance(current, RolloutTreeNode): + yield current + current = current.child + elif isinstance(current, RolloutTreeBranchNode): + # Follow only the main child on the main trajectory + current = current.main_child + else: + break + + +def iterate_main_simulation_logs( + root: RolloutTreeRootNode, +) -> Iterator[SimulationStepLog]: + for node in _iterate_main_nodes(root): + yield node.step_log.simulation_step_log + + +def stream_rollout_files(iteration_folder: Path) -> Iterator[Path]: + for p in iteration_folder.rglob("*.rt.pkl"): + if p.is_file(): + yield p + + +def load_root(path: Path) -> RolloutTreeRootNode: + with open(path, "rb") as f: + data = pickle.load(f) + return RolloutTreeRootNode.model_validate(data) + + +@dataclass +class StatRecord: + mgid: int + crn_id: Optional[int] + iteration: str + values: Dict[str, Any] + + +class StatComputer: + """ + Stateful stat computer that consumes SimulationStepLog instances + and produces final aggregated values for one rollout (mgid). + """ + + def update(self, sl: SimulationStepLog) -> None: # pragma: no cover - interface + raise NotImplementedError + + def finalize(self) -> Dict[str, Any]: # pragma: no cover - interface + raise NotImplementedError + + +def run_stats( + data_root: Path, + game_name: str, + make_computers: Callable[[], List[StatComputer]], + output_filename: Optional[str] = None, + output_format: str = "json", # "json" (dict of lists) or "jsonl" +) -> Path: + """ + Compute stats across all iteration_* folders under data_root. + Writes JSONL to data_root/statistics/. + """ + data_root = Path(data_root) + outdir = data_root / "statistics" + outdir.mkdir(parents=True, exist_ok=True) + # Choose extension by format + default_name = ( + f"{game_name}.stats.json" + if output_format == "json" + else f"{game_name}.stats.jsonl" + ) + outfile = outdir / ( + output_filename if output_filename is not None else default_name + ) + + # Rewrite file each run to keep it clean and small + if outfile.exists(): + outfile.unlink() + + iteration_folders = find_iteration_folders(str(data_root)) + + # If writing JSONL, stream directly; otherwise accumulate minimal records + if output_format == "jsonl": + with open(outfile, "w", encoding="utf-8") as w: + for iteration_folder in iteration_folders: + iteration_name = Path(iteration_folder).name + for pkl_path in stream_rollout_files(Path(iteration_folder)): + root = load_root(pkl_path) + + computers = make_computers() + for sl in iterate_main_simulation_logs(root): + for comp in computers: + try: + comp.update(sl) + except Exception: + continue + + values: Dict[str, Any] = {} + for comp in computers: + try: + values.update(comp.finalize()) + except Exception: + continue + + rec = { + "mgid": getattr(root, "id", None), + "crn_id": getattr(root, "crn_id", None), + "iteration": iteration_name, + "stats": values, + } + w.write(json.dumps(rec, ensure_ascii=False) + "\n") + + del root + del computers + gc.collect() + else: + # Aggregate to dict-of-lists for easier plotting + records: List[Dict[str, Any]] = [] + # Process in deterministic order + for iteration_folder in iteration_folders: + iteration_name = Path(iteration_folder).name + for pkl_path in stream_rollout_files(Path(iteration_folder)): + root = load_root(pkl_path) + + computers = make_computers() + for sl in iterate_main_simulation_logs(root): + for comp in computers: + try: + comp.update(sl) + except Exception: + continue + + values: Dict[str, Any] = {} + for comp in computers: + try: + values.update(comp.finalize()) + except Exception: + continue + + records.append( + { + "mgid": getattr(root, "id", None), + "crn_id": getattr(root, "crn_id", None), + "iteration": iteration_name, + "stats": values, + } + ) + + del root + del computers + gc.collect() + + # Build dict-of-lists with nested stats preserved + # Collect all stat keys and nested agent keys where needed + mgids: List[Any] = [] + crn_ids: List[Any] = [] + iterations_out: List[str] = [] + # stats_out is a nested structure mirroring keys but with lists + stats_out: Dict[str, Any] = {} + + # First pass to collect union of keys + stat_keys: set[str] = set() + nested_agent_keys: Dict[str, set[str]] = {} + for r in records: + stats = r.get("stats", {}) or {} + for k, v in stats.items(): + stat_keys.add(k) + if isinstance(v, dict): + nested = nested_agent_keys.setdefault(k, set()) + for ak in v.keys(): + nested.add(str(ak)) + + # Initialize structure + for k in stat_keys: + if k in nested_agent_keys: + stats_out[k] = {ak: [] for ak in sorted(nested_agent_keys[k])} + else: + stats_out[k] = [] + + # Fill lists + for r in records: + mgids.append(r.get("mgid")) + crn_ids.append(r.get("crn_id")) + iterations_out.append(r.get("iteration")) + stats = r.get("stats", {}) or {} + for k in stat_keys: + val = stats.get(k) + if isinstance(stats_out[k], dict): + # per-agent dict + agent_dict = val if isinstance(val, dict) else {} + for ak in stats_out[k].keys(): + stats_out[k][ak].append(agent_dict.get(ak)) + else: + stats_out[k].append(val) + + with open(outfile, "w", encoding="utf-8") as w: + json.dump( + { + "mgid": mgids, + "crn_id": crn_ids, + "iteration": iterations_out, + "stats": stats_out, + }, + w, + ensure_ascii=False, + ) + + return outfile + + +def run_stats_functional( + data_root: Path, + game_name: str, + metrics: Dict[str, Callable[[SimulationStepLog], Optional[Dict[str, float]]]], + output_filename: Optional[str] = None, + output_format: str = "json", +) -> Path: + """ + Functional variant where metrics is a dict of name -> f(SimulationStepLog) -> {agent_id: value}. + Aggregates per rollout by averaging over steps where a metric produced a value. + Writes a single consolidated file in data_root/statistics/. + """ + data_root = Path(data_root) + outdir = data_root / "statistics" + outdir.mkdir(parents=True, exist_ok=True) + default_name = ( + f"{game_name}.stats.json" + if output_format == "json" + else f"{game_name}.stats.jsonl" + ) + outfile = outdir / ( + output_filename if output_filename is not None else default_name + ) + + if outfile.exists(): + outfile.unlink() + + iteration_folders = find_iteration_folders(str(data_root)) + + def finalize_rollout( + agg: Dict[str, Dict[str, List[float]]] + ) -> Dict[str, Dict[str, float]]: + # avg per metric per agent + result: Dict[str, Dict[str, float]] = {} + for mname, agent_values in agg.items(): + result[mname] = {} + for aid, vals in agent_values.items(): + if not vals: + result[mname][aid] = None # keep alignment; could be None + else: + result[mname][aid] = sum(vals) / len(vals) + return result + + if output_format == "jsonl": + with open(outfile, "w", encoding="utf-8") as w: + for iteration_folder in iteration_folders: + iteration_name = Path(iteration_folder).name + for pkl_path in stream_rollout_files(Path(iteration_folder)): + root = load_root(pkl_path) + + # aggregator structure: metric -> agent_id -> list of values + agg: Dict[str, Dict[str, List[float]]] = { + m: {} for m in metrics.keys() + } + + for sl in iterate_main_simulation_logs(root): + for mname, fn in metrics.items(): + try: + vals = fn(sl) + except Exception: + vals = None + if not vals: + continue + for aid, v in vals.items(): + if v is None: + continue + lst = agg[mname].setdefault(str(aid), []) + try: + lst.append(float(v)) + except Exception: + continue + + values = finalize_rollout(agg) + rec = { + "mgid": getattr(root, "id", None), + "crn_id": getattr(root, "crn_id", None), + "iteration": iteration_name, + "stats": values, + } + w.write(json.dumps(rec, ensure_ascii=False) + "\n") + + del root + gc.collect() + else: + records: List[Dict[str, Any]] = [] + for iteration_folder in iteration_folders: + iteration_name = Path(iteration_folder).name + for pkl_path in stream_rollout_files(Path(iteration_folder)): + root = load_root(pkl_path) + + agg: Dict[str, Dict[str, List[float]]] = {m: {} for m in metrics.keys()} + for sl in iterate_main_simulation_logs(root): + for mname, fn in metrics.items(): + try: + vals = fn(sl) + except Exception: + vals = None + if not vals: + continue + for aid, v in vals.items(): + if v is None: + continue + lst = agg[mname].setdefault(str(aid), []) + try: + lst.append(float(v)) + except Exception: + continue + + values = finalize_rollout(agg) + records.append( + { + "mgid": getattr(root, "id", None), + "crn_id": getattr(root, "crn_id", None), + "iteration": iteration_name, + "stats": values, + } + ) + + del root + gc.collect() + + # Build dict-of-lists output + mgids: List[Any] = [] + crn_ids: List[Any] = [] + iterations_out: List[str] = [] + stats_out: Dict[str, Any] = {} + + stat_keys: set[str] = set() + nested_agent_keys: Dict[str, set[str]] = {} + for r in records: + stats = r.get("stats", {}) or {} + for k, v in stats.items(): + stat_keys.add(k) + if isinstance(v, dict): + nested = nested_agent_keys.setdefault(k, set()) + for ak in v.keys(): + nested.add(str(ak)) + + for k in stat_keys: + if k in nested_agent_keys: + stats_out[k] = {ak: [] for ak in sorted(nested_agent_keys[k])} + else: + stats_out[k] = [] + + for r in records: + mgids.append(r.get("mgid")) + crn_ids.append(r.get("crn_id")) + iterations_out.append(r.get("iteration")) + stats = r.get("stats", {}) or {} + for k in stat_keys: + val = stats.get(k) + if isinstance(stats_out[k], dict): + agent_dict = val if isinstance(val, dict) else {} + for ak in stats_out[k].keys(): + stats_out[k][ak].append(agent_dict.get(ak)) + else: + stats_out[k].append(val) + + with open(outfile, "w", encoding="utf-8") as w: + json.dump( + { + "mgid": mgids, + "crn_id": crn_ids, + "iteration": iterations_out, + "stats": stats_out, + }, + w, + ensure_ascii=False, + ) + + return outfile diff --git a/src_code_for_reproducibility/markov_games/vine_ppo.py b/src_code_for_reproducibility/markov_games/vine_ppo.py new file mode 100644 index 0000000000000000000000000000000000000000..f3cd2c89133fa00b0d1d0fa260688efa642d4ded --- /dev/null +++ b/src_code_for_reproducibility/markov_games/vine_ppo.py @@ -0,0 +1,10 @@ +from anytree import Node, RenderTree +from anytree.exporter import DotExporter +import os.path +import asyncio +from mllm.markov_games.markov_game import MarkovGame + +async def VinePPORunner( + markov_game: MarkovGame, + **kwargs): + pass diff --git a/src_code_for_reproducibility/models/__init__.py b/src_code_for_reproducibility/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src_code_for_reproducibility/models/__pycache__/__init__.cpython-312.pyc b/src_code_for_reproducibility/models/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..632ac87b9f56eb9bafde1439f0fb8e11d82c8e3a Binary files /dev/null and b/src_code_for_reproducibility/models/__pycache__/__init__.cpython-312.pyc differ diff --git a/src_code_for_reproducibility/models/__pycache__/inference_backend.cpython-312.pyc b/src_code_for_reproducibility/models/__pycache__/inference_backend.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dcecc73f903ad488a90bff2b6f23d4511d757ba3 Binary files /dev/null and b/src_code_for_reproducibility/models/__pycache__/inference_backend.cpython-312.pyc differ diff --git a/src_code_for_reproducibility/models/__pycache__/inference_backend_dummy.cpython-312.pyc b/src_code_for_reproducibility/models/__pycache__/inference_backend_dummy.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fea16b5f626dc11d3c4194776cf7c149549e4012 Binary files /dev/null and b/src_code_for_reproducibility/models/__pycache__/inference_backend_dummy.cpython-312.pyc differ diff --git a/src_code_for_reproducibility/models/adapter_training_wrapper.py b/src_code_for_reproducibility/models/adapter_training_wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..ccb41c9fb2c314db09a3e8a370ffa50c3648a198 --- /dev/null +++ b/src_code_for_reproducibility/models/adapter_training_wrapper.py @@ -0,0 +1,98 @@ +import torch +import torch.nn as nn +import logging +from typing import Union +from peft import ( + LoraConfig, + get_peft_model, +) + +logger = logging.getLogger(__name__) + + +class AdapterWrapper(nn.Module): + """ + A thin façade that + • keeps a reference to a *shared* PEFT-wrapped model, + • ensures `set_adapter(adapter)` is called on every forward, + • exposes only the parameters that should be trained for that adapter + (plus whatever extra modules you name). + """ + def __init__( + self, + shared_llm: nn.Module, + adapter_id: str, + lora_config: dict, + path: Union[str, None] = None, + ): + super().__init__() + self.shared_llm = shared_llm + self.adapter_id = adapter_id + lora_config = LoraConfig(**lora_config) + # this modifies the shared llm in place, adding a lora adapter inside + self.shared_llm = get_peft_model( + model=shared_llm, + peft_config=lora_config, + adapter_name=adapter_id, + ) + self.shared_llm.train() + # Load external adapter weights if provided + loaded_from: str | None = None + if path: + try: + # Supports both local filesystem paths and HF Hub repo IDs + self.shared_llm.load_adapter( + is_trainable=True, + model_id=path, + adapter_name=adapter_id, + ) + loaded_from = path + except Exception as exc: # noqa: BLE001 - want to log any load failure context + logger.warning( + f"Adapter '{adapter_id}': failed to load from '{path}': {exc}" + ) + + if loaded_from: + logger.info( + f"Adapter '{adapter_id}': loaded initial weights from '{loaded_from}'." + ) + else: + logger.info( + f"Adapter '{adapter_id}': initialized with fresh weights (no initial weights found)." + ) + + def parameters(self, recurse: bool = True): + """ + "recurse" is just for pytorch compatibility + """ + self.shared_llm.set_adapter(self.adapter_id) + params = [p for p in self.shared_llm.parameters() if p.requires_grad] + + return params + + def get_base_model_logits(self, contexts): + """ + Run the base model (without adapter) in inference mode, without tracking gradients. + This is useful to get reference logits for KL-divergence computation. + """ + with torch.no_grad(): + with self.shared_llm.disable_adapter(): + return self.shared_llm(input_ids=contexts)[0] + + def forward(self, *args, **kwargs): + self.shared_llm.set_adapter(self.adapter_id) + return self.shared_llm(*args, **kwargs) + + def save_pretrained(self, save_path): + self.shared_llm.save_pretrained(save_path) + + def gradient_checkpointing_enable(self, *args, **kwargs): + self.shared_llm.gradient_checkpointing_enable(*args, **kwargs) + + @property + def dtype(self): + return self.shared_llm.dtype + + @property + def device(self): + return self.shared_llm.device diff --git a/src_code_for_reproducibility/models/human_policy.py b/src_code_for_reproducibility/models/human_policy.py new file mode 100644 index 0000000000000000000000000000000000000000..00076a9e592cb29f8e6f33d852fc51be3e47ed13 --- /dev/null +++ b/src_code_for_reproducibility/models/human_policy.py @@ -0,0 +1,255 @@ +import asyncio +import os +import re +import shutil +import sys +from typing import Callable, Dict, List, Optional + +from mllm.markov_games.rollout_tree import ChatTurn + +try: + import rstr # For generating example strings from regex +except Exception: # pragma: no cover + rstr = None + + +def _clear_terminal() -> None: + """ + Clear the terminal screen in a cross-platform manner. + """ + if sys.stdout.isatty(): + os.system("cls" if os.name == "nt" else "clear") + + +def _terminal_width(default: int = 100) -> int: + try: + return shutil.get_terminal_size().columns + except Exception: + return default + + +def _horizontal_rule(char: str = "─") -> str: + width = max(20, _terminal_width() - 2) + return char * width + + +class _Style: + # ANSI colors (bright, readable) + RESET = "\033[0m" + BOLD = "\033[1m" + DIM = "\033[2m" + # Foreground colors + FG_BLUE = "\033[94m" # user/system headers + FG_GREEN = "\033[92m" # human response header + FG_YELLOW = "\033[93m" # notices + FG_RED = "\033[91m" # errors + FG_MAGENTA = "\033[95m" # regex + FG_CYAN = "\033[96m" # tips + + +def _render_chat(state) -> str: + """ + Render prior messages in a compact, readable terminal format. + + Expected message dict keys: {"role": str, "content": str, ...} + """ + lines: List[str] = [] + lines.append(_horizontal_rule()) + lines.append(f"{_Style.FG_BLUE}{_Style.BOLD} Conversation so far {_Style.RESET}") + lines.append(_horizontal_rule()) + for chat in state: + role = chat.role + content = str(chat.content).strip() + # Map roles to display names and colors/emojis + if role == "assistant": + header = f"{_Style.FG_GREEN}{_Style.BOLD}HUMAN--🧑‍💻{_Style.RESET}" + elif role == "user": + header = f"{_Style.FG_BLUE}{_Style.BOLD}USER--⚙️{_Style.RESET}" + else: + header = f"[{_Style.DIM}{role.upper()}{_Style.RESET}]" + lines.append(header) + # Indent content for readability + for line in content.splitlines() or [""]: + lines.append(f" {line}") + lines.append("") + lines.append(_horizontal_rule()) + return "\n".join(lines) + + +async def _async_input(prompt_text: str) -> str: + """Non-blocking input using a background thread.""" + return await asyncio.to_thread(input, prompt_text) + + +def _short_regex_example(regex: str, max_len: int = 30) -> Optional[str]: + """ + Try to produce a short example string that matches the regex. + We attempt multiple times and pick the first <= max_len. + """ + if rstr is None: + return None + try: + for _ in range(20): + candidate = rstr.xeger(regex) + if len(candidate) <= max_len: + return candidate + # Fallback to truncation (may break match, so don't return) + return None + except Exception: + return None + + +def _detect_input_type(regex: str | None) -> tuple[str, str, str]: + """ + Detect what type of input is expected based on the regex pattern. + Returns (input_type, start_tag, end_tag) + """ + if regex is None: + return "text", "", "" + + if "message_start" in regex and "message_end" in regex: + return "message", "<>", "<>" + elif "proposal_start" in regex and "proposal_end" in regex: + return "proposal", "<>", "<>" + else: + return "text", "", "" + + +async def human_policy(state, agent_id, regex: str | None = None) -> str: + """ + Async human-in-the-loop policy. + + - Displays prior conversation context in the terminal. + - Prompts the user for a response. + - If a regex is provided, validates and re-prompts until it matches. + - Automatically adds formatting tags based on expected input type. + + Args: + prompt: Chat history as a list of {role, content} dicts. + regex: Optional fullmatch validation pattern. + + Returns: + The user's validated response string. + """ + # Detect input type and formatting + input_type, start_tag, end_tag = _detect_input_type(regex) + + while True: + _clear_terminal() + print(_render_chat(state)) + + if regex: + example = _short_regex_example(regex, max_len=30) + print( + f"{_Style.FG_MAGENTA}{_Style.BOLD}Expected format (regex fullmatch):{_Style.RESET}" + ) + print(f" {_Style.FG_MAGENTA}{regex}{_Style.RESET}") + if example: + print( + f"{_Style.FG_CYAN}Example (random, <=30 chars):{_Style.RESET} {example}" + ) + print(_horizontal_rule(".")) + + # Custom prompt based on input type + if input_type == "message": + print( + f"{_Style.FG_YELLOW}Type your message content (formatting will be added automatically):{_Style.RESET}" + ) + elif input_type == "proposal": + print( + f"{_Style.FG_YELLOW}Type your proposal (number only, formatting will be added automatically):{_Style.RESET}" + ) + else: + print( + f"{_Style.FG_YELLOW}Type your response and press Enter.{_Style.RESET}" + ) + + print( + f"{_Style.DIM}Commands: /help to view commands, /refresh to re-render, /quit to abort{_Style.RESET}" + ) + else: + print( + f"{_Style.FG_YELLOW}Type your response and press Enter.{_Style.RESET} {_Style.DIM}(/help for commands){_Style.RESET}" + ) + + user_in = (await _async_input("> ")).rstrip("\n") + + # Commands + if user_in.strip().lower() in {"/help", "/h"}: + print(f"\n{_Style.FG_CYAN}{_Style.BOLD}Available commands:{_Style.RESET}") + print( + f" {_Style.FG_CYAN}/help{_Style.RESET} or {_Style.FG_CYAN}/h{_Style.RESET} Show this help" + ) + print( + f" {_Style.FG_CYAN}/refresh{_Style.RESET} or {_Style.FG_CYAN}/r{_Style.RESET} Re-render the conversation and prompt" + ) + print( + f" {_Style.FG_CYAN}/quit{_Style.RESET} or {_Style.FG_CYAN}/q{_Style.RESET} Abort the run (raises KeyboardInterrupt)" + ) + await asyncio.sleep(1.0) + continue + if user_in.strip().lower() in {"/refresh", "/r"}: + continue + if user_in.strip().lower() in {"/quit", "/q"}: + raise KeyboardInterrupt("Human aborted run from human_policy") + + # Add formatting tags if needed + if start_tag and end_tag: + formatted_input = f"{start_tag}{user_in}{end_tag}" + else: + formatted_input = user_in + + if regex is None: + return ChatTurn( + role="assistant", agent_id=agent_id, content=formatted_input + ) + + # Validate against regex (fullmatch) + try: + pattern = re.compile(regex) + except re.error as e: + # If regex is invalid, fall back to accepting any input + print( + f"{_Style.FG_RED}Warning:{_Style.RESET} Provided regex is invalid: {e}. Accepting input without validation." + ) + await asyncio.sleep(0.5) + return ChatTurn( + role="assistant", agent_id=agent_id, content=formatted_input + ) + + if pattern.fullmatch(formatted_input): + return ChatTurn( + role="assistant", agent_id=agent_id, content=formatted_input + ) + + # Show validation error and re-prompt + print("") + print( + f"{_Style.FG_RED}{_Style.BOLD}Input did not match the required format.{_Style.RESET} Please try again." + ) + + if input_type == "message": + print( + f"You entered: {_Style.FG_CYAN}{start_tag}{user_in}{end_tag}{_Style.RESET}" + ) + print(f"Just type the message content without tags.") + elif input_type == "proposal": + print( + f"You entered: {_Style.FG_CYAN}{start_tag}{user_in}{end_tag}{_Style.RESET}" + ) + print(f"Just type the number without tags.") + else: + print(f"Expected (regex):") + print(f" {_Style.FG_MAGENTA}{regex}{_Style.RESET}") + + print(_horizontal_rule(".")) + print(f"{_Style.FG_YELLOW}Press Enter to retry...{_Style.RESET}") + await _async_input("") + + +def get_human_policies() -> Dict[str, Callable[[List[Dict]], str]]: + """ + Expose the human policy in the same map shape used elsewhere. + """ + # Type hint says Callable[[List[Dict]], str] but we intentionally return the async callable. + return {"human_policy": human_policy} # type: ignore[return-value] diff --git a/src_code_for_reproducibility/models/inference_backend.py b/src_code_for_reproducibility/models/inference_backend.py new file mode 100644 index 0000000000000000000000000000000000000000..2fd31f0f51e00dbcbf1db954e8eab42550e05d4e --- /dev/null +++ b/src_code_for_reproducibility/models/inference_backend.py @@ -0,0 +1,39 @@ +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Any, Optional + + +@dataclass +class LLMInferenceOutput: + content: str + reasoning_content: str | None = None + log_probs: list[float] | None = None + out_token_ids: list[int] | None = None + + +class LLMInferenceBackend(ABC): + @abstractmethod + def __init__(self, **kwargs): + ... + + @abstractmethod + def prepare_adapter( + self, adapter_id: str, weights_got_updated: bool = False + ) -> None: + """Ensure adapter is ready/loaded for next generation call.""" + + @abstractmethod + async def generate(self, prompt: list[dict], regex: Optional[str] = None) -> str: + ... + + @abstractmethod + def toggle_training_mode(self) -> None: + ... + + @abstractmethod + def toggle_eval_mode(self) -> None: + ... + + @abstractmethod + def shutdown(self) -> None: + ... diff --git a/src_code_for_reproducibility/models/inference_backend_dummy.py b/src_code_for_reproducibility/models/inference_backend_dummy.py new file mode 100644 index 0000000000000000000000000000000000000000..ac7d4bd5566220377913e6af012545df492acad9 --- /dev/null +++ b/src_code_for_reproducibility/models/inference_backend_dummy.py @@ -0,0 +1,54 @@ +import asyncio +from typing import Optional + +import rstr +from transformers import AutoTokenizer + +from mllm.models.inference_backend import LLMInferenceBackend, LLMInferenceOutput +from mllm.utils.short_id_gen import generate_short_id + + +class DummyInferenceBackend(LLMInferenceBackend): + def __init__( + self, + *args, + **kwargs, + ): + pass + + def prepare_adapter( + self, + adapter_id: Optional[str], + weights_got_updated: bool, + adapter_path: Optional[str] = None, + ) -> None: + pass + + async def toggle_training_mode(self) -> None: + await asyncio.sleep(0) + pass + + async def toggle_eval_mode(self) -> None: + await asyncio.sleep(0) + pass + + def shutdown(self) -> None: + pass + + async def generate( + self, + prompt_text: str, + regex: Optional[str] = None, + extract_thinking: bool = False, + ) -> LLMInferenceOutput: + if regex: + # Create random string that respects the regex + return LLMInferenceOutput( + content=rstr.xeger(regex), + reasoning_content="I don't think, I am a dummy backend.", + ) + else: + return LLMInferenceOutput( + content="I am a dummy backend without a regex.", + reasoning_content="I don't think, I am a dummy backend.", + ) diff --git a/src_code_for_reproducibility/models/inference_backend_sglang.py b/src_code_for_reproducibility/models/inference_backend_sglang.py new file mode 100644 index 0000000000000000000000000000000000000000..d06b9d4bed25bb2b7fe117422b67b40f0c6e4509 --- /dev/null +++ b/src_code_for_reproducibility/models/inference_backend_sglang.py @@ -0,0 +1,86 @@ +# new_backend_sglang_offline.py +from __future__ import annotations + +import asyncio +from typing import Any, Optional + +# import sglang as sgl + +from mllm.models.inference_backend import LLMInferenceBackend + + +class SGLangOfflineBackend(LLMInferenceBackend): + def __init__( + self, + model_name: str, + tokenizer, # unused but kept for parity + adapter_paths: dict[str, str], + device: str = "cuda", + max_model_len: Optional[int] = None, + enable_lora: bool = True, + lora_target_modules: Optional[list[str] | str] = None, + max_loras_per_batch: int = 8, + engine_kwargs: dict[str, Any] = None, + ): + self.model_name = model_name + self.adapter_paths = adapter_paths + self.current_adapter: Optional[str] = None + engine_kwargs = dict(engine_kwargs or {}) + # Map server-style LoRA flags to offline engine ctor + if enable_lora and adapter_paths: + engine_kwargs.setdefault("enable_lora", True) + # The offline Engine mirrors server args; pass a mapping name->path + engine_kwargs.setdefault("lora_paths", adapter_paths) + if lora_target_modules is not None: + engine_kwargs.setdefault("lora_target_modules", lora_target_modules) + engine_kwargs.setdefault("max_loras_per_batch", max_loras_per_batch) + + if max_model_len is not None: + engine_kwargs.setdefault("context_length", max_model_len) + + # Launch in-process engine (no HTTP server) + self.llm = sgl.Engine(model_path=model_name, **engine_kwargs) # async-ready + # SGLang supports: generate(), async_generate(), and async streaming helpers. :contentReference[oaicite:2]{index=2} + + def is_ready(self) -> bool: + return True + + def toggle_training_mode(self) -> None: + # No explicit KV release API offline; typically you pause usage here. + pass + + def toggle_eval_mode(self) -> None: + pass + + def shutdown(self) -> None: + # Engine cleans up on GC; explicit close not required. + pass + + def prepare_adapter(self, adapter_id: Optional[str]) -> None: + # With offline Engine, when LoRA is enabled at init, + # you select adapter per request via the input batch mapping. + self.current_adapter = adapter_id + + async def generate( + self, prompt_text: str, sampling_params: dict, adapter_id: Optional[str] + ) -> str: + # Non-streaming async (batch of 1). For batched prompts, pass a list. + params = { + "temperature": sampling_params.get("temperature", 1.0), + "top_p": sampling_params.get("top_p", 1.0), + "max_new_tokens": sampling_params.get("max_new_tokens", 128), + } + if (tk := sampling_params.get("top_k", -1)) and tk > 0: + params["top_k"] = tk + if (mn := sampling_params.get("min_new_tokens")) is not None: + params["min_new_tokens"] = mn + if (fp := sampling_params.get("frequency_penalty")) is not None: + params["frequency_penalty"] = fp + + # If using multi-LoRA, SGLang lets you provide adapter names aligned to each input. + prompts = [prompt_text] + adapters = [adapter_id] if adapter_id else None # or omit for base + outs = await self.llm.async_generate( + prompts, params, adapters + ) # :contentReference[oaicite:3]{index=3} + return outs[0]["text"] diff --git a/src_code_for_reproducibility/models/inference_backend_sglang_local_server.py b/src_code_for_reproducibility/models/inference_backend_sglang_local_server.py new file mode 100644 index 0000000000000000000000000000000000000000..c29f4d01a0bc0a6a0435461a02cd67074771abb0 --- /dev/null +++ b/src_code_for_reproducibility/models/inference_backend_sglang_local_server.py @@ -0,0 +1,127 @@ +import os + +import httpx +import requests +from sglang.utils import launch_server_cmd, wait_for_server + +from mllm.models.inference_backend import LLMInferenceBackend + + +class HttpSGLangBackend(LLMInferenceBackend): + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.port = None + self.proc = None + self.urls = {} + # track sglang adapter ids separately from your logical ids + self.sglang_names = {aid: aid for aid in self.adapter_paths.keys()} + self.needs_loading = {aid: True for aid in self.adapter_paths.keys()} + + # defaults you already used: + self.mem_fraction = kwargs.get("mem_fraction_static", 0.6) + self.dtype = kwargs.get("dtype", "bfloat16") + self.extra_cli = kwargs.get("extra_cli", "") + self.disable_radix_cache = kwargs.get("disable_radix_cache", True) + + def launch(self) -> None: + # find local hf cache path for server + from transformers.utils import cached_file + + local_llm_path = os.path.split(cached_file(self.model_name, "config.json"))[0] + + lora_str = "" + if self.adapter_paths: + lora_str = "--lora-paths " + " ".join( + f"{aid}={path}" for aid, path in self.adapter_paths.items() + ) + + cmd = f""" + python3 -m sglang.launch_server --model-path {local_llm_path} \ + --host 0.0.0.0 {lora_str} \ + {'--disable-radix-cache' if self.disable_radix_cache else ''} \ + --mem-fraction-static {self.mem_fraction} --dtype {self.dtype} {self.extra_cli} + """ + self.proc, self.port = launch_server_cmd(cmd) + wait_for_server(f"http://localhost:{self.port}") + base = f"http://localhost:{self.port}" + self.urls = dict( + generate=f"{base}/generate", + release=f"{base}/release_memory_occupation", + resume=f"{base}/resume_memory_occupation", + load_lora=f"{base}/load_lora_adapter", + unload_lora=f"{base}/unload_lora_adapter", + ) + + def is_ready(self) -> bool: + try: + requests.get(self.urls["generate"], timeout=2) + return True + except Exception: + return False + + def prepare_adapter(self, adapter_id: str) -> None: + if adapter_id is None: + return + if self.needs_loading.get(adapter_id, False): + # unload old name if present + try: + requests.post( + self.urls["unload_lora"], + json={"lora_name": self.sglang_names[adapter_id]}, + timeout=10, + ) + except Exception: + pass + new_name = self._short_id() + self.sglang_names[adapter_id] = new_name + requests.post( + self.urls["load_lora"], + json={ + "lora_name": new_name, + "lora_path": self.adapter_paths[adapter_id], + }, + ).raise_for_status() + self.needs_loading[adapter_id] = False + + async def generate( + self, prompt_text: str, sampling_params: dict, adapter_id: str | None + ) -> str: + lora_name = self.sglang_names.get(adapter_id) if adapter_id else None + payload = { + "text": [prompt_text], + "sampling_params": sampling_params, + } + if lora_name: + payload["lora_path"] = [lora_name] + + timeout = httpx.Timeout(3600.0, connect=3600.0) + async with httpx.AsyncClient(timeout=timeout) as client: + resp = await client.post(self.urls["generate"], json=payload) + resp.raise_for_status() + return resp.json()[0]["text"] + + def toggle_training_mode(self) -> None: + # free KV space while training adapters + requests.post( + self.urls["release"], json={"tags": ["kv_cache"]} + ).raise_for_status() + + def toggle_eval_mode(self) -> None: + # re-allocate KV space + try: + requests.post( + self.urls["resume"], json={"tags": ["kv_cache"]} + ).raise_for_status() + except Exception: + pass + + def shutdown(self) -> None: + from sglang.utils import terminate_process + + if self.proc: + terminate_process(self.proc) + + def _short_id(self) -> str: + import uuid + + return str(uuid.uuid4().int)[:8] diff --git a/src_code_for_reproducibility/models/inference_backend_vllm.py b/src_code_for_reproducibility/models/inference_backend_vllm.py new file mode 100644 index 0000000000000000000000000000000000000000..ebc9a5de272c3f2ec0bb9f4297610f3c4481185a --- /dev/null +++ b/src_code_for_reproducibility/models/inference_backend_vllm.py @@ -0,0 +1,118 @@ +import asyncio +import re +from typing import Optional + +import torch +from transformers import AutoTokenizer +from vllm import AsyncEngineArgs, AsyncLLMEngine, SamplingParams +from vllm.inputs import TokensPrompt +from vllm.lora.request import LoRARequest +from vllm.sampling_params import GuidedDecodingParams, RequestOutputKind + +from mllm.models.inference_backend import LLMInferenceBackend, LLMInferenceOutput +from mllm.utils.short_id_gen import generate_short_id + + +class VLLMAsyncBackend(LLMInferenceBackend): + def __init__( + self, + model_name: str, + tokenizer: AutoTokenizer, + # adapter_paths: dict[str, str], + engine_init_kwargs: dict = {}, + sampling_params: dict = {}, + ): + self.model_name = model_name + # self.adapter_paths = adapter_paths or {} + # self.current_adapter = None + # self.vllm_adapter_ids = { + # adapter_id: generate_short_id() for adapter_id in adapter_paths.keys() + # } + self.vllm_adapter_ids = {} + ea = dict(model=model_name, **engine_init_kwargs) + # ea["enable_lora"] = True + # ea["max_loras"] = len(self.vllm_adapter_ids) + # ea["enable_sleep_mode"] = True + self.engine = AsyncLLMEngine.from_engine_args(AsyncEngineArgs(**ea)) + + self.sampling_params = sampling_params + self.tokenizer = tokenizer + + def prepare_adapter( + self, + adapter_id: Optional[str], + adapter_path: Optional[str], + weights_got_updated: bool, + ) -> None: + # self.current_adapter = adapter_id + if weights_got_updated: + self.vllm_adapter_ids[adapter_id] = generate_short_id() + self.current_lora_request = LoRARequest( + adapter_id, + self.vllm_adapter_ids[adapter_id], + adapter_path, + ) + + async def toggle_training_mode(self) -> None: + await self.engine.sleep(level=1) + + async def toggle_eval_mode(self) -> None: + await self.engine.wake_up() + + def shutdown(self) -> None: + # No explicit close call; engine stops when process exits. + pass + + async def generate( + self, + input_token_ids: list[int], + regex: Optional[str] = None, + extract_thinking: bool = False, + ) -> LLMInferenceOutput: + # Build SamplingParams correctly + guided = GuidedDecodingParams(regex=regex) if regex else None + sp = SamplingParams( + **self.sampling_params, + guided_decoding=guided, + output_kind=RequestOutputKind.FINAL_ONLY, + ) + + prompt = TokensPrompt(prompt_token_ids=input_token_ids) + request_id = f"req-{asyncio.get_running_loop().time()}" + result_generator = self.engine.generate( + prompt, + sp, # SamplingParams(...) + request_id, + lora_request=self.current_lora_request, + ) + + async for out in result_generator: # with FINAL_ONLY this runs once + res = out + + raw_text = res.outputs[0].text + out_token_ids = res.outputs[0].token_ids + log_probs = [ + logprob_dict[token_id].logprob + for token_id, logprob_dict in zip(out_token_ids, res.outputs[0].logprobs) + ] + log_probs = torch.tensor(log_probs) + out_token_ids = torch.tensor(out_token_ids, dtype=torch.long) + # for out_token_id, logprob_dict in zip(out_token_ids, res.outputs[0].logprobs): + # if logprob_dict[out_token_id].logprob < -1: + # print(f"High negative logprob {logprob_dict[out_token_id].logprob} for {logprob_dict}") + content = raw_text + reasoning_content = None + + if extract_thinking: + m = re.match( + r"^\n\n([\s\S]*?)\n\n(.*)$", raw_text, flags=re.DOTALL + ) + if m: + reasoning_content = m.group(1) + content = m.group(2) + return LLMInferenceOutput( + content=content, + reasoning_content=reasoning_content, + log_probs=log_probs, + out_token_ids=out_token_ids, + ) diff --git a/src_code_for_reproducibility/models/inference_backend_vllm_local_server.py b/src_code_for_reproducibility/models/inference_backend_vllm_local_server.py new file mode 100644 index 0000000000000000000000000000000000000000..815ba3eb2cc9bb1d664b0065b240cd0273f66474 --- /dev/null +++ b/src_code_for_reproducibility/models/inference_backend_vllm_local_server.py @@ -0,0 +1,160 @@ +import json +import os +import subprocess +import time + +import httpx +import requests + +from mllm.models.inference_backend import LLMInferenceBackend + + +class HttpVLLMBackend(LLMInferenceBackend): + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.port = kwargs.get("port", 8000) + self.host = kwargs.get("host", "0.0.0.0") + self.proc = None + self.base_url = f"http://{self.host}:{self.port}" + # vLLM memory safety knobs + self.gpu_mem_util = kwargs.get("gpu_memory_utilization", 0.9) + self.max_model_len = kwargs.get("max_model_len", None) + self.max_num_seqs = kwargs.get("max_num_seqs", None) + self.max_batched_tokens = kwargs.get("max_num_batched_tokens", None) + self.dtype = kwargs.get("dtype", "bfloat16") + self.trust_remote_code = kwargs.get("trust_remote_code", False) + # LoRA strategy: "preload" (CLI) or "runtime" (endpoints) depending on your vLLM build + self.lora_mode = kwargs.get( + "lora_mode", "preload" + ) # "runtime" supported in newer builds + self.runtime_lora_enabled = self.lora_mode == "runtime" + + # If preloading: build CLI args (adapter name -> path) + self._preload_lora_args = [] + if self.adapter_paths and self.lora_mode == "preload": + # vLLM supports multiple LoRA modules via CLI in recent versions + # Example flag shapes can vary; adapt as needed for your version: + # --lora-modules adapter_id=path + for aid, pth in self.adapter_paths.items(): + self._preload_lora_args += ["--lora-modules", f"{aid}={pth}"] + + def launch(self): + # Build vLLM serve command + cmd = [ + "python3", + "-m", + "vllm.entrypoints.openai.api_server", + "--model", + self.model_name, + "--host", + self.host, + "--port", + str(self.port), + "--dtype", + self.dtype, + "--gpu-memory-utilization", + str(self.gpu_mem_util), + ] + if self.trust_remote_code: + cmd += ["--trust-remote-code"] + if self.max_model_len: + cmd += ["--max-model-len", str(self.max_model_len)] + if self.max_num_seqs: + cmd += ["--max-num-seqs", str(self.max_num_seqs)] + if self.max_batched_tokens: + cmd += ["--max-num-batched-tokens", str(self.max_batched_tokens)] + cmd += self._preload_lora_args + + self.proc = subprocess.Popen( + cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True + ) + self._wait_ready() + + def _wait_ready(self, timeout=120): + url = f"{self.base_url}/v1/models" + t0 = time.time() + while time.time() - t0 < timeout: + try: + r = requests.get(url, timeout=2) + if r.status_code == 200: + return + except Exception: + pass + time.sleep(1) + raise RuntimeError("vLLM server did not become ready in time") + + def is_ready(self) -> bool: + try: + return ( + requests.get(f"{self.base_url}/v1/models", timeout=2).status_code == 200 + ) + except Exception: + return False + + def prepare_adapter(self, adapter_id: str) -> None: + if not adapter_id or not self.runtime_lora_enabled: + return + # Newer vLLM builds expose runtime LoRA endpoints. If yours differs, + # adjust the path/body here and keep the interface stable. + try: + requests.post( + f"{self.base_url}/v1/load_lora_adapter", + json={ + "adapter_name": adapter_id, + "adapter_path": self.adapter_paths[adapter_id], + }, + timeout=10, + ).raise_for_status() + except Exception as e: + # If already loaded or endpoint not present, swallow or log + pass + + async def generate( + self, prompt_text: str, sampling_params: dict, adapter_id: str | None + ) -> str: + # Map your sampling params to OpenAI schema + body = { + "model": self.model_name, + "messages": [{"role": "user", "content": prompt_text}], + "temperature": sampling_params.get("temperature", 1.0), + "top_p": sampling_params.get("top_p", 1.0), + "max_tokens": sampling_params.get("max_new_tokens", 128), + } + # Optional knobs: + if sampling_params.get("top_k", -1) and sampling_params["top_k"] > 0: + # vLLM accepts top_k via extra params; put under "extra_body" + body.setdefault("extra_body", {})["top_k"] = sampling_params["top_k"] + if sampling_params.get("min_new_tokens", None) is not None: + body.setdefault("extra_body", {})["min_tokens"] = sampling_params[ + "min_new_tokens" + ] + if sampling_params.get("frequency_penalty", None) is not None: + body["frequency_penalty"] = sampling_params["frequency_penalty"] + + # Select LoRA adapter + if adapter_id: + if self.runtime_lora_enabled: + body.setdefault("extra_body", {})["lora_adapter"] = adapter_id + else: + # when preloaded via CLI, most builds select by name via "adapter_name"/"lora_adapter" + body.setdefault("extra_body", {})["lora_adapter"] = adapter_id + + url = f"{self.base_url}/v1/chat/completions" + timeout = httpx.Timeout(3600.0, connect=3600.0) + async with httpx.AsyncClient(timeout=timeout) as client: + resp = await client.post(url, json=body) + resp.raise_for_status() + data = resp.json() + return data["choices"][0]["message"]["content"] + + def toggle_training_mode(self) -> None: + # vLLM doesn’t expose an explicit KV “release” toggle via API. + # Strategy: keep inference server idle during training, or run training in a separate process. + pass + + def toggle_eval_mode(self) -> None: + pass + + def shutdown(self) -> None: + if self.proc: + self.proc.terminate() diff --git a/src_code_for_reproducibility/models/large_language_model_api.py b/src_code_for_reproducibility/models/large_language_model_api.py new file mode 100644 index 0000000000000000000000000000000000000000..1afedf4e19f6b00c40c0127b18843eab03a89648 --- /dev/null +++ b/src_code_for_reproducibility/models/large_language_model_api.py @@ -0,0 +1,171 @@ +from __future__ import annotations + +import asyncio +import copy +import os +import random +import re +from typing import Any, Callable, Dict, List, Optional, Sequence + +import backoff +from openai import AsyncOpenAI, OpenAIError + +from mllm.markov_games.rollout_tree import ChatTurn +from mllm.models.inference_backend import LLMInferenceOutput + +# TODO: Get this automatically from OpenAI +reasoning_models = [ + "gpt-5-nano", + "gpt-5-mini", + "gpt-5", + "o1-mini", + "o1", + "o1-pro", + "o3-mini", + "o3", + "o3-pro", + "o4-mini", + "o4", + "o4-pro", +] + + +class LargeLanguageModelOpenAI: + """Tiny async wrapper for OpenAI Chat Completions.""" + + def __init__( + self, + llm_id: str = "", + model: str = "gpt-4.1-mini", + api_key: Optional[str] = None, + base_url: Optional[str] = None, + timeout_s: float = 300.0, + regex_max_attempts: int = 10, + sampling_params: Optional[Dict[str, Any]] = None, + init_kwargs: Optional[Dict[str, Any]] = None, + output_directory: Optional[str] = None, + ) -> None: + self.llm_id = llm_id + self.model = model + key = api_key or os.getenv("OPENAI_API_KEY") + if not key: + raise RuntimeError( + "Set OPENAI_API_KEY as global environment variable or pass api_key." + ) + client_kwargs: Dict[str, Any] = {"api_key": key, "timeout": timeout_s} + if base_url: + client_kwargs["base_url"] = base_url + self.client = AsyncOpenAI(**client_kwargs) + + # Sampling/default request params set at init + self.sampling_params = sampling_params + self.use_reasoning = model in reasoning_models + if self.use_reasoning: + self.sampling_params["reasoning"] = { + "effort": "low", + "summary": "detailed", + } + self.regex_max_attempts = max(1, int(regex_max_attempts)) + + def get_inference_policies(self) -> Dict[str, Callable]: + return { + self.llm_id: self.get_action, + } + + async def prepare_adapter_for_inference(self, *args: Any, **kwargs: Any) -> None: + await asyncio.sleep(0) + pass + + async def toggle_eval_mode(self, *args: Any, **kwargs: Any) -> None: + await asyncio.sleep(0) + pass + + async def toggle_training_mode(self, *args: Any, **kwargs: Any) -> None: + await asyncio.sleep(0) + pass + + async def export_adapters(self, *args: Any, **kwargs: Any) -> None: + await asyncio.sleep(0) + pass + + async def checkpoint_all_adapters(self, *args: Any, **kwargs: Any) -> None: + await asyncio.sleep(0) + pass + + def extract_output_from_response(self, resp: Response) -> LLMInferenceOutput: + if len(resp.output) > 1: + summary = resp.output[0].summary + if summary != []: + reasoning_content = summary[0].text + reasoning_content = f"OpenAI Reasoning Summary: {reasoning_content}" + else: + reasoning_content = None + content = resp.output[1].content[0].text + else: + reasoning_content = None + content = resp.output[0].content[0].text + + return LLMInferenceOutput( + content=content, + reasoning_content=reasoning_content, + ) + + @backoff.on_exception( + backoff.expo, Exception, max_time=10**10, max_tries=10**10 + ) + async def get_action( + self, + state: list[ChatTurn], + agent_id: str, + regex: Optional[str] = None, + ) -> LLMInferenceOutput: + # Remove any non-role/content keys from the prompt else openai will error + + # TODO: + prompt = [{"role": p.role, "content": p.content} for p in state] + + # if self.sleep_between_requests: + # await self.wait_random_time() + + # If regex is required, prime the model and validate client-side + if regex: + constraint_msg = { + "role": "user", + "content": ( + f"Output must match this regex exactly: {regex} \n" + "Return only the matching string, with no quotes or extra text." + ), + } + prompt = [constraint_msg, *prompt] + pattern = re.compile(regex) + for _ in range(self.regex_max_attempts): + resp = await self.client.responses.create( + model=self.model, + input=prompt, + **self.sampling_params, + ) + policy_output = self.extract_output_from_response(resp) + if pattern.fullmatch(policy_output.content): + return policy_output + prompt = [ + *prompt, + { + "role": "user", + "content": ( + f"Invalid response format. Expected format (regex): {regex}\n Please try again and provide ONLY a response that matches this regex." + ), + }, + ] + return policy_output + + # Simple, unconstrained generation + resp = await self.client.responses.create( + model=self.model, + input=prompt, + **self.sampling_params, + ) + policy_output = self.extract_output_from_response(resp) + return policy_output + + def shutdown(self) -> None: + self.client = None diff --git a/src_code_for_reproducibility/models/large_language_model_local.py b/src_code_for_reproducibility/models/large_language_model_local.py new file mode 100644 index 0000000000000000000000000000000000000000..7eac1c32c0233cf04106fa12be333ebf74319c2a --- /dev/null +++ b/src_code_for_reproducibility/models/large_language_model_local.py @@ -0,0 +1,384 @@ +""" +TODO: Figure out how to tweak SGlang not to go OOM when batch size is 32. See https://github.com/sgl-project/sglang/issues/6309. +""" + +import logging +import os +import re +import sys +import uuid +from collections.abc import Callable +from copy import deepcopy +from datetime import datetime +from typing import Literal + +import httpx +import requests +import torch +import torch.nn as nn + +# from sglang.utils import ( +# launch_server_cmd, +# print_highlight, +# terminate_process, +# wait_for_server, +# ) +from torch.optim import SGD, Adam, AdamW, RMSprop +from transformers import AutoModelForCausalLM, AutoTokenizer +from trl import AutoModelForCausalLMWithValueHead + +from mllm.chat_utils.apply_template import chat_turns_to_token_ids +from mllm.markov_games.rollout_tree import ChatTurn +from mllm.models.adapter_training_wrapper import AdapterWrapper +from mllm.models.inference_backend import LLMInferenceOutput +from mllm.models.inference_backend_dummy import DummyInferenceBackend +from mllm.models.inference_backend_sglang import SGLangOfflineBackend +from mllm.models.inference_backend_vllm import VLLMAsyncBackend + +logger = logging.getLogger(__name__) +logger.addHandler(logging.StreamHandler(sys.stdout)) + +AdapterID = str +PolicyID = str + + +class LeanLocalLLM: + """ + TOWRITE + """ + + def __init__( + self, + llm_id: str = "base_llm", + model_name: str = "Qwen/Qwen3-4B-Instruct-2507", + device: str = "cuda", + hf_kwargs: dict = {}, + adapter_configs: dict = {}, + output_directory: str = "./models/", + inference_backend: Literal["vllm", "sglang", "dummy"] = "vllm", + inference_backend_sampling_params: dict = {}, + inference_backend_init_kwargs: dict = {}, + initial_adapter_paths: dict[str, str] | None = None, + initial_buffer_paths: list[str] | None = None, + enable_thinking: bool = None, + regex_max_attempts: int = -1, + max_thinking_characters: int = 0, + ): + self.inference_backend_name = inference_backend + self.output_directory = output_directory + self.llm_id = llm_id + self.device = torch.device(device) if device else torch.device("cuda") + self.model_name = model_name + self.adapter_configs = adapter_configs + self.adapter_ids = list(adapter_configs.keys()) + self.enable_thinking = enable_thinking + self.regex_max_attempts = regex_max_attempts + self.initial_buffer_paths = initial_buffer_paths + self.max_thinking_characters = max_thinking_characters + self.regex_retries_count = 0 + + # Optional user-specified initial adapter weight locations (local or HF Hub) + # Format: {adapter_id: path_or_repo_id} + self.initial_adapter_paths: dict[str, str] | None = initial_adapter_paths + + # Path management / imports + self.save_path = str(os.path.join(output_directory, model_name, "adapters")) + self.adapter_paths = { + adapter_id: os.path.join(self.save_path, adapter_id) + for adapter_id in self.adapter_ids + } + checkpoints_dir = os.path.join(self.output_directory, "checkpoints") + self.past_agent_adapter_paths = {} + if os.path.isdir(checkpoints_dir): + for dirname in os.listdir(checkpoints_dir): + dirpath = os.path.join(checkpoints_dir, dirname) + if os.path.isdir(dirpath): + self.past_agent_adapter_paths[f"{dirname}_buffer"] = os.path.join( + dirpath, "agent_adapter" + ) + logger.info( + f"Loaded {len(self.past_agent_adapter_paths)} past agent adapters from checkpoints directory." + ) + if self.initial_buffer_paths is not None: + previous_count = len(self.past_agent_adapter_paths) + for path in self.initial_buffer_paths: + if os.path.isdir(path): + for dirname in os.listdir(path): + dirpath = os.path.join(path, dirname) + if os.path.isdir(dirpath): + self.past_agent_adapter_paths[ + f"{dirname}_buffer" + ] = os.path.join(dirpath, "agent_adapter") + else: + logger.warning( + f"Initial buffer path {path} does not exist or is not a directory." + ) + logger.info( + f"Loaded {len(self.past_agent_adapter_paths) - previous_count} past agent adapters from user-specified initial buffer paths." + ) + self.past_agent_adapter_ids = list(self.past_agent_adapter_paths.keys()) + + # ID management for tracking adapter versions + self.adapter_train_ids = { + adapter_id: self.short_id_generator() for adapter_id in self.adapter_ids + } + # Initialize tokenizer + self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) + # Setup padding token to be same as EOS token + self.tokenizer.pad_token_id = self.tokenizer.eos_token_id + self.tokenizer.pad_token = self.tokenizer.eos_token + + self.weights_got_updated: dict[AdapterID, bool] = { + adapter_id: False for adapter_id in self.adapter_ids + } + self.weights_got_updated.update( + {adapter_id: False for adapter_id in self.past_agent_adapter_ids} + ) + self.current_lora_request = None + self.currently_loaded_adapter_id = None + + # --------------------------------------------------------- + # Init HF model, peft adapters + # --------------------------------------------------------- + self.shared_hf_llm = AutoModelForCausalLM.from_pretrained( + pretrained_model_name_or_path=model_name, + **hf_kwargs, + ) + self.hf_adapters = {} + self.optimizers = {} + for adapter_id in self.adapter_ids: + # Prefer output-folder path if it exists; else fall back to user-specified initial path if provided + output_path = os.path.join(self.save_path, adapter_id) + chosen_path: str | None = None + if os.path.isdir(output_path) and os.listdir(output_path): + chosen_path = output_path + logger.info( + f"Initializing adapter '{adapter_id}': using existing weights from output folder '{chosen_path}'." + ) + elif ( + self.initial_adapter_paths and adapter_id in self.initial_adapter_paths + ): + chosen_path = self.initial_adapter_paths[adapter_id] + logger.info( + f"Initializing adapter '{adapter_id}': using provided initial path '{chosen_path}'." + ) + else: + logger.info( + f"Initializing adapter '{adapter_id}': no initial weights provided or found; starting from scratch." + ) + hf_adapter = AdapterWrapper( + shared_llm=self.shared_hf_llm, + adapter_id=adapter_id, + lora_config=adapter_configs[adapter_id], + path=chosen_path, + ).to(device) + self.hf_adapters[adapter_id] = hf_adapter + # Persist current state of all adapters (ensures remote loads are cached to disk) + self.export_adapters() + + # --------------------------------------------------------- + # Init inference inference_backend + # --------------------------------------------------------- + + if inference_backend == "sglang": + self.inference_backend = SGLangOfflineBackend( + model_name=self.model_name, + save_path=self.save_path, + adapter_paths=self.adapter_paths, + tokenizer=self.tokenizer, + kwargs=inference_backend_init_kwargs, + ) + elif inference_backend == "vllm": + self.inference_backend = VLLMAsyncBackend( + model_name=self.model_name, + # adapter_paths=self.adapter_paths, + tokenizer=self.tokenizer, + engine_init_kwargs=inference_backend_init_kwargs, + sampling_params=inference_backend_sampling_params, + ) + elif inference_backend == "dummy": + self.inference_backend = DummyInferenceBackend() + else: + raise ValueError(f"Unknown inference_backend: {inference_backend}") + + def reset_regex_retries_count(self) -> None: + self.regex_retries_count = 0 + + def get_inference_policies(self) -> dict[PolicyID, Callable]: + """ + TOWRITE + """ + policies = {} + for adapter_id in self.adapter_ids: + # define policy func + async def policy( + state: list[ChatTurn], + agent_id: str, + regex: str | None = None, + _adapter_id=adapter_id, + ): + self.prepare_adapter_for_inference(adapter_id=_adapter_id) + response = await self.get_action(state, agent_id, regex) + return response + + policies[self.llm_id + "/" + adapter_id] = policy + + for adapter_id in self.past_agent_adapter_ids: + # define policy func + async def policy( + state: list[ChatTurn], + agent_id: str, + regex: str | None = None, + _adapter_id=adapter_id, + ): + self.prepare_adapter_for_inference(adapter_id=_adapter_id) + response = await self.get_action(state, agent_id, regex) + return response + + policies[self.llm_id + "/" + adapter_id] = policy + return policies + + def get_adapter_modules(self) -> dict[PolicyID, nn.Module]: + """ + Returns wrappers over the adapters which allows them be + interfaced like regular PyTorch models. + # TODO: create the adapter wrappers here + See adapter_wrapper.py + """ + trainable_objects = {an: self.hf_adapters[an] for an in self.adapter_ids} + return trainable_objects + + async def toggle_training_mode(self) -> None: + for adn in self.adapter_ids: + self.adapter_train_ids[adn] = self.short_id_generator() + await self.inference_backend.toggle_training_mode() + + async def toggle_eval_mode(self) -> None: + await self.inference_backend.toggle_eval_mode() + + def prepare_adapter_for_inference(self, adapter_id: AdapterID) -> None: + self.inference_backend.prepare_adapter( + adapter_id, + adapter_path=self.adapter_paths.get( + adapter_id, self.past_agent_adapter_paths.get(adapter_id, None) + ), + weights_got_updated=self.weights_got_updated[adapter_id], + ) + self.currently_loaded_adapter_id = adapter_id + self.weights_got_updated[adapter_id] = False + + # def _make_prompt_text(self, prompt: list[dict]) -> str: + # if self.enable_thinking is not None: + # prompt_text = self.tokenizer.apply_chat_template( + # prompt, + # tokenize=False, + # add_generation_prompt=True, + # enable_thinking=self.enable_thinking, + # ) + # else: + # prompt_text = self.tokenizer.apply_chat_template( + # prompt, + # tokenize=False, + # add_generation_prompt=True, + # ) + + # return prompt_text + + async def get_action( + self, state: list[ChatTurn], agent_id: str, regex: str | None = None + ) -> ChatTurn: + current_regex = regex if self.regex_max_attempts == -1 else None + pattern = re.compile(regex) if regex else None + nb_attempts = 0 + state = state[:] + while True: + context_token_ids = chat_turns_to_token_ids( + chats=state, + tokenizer=self.tokenizer, + enable_thinking=self.enable_thinking, + ) + # print(f"context is {self.tokenizer.decode(context_token_ids)}") + policy_output = await self.inference_backend.generate( + input_token_ids=context_token_ids.tolist(), + extract_thinking=(self.max_thinking_characters > 0), + regex=current_regex, + ) + # print(f"generated: {self.tokenizer.decode(policy_output.out_token_ids)}") + if ( + pattern is None + or (pattern.fullmatch(policy_output.content)) + or (nb_attempts >= self.regex_max_attempts) + ): + return ChatTurn( + agent_id=agent_id, + role="assistant", + content=policy_output.content, + reasoning_content=policy_output.reasoning_content, + out_token_ids=policy_output.out_token_ids, + log_probs=policy_output.log_probs, + is_state_end=False, + ) + else: + self.regex_retries_count += 1 + nb_attempts += 1 + logger.warning( + f"Response {policy_output.content} did not match regex: {regex}, retry {nb_attempts}/{self.regex_max_attempts}" + ) + if nb_attempts == self.regex_max_attempts: + current_regex = regex + # regex_prompt = ChatTurn( + # role="user", + # content=f"Invalid response format. Expected format (regex): {current_regex}\n Please try again and provide ONLY a response that matches this regex.", + # reasoning_content=None, + # log_probs=None, + # out_token_ids=None, + # is_state_end=False, + # ) + # state.append(regex_prompt) + + def export_adapters(self) -> None: + """ + Any peft wrapper, by default, saves all adapters, not just the one currently loaded. + """ + + # New version of the adapters available + for adapter_id in self.adapter_ids: + self.weights_got_updated[adapter_id] = True + for adapter_id in self.past_agent_adapter_ids: + self.weights_got_updated[adapter_id] = True + + # import random + # self.save_path = self.save_path + str(random.randint(1,500)) + # print(f"Save path: {self.save_path}") + # self.adapter_paths = {adapter_id:os.path.join(self.save_path, adapter_id) for adapter_id in self.adapter_ids} + + adapter_id = self.adapter_ids[0] + self.hf_adapters[adapter_id].save_pretrained(self.save_path) + + def checkpoint_all_adapters(self, checkpoint_indicator: str) -> None: + """ + Checkpoints all adapters to the configured output directory. + """ + adapter_id = self.adapter_ids[0] + output_dir = os.path.join(self.output_directory, "checkpoints") + os.makedirs(output_dir, exist_ok=True) + date_str = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + agent_adapter_dir = f"{adapter_id}-{checkpoint_indicator}-{date_str}" + export_path = os.path.join(output_dir, agent_adapter_dir) + for adapter_id in self.adapter_ids: + if "agent" in adapter_id: + self.past_agent_adapter_paths[ + f"{agent_adapter_dir}_buffer" + ] = os.path.join(export_path, adapter_id) + self.past_agent_adapter_ids.append(f"{agent_adapter_dir}_buffer") + self.weights_got_updated[f"{agent_adapter_dir}_buffer"] = False + self.hf_adapters[adapter_id].save_pretrained(export_path) + + def short_id_generator(self) -> str: + """ + Generates a short unique ID for tracking adapter versions. + + Returns: + int: An 8-digit integer ID. + """ + return str(uuid.uuid4().int)[:8] diff --git a/src_code_for_reproducibility/models/scalar_critic.py b/src_code_for_reproducibility/models/scalar_critic.py new file mode 100644 index 0000000000000000000000000000000000000000..b0cabf6acfb6757db2871778a397bdbe38b813dd --- /dev/null +++ b/src_code_for_reproducibility/models/scalar_critic.py @@ -0,0 +1,54 @@ +import torch, torch.nn as nn, torch.optim as optim +from transformers import AutoModelForCausalLM, AutoTokenizer +from peft import LoraConfig, get_peft_model + +from mllm.models.adapter_training_wrapper import AdapterWrapper + + +class ScalarCritic(nn.Module): + """ + A causal-LM critic_adapter + a scalar value head: + V_φ(s) = wᵀ h_last + b + Only LoRA adapters (inside critic_adapter) and the value head are trainable. + """ + def __init__(self, critic_adapter: AdapterWrapper): + super().__init__() + self.critic_adapter = critic_adapter + hidden_size = self.critic_adapter.shared_llm.config.hidden_size + self.value_head = nn.Linear(hidden_size, 1).to( + dtype=critic_adapter.dtype, + device=critic_adapter.device) + + def forward(self, + input_ids, + attention_mask=None, + **kwargs): + # AdapterWrapper activates its own adapter internally + outputs = self.critic_adapter( + input_ids=input_ids, + attention_mask=attention_mask, + output_hidden_states=True, + **kwargs, + ) + h_last = outputs.hidden_states[-1] # (B, S, H) + values = self.value_head(h_last).squeeze(-1) # (B, S) + return values + + def parameters(self, recurse: bool = True): + """Iterator over *trainable* parameters for this critic.""" + # 1) LoRA params for *this* adapter + for p in self.critic_adapter.parameters(): + yield p + # 2) scalar head + yield from self.value_head.parameters() + + def gradient_checkpointing_enable(self, *args, **kwargs): + self.critic_adapter.gradient_checkpointing_enable(*args, **kwargs) + + @property + def dtype(self): + return self.critic_adapter.dtype + + @property + def device(self): + return self.critic_adapter.device diff --git a/src_code_for_reproducibility/training/README.md b/src_code_for_reproducibility/training/README.md new file mode 100644 index 0000000000000000000000000000000000000000..e1d56e2e3d2fb427f8da2488de4dd9de89650ca2 --- /dev/null +++ b/src_code_for_reproducibility/training/README.md @@ -0,0 +1,20 @@ +Suppose we have a trajectory with 3 timesteps. +token: "0 1 2 3 4 5 6 7 8 9 . . . . ." +string: "A B C a b c A a A a b c A B C" (Capitalized = User, Lowercased = Assistant) +action_mask: "x x x ✓ ✓ ✓ x ✓ x ✓ ✓ ✓ x x x" (F = False, T = True) +rewards: "r r r r r r R R R R R R r r r" +timestep: "0 0 0 0 0 0 1 1 1 1 1 1 2 2 2" +state_ends: "x x ✓ x x x ✓ x x x x x x x ✓" + +There must be one baseline flag per timestep! + +Then, we might have + +A naive way to interpret this is to think of the number of assistant messages as the number of +steps in the environment. However, this is not the case in practice. Indeed, in a +single simulation step, + + + + +A subtlety arises with credit assignment. In the multi-agent case, we might diff --git a/src_code_for_reproducibility/training/__init__.py b/src_code_for_reproducibility/training/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src_code_for_reproducibility/training/annealing_methods.py b/src_code_for_reproducibility/training/annealing_methods.py new file mode 100644 index 0000000000000000000000000000000000000000..1d46d6fe04d9482ab4f2c5763f4735390d7ecdf6 --- /dev/null +++ b/src_code_for_reproducibility/training/annealing_methods.py @@ -0,0 +1,6 @@ +import numpy as np + + +def sigmoid_annealing(step: int, temperature: float) -> float: + return 2 / (1 + np.exp(-step / temperature)) - 1 + diff --git a/src_code_for_reproducibility/training/credit_methods.py b/src_code_for_reproducibility/training/credit_methods.py new file mode 100644 index 0000000000000000000000000000000000000000..33d07c6bffebf07959f4b3d9f0ebaf515a1b8783 --- /dev/null +++ b/src_code_for_reproducibility/training/credit_methods.py @@ -0,0 +1,304 @@ +import torch + + +def whiten_advantages(advantages: torch.Tensor) -> torch.Tensor: + """ + Whitens the advantages. + """ + whitened_advantages = (advantages - torch.mean(advantages)) / ( + torch.std(advantages) + 1e-9 + ) + return whitened_advantages + + +def whiten_advantages_time_step_wise( + advantages: torch.Tensor, # (B, T) +) -> torch.Tensor: + """ + Whitens the advantages. + """ + assert advantages.dim() == 2, "Wrong dimensions." + whitened_advantages_time_step_wise = ( + advantages - advantages.mean(dim=0, keepdim=True) + ) / (advantages.std(dim=0, keepdim=True) + 1e-9) + return whitened_advantages_time_step_wise + + +def get_discounted_state_visitation_credits( + credits: torch.Tensor, discount_factor: float # (B, T) +) -> torch.Tensor: + """ + Computes discounted state visitation credits for a sequence of credits. + """ + return credits * ( + discount_factor ** torch.arange(credits.shape[1], device=credits.device) + ) + + +def get_discounted_returns( + rewards: torch.Tensor, # (B, T) + discount_factor: float, +) -> torch.Tensor: + """ + Computes Monte Carlo discounted returns for a sequence of rewards. + + Args: + rewards (torch.Tensor): Array of rewards for each timestep. + + Returns: + torch.Tensor: Array of discounted returns. + """ + assert rewards.dim() == 2, "Wrong dimensions." + B, T = rewards.shape + discounted_returns = torch.zeros_like(rewards) + accumulator = torch.zeros(B, device=rewards.device, dtype=rewards.dtype) + for t in reversed(range(T)): + accumulator = rewards[:, t] + discount_factor * accumulator + discounted_returns[:, t] = accumulator + return discounted_returns + + +def get_rloo_credits(credits: torch.Tensor): # (B, S) + assert credits.dim() == 2, "Wrong dimensions." + rloo_baselines = torch.zeros_like(credits) + n = credits.shape[0] + if n == 1: + return credits, rloo_baselines + rloo_baselines = (torch.sum(credits, dim=0, keepdim=True) - credits) / (n - 1) + rloo_credits = credits - rloo_baselines + return rloo_credits, rloo_baselines + + +def get_generalized_advantage_estimates( + rewards: torch.Tensor, # (B, T) + value_estimates: torch.Tensor, # (B, T+1) + discount_factor: float, + lambda_coef: float, +) -> torch.Tensor: + """ + Computes Generalized Advantage Estimates (GAE) for a sequence of rewards and value estimates. + See https://arxiv.org/pdf/1506.02438 for details. + + + Returns: + torch.Tensor: Array of GAE values. + """ + assert rewards.dim() == value_estimates.dim() == 2, "Wrong dimensions." + + assert ( + rewards.shape[0] == value_estimates.shape[0] + ), f"Got shapes {rewards.shape} and {value_estimates.shape} of rewards and value estimates." + assert ( + rewards.shape[1] == value_estimates.shape[1] - 1 + ), f"Got shapes {rewards.shape} and {value_estimates.shape} of rewards and value estimates." + + T = rewards.shape[1] + tds = rewards + discount_factor * value_estimates[:, 1:] - value_estimates[:, :-1] + gaes = torch.zeros_like(tds) + acc = 0.0 + for t in reversed(range(T)): + acc = tds[:, t] + lambda_coef * discount_factor * acc + gaes[:, t] = acc + return gaes + + +def get_advantage_alignment_weights( + advantages: torch.Tensor, # (B, T) + exclude_k_equals_t: bool, + gamma: float, + discount_t: bool, +) -> torch.Tensor: + """ + The advantage alignment credit is calculated as + + \[ + A^*(s_t, a_t, b_t) = A^1(s_t, a_t, b_t) + \beta \cdot + \left( \sum_{k < t} \gamma^{t-k} A^1(s_k, a_k, b_k) \right) + A^2(s_t, a_t, b_t) + \] + + Here, the weights are defined as \( \beta \cdot + \left( \sum_{k < t} \gamma^{t-k} A^1(s_k, a_k, b_k) \) + """ + T = advantages.shape[1] + discounted_advantages = advantages * ( + gamma * torch.ones((1, T), device=advantages.device) + ) ** (-torch.arange(0, T, 1, device=advantages.device)) + if exclude_k_equals_t: + sub = torch.eye(T, device=advantages.device) + else: + sub = torch.zeros((T, T), device=advantages.device) + # Identity is for \( k < t \), remove for \( k \leq t \) + ad_align_weights = discounted_advantages @ ( + torch.triu(torch.ones((T, T), device=advantages.device)) - sub + ) + t_discounts = (gamma * torch.ones((1, T), device=advantages.device)) ** ( + torch.arange(0, T, 1, device=advantages.device) + ) + ad_align_weights = t_discounts * ad_align_weights + if discount_t: + time_discounted_advantages = advantages * ( + gamma * torch.ones((1, T), device=advantages.device) + ) ** (torch.arange(0, T, 1, device=advantages.device)) + ad_align_weights = ad_align_weights - advantages + time_discounted_advantages + return ad_align_weights + + +def get_advantage_alignment_credits( + a1: torch.Tensor, # (B, S) + a1_alternative: torch.Tensor, # (B, S, A) + a2: torch.Tensor, # (B, S) + exclude_k_equals_t: bool, + beta: float, + gamma: float = 1.0, + use_old_ad_align: bool = False, + use_sign: bool = False, + clipping: float | None = None, + use_time_regularization: bool = False, + force_coop_first_step: bool = False, + use_variance_regularization: bool = False, + rloo_branch: bool = False, + reuse_baseline: bool = False, + mean_normalize_ad_align: bool = False, + whiten_adalign_advantages: bool = False, + whiten_adalign_advantages_time_step_wise: bool = False, + discount_t: bool = False, +) -> torch.Tensor: + """ + Calculate the advantage alignment credits with vectorization, as described in https://arxiv.org/abs/2406.14662. + + Recall that the advantage opponent shaping term of the AdAlign policy gradient is: + \[ + \beta \mathbb{E}_{\substack{ + \tau \sim \text{Pr}_{\mu}^{\pi^1, \pi^2} \\ + a_t' \sim \pi^1(\cdot \mid s_t) + }} + \left[\sum_{t=0}^\infty \gamma^{t}\left( \sum_{k\leq t} A^1(s_k,a^{\prime}_k,b_k) \right) A^{2}(s_t,a_t, b_t)\nabla_{\theta^1}\text{log } \pi^1(a_t|s_t) \right] + \] + + This method computes the following: + \[ + Credit(s_t, a_t, b_t) = \gamma^t \left[ A^1(s_t, a_t, b_t) + \beta \left( \sum_{k\leq t} A^1(s_k,a^{\prime}_k,b_k) \right) A^{2}(s_t,a_t, b_t) \right] + \] + + Args: + a1: Advantages of the main trajectories for the current agent. + a1_alternative: Advantages of the alternative trajectories for the current agent. + a2: Advantages of the main trajectories for the other agent. + discount_factor: Discount factor for the advantage alignment. + beta: Beta parameter for the advantage alignment. + gamma: Gamma parameter for the advantage alignment. + use_sign_in_ad_align: Whether to use sign in the advantage alignment. + + Returns: + torch.Tensor: The advantage alignment credits. + """ + + assert a1.dim() == a2.dim() == 2, "Advantages must be of shape (B, S)" + if a1_alternative is not None: + assert ( + a1_alternative.dim() == 3 + ), "Alternative advantages must be of shape (B, S, A)" + B, T, A = a1_alternative.shape + else: + B, T = a1.shape + assert a1.shape == a2.shape, "Not the same shape" + + sub_tensors = {} + + if use_old_ad_align: + ad_align_weights = get_advantage_alignment_weights( + advantages=a1, + exclude_k_equals_t=exclude_k_equals_t, + gamma=gamma, + discount_t=discount_t, + ) + sub_tensors["ad_align_weights_prev"] = ad_align_weights + if exclude_k_equals_t: + ad_align_weights = gamma * ad_align_weights + else: + assert a1_alternative is not None, "Alternative advantages must be provided" + if rloo_branch: + a1_alternative = torch.cat([a1.unsqueeze(2), a1_alternative], dim=2) + a1_alternative = a1_alternative.mean(dim=2) + # print(f"a1_alternative: {a1_alternative}, a1: {a1}\n") + a1, baseline = get_rloo_credits(a1) + if reuse_baseline: + a1_alternative = a1_alternative - baseline + else: + a1_alternative, _ = get_rloo_credits(a1_alternative) + assert a1.shape == a1_alternative.shape, "Not the same shape" + ad_align_weights = get_advantage_alignment_weights( + advantages=a1_alternative, + exclude_k_equals_t=exclude_k_equals_t, + gamma=gamma, + ) + sub_tensors["ad_align_weights"] = ad_align_weights + + # Use sign + if use_sign: + assert beta == 1.0, "beta should be 1.0 when using sign" + positive_signs = ad_align_weights > 0 + negative_signs = ad_align_weights < 0 + ad_align_weights[positive_signs] = 1 + ad_align_weights[negative_signs] = -1 + sub_tensors["ad_align_weights_sign"] = ad_align_weights + # (rest are 0) + + ################### + # Process weights + ################### + + # Use clipping + if clipping not in [0.0, None]: + upper_mask = ad_align_weights > 1 + lower_mask = ad_align_weights < -1 + + ad_align_weights = torch.clip( + ad_align_weights, + -clipping, + clipping, + ) + clipping_ratio = ( + torch.sum(upper_mask) + torch.sum(lower_mask) + ) / upper_mask.size + sub_tensors["clipped_ad_align_weights"] = ad_align_weights + + # 1/1+t Regularization + if use_time_regularization: + t_values = torch.arange(1, T + 1).to(ad_align_weights.device) + ad_align_weights = ad_align_weights / t_values + sub_tensors["time_regularized_ad_align_weights"] = ad_align_weights + + # Use coop on t=0 + if force_coop_first_step: + ad_align_weights[:, 0] = 1 + sub_tensors["coop_first_step_ad_align_weights"] = ad_align_weights + # # Normalize alignment terms (across same time step) + # if use_variance_regularization_in_ad_align: + # # TODO: verify + # reg_coef = torch.std(a1[:, -1]) / (torch.std(opp_shaping_terms[:, -1]) + 1e-9) + # opp_shaping_terms *= reg_coef + + #################################### + # Compose elements together + #################################### + + opp_shaping_terms = beta * ad_align_weights * a2 + sub_tensors["ad_align_opp_shaping_terms"] = opp_shaping_terms + + credits = a1 + opp_shaping_terms + if mean_normalize_ad_align: + credits = credits - credits.mean(dim=0) + sub_tensors["mean_normalized_ad_align_credits"] = credits + if whiten_adalign_advantages: + credits = (credits - credits.mean()) / (credits.std() + 1e-9) + sub_tensors["whitened_ad_align_credits"] = credits + if whiten_adalign_advantages_time_step_wise: + credits = (credits - credits.mean(dim=0, keepdim=True)) / ( + credits.std(dim=0, keepdim=True) + 1e-9 + ) + sub_tensors["whitened_ad_align_credits_time_step_wise"] = credits + sub_tensors["final_ad_align_credits"] = credits + + return credits, sub_tensors diff --git a/src_code_for_reproducibility/training/tally_metrics.py b/src_code_for_reproducibility/training/tally_metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..9d3675c78a66b65a7d538975586c5882ee11b84c --- /dev/null +++ b/src_code_for_reproducibility/training/tally_metrics.py @@ -0,0 +1,55 @@ +import os +from numbers import Number +from typing import Union + +import wandb + + +class Tally: + """ + Minimal scalar-first tally. + - Keys are strings. + - First add stores a scalar; subsequent adds upgrade to a list of scalars. + """ + + def __init__(self): + self.stats = {} + + def reset(self): + self.stats = {} + + def _coerce_scalar(self, value: Union[int, float]) -> Union[int, float]: + if hasattr(value, "item") and callable(getattr(value, "item")): + try: + value = value.item() + except Exception: + pass + if isinstance(value, Number): + return value + raise AssertionError("Metric must be a scalar number") + + def add_metric(self, path: str, metric: Union[int, float]): + metric = float(metric) + assert isinstance(path, str), "Path must be a string." + assert isinstance(metric, float), "Metric must be a scalar number." + + scalar = self._coerce_scalar(metric) + existing = self.stats.get(path) + if existing is None: + self.stats[path] = scalar + elif isinstance(existing, list): + existing.append(scalar) + else: + self.stats[path] = [existing, scalar] + + def save(self, identifier: str, folder: str): + os.makedirs(name=folder, exist_ok=True) + try: + import pickle + + pkl_path = os.path.join(folder, f"{identifier}.tally.pkl") + payload = self.stats + with open(pkl_path, "wb") as f: + pickle.dump(payload, f, protocol=pickle.HIGHEST_PROTOCOL) + except Exception: + pass diff --git a/src_code_for_reproducibility/training/tally_rollout.py b/src_code_for_reproducibility/training/tally_rollout.py new file mode 100644 index 0000000000000000000000000000000000000000..18bbff345510b55f36aafc88f3f8be9ef46657cc --- /dev/null +++ b/src_code_for_reproducibility/training/tally_rollout.py @@ -0,0 +1,137 @@ +import json +import os +from copy import deepcopy +from typing import Union + +import numpy as np +import pandas as pd +import torch +from transformers import AutoTokenizer + + +class RolloutTallyItem: + def __init__(self, crn_ids: list[str], rollout_ids: list[str], agent_ids: list[str], metric_matrix: torch.Tensor): + """ + Initializes the RolloutTallyItem object. + + Args: + crn_ids (list[str]): List of CRN IDs. + rollout_ids (list[str]): List of rollout IDs. + agent_ids (list[str]): List of agent IDs. + metric_matrix (torch.Tensor): Metric matrix. + """ + if isinstance(crn_ids, torch.Tensor): + crn_ids = crn_ids.detach().cpu().numpy() + if isinstance(rollout_ids, torch.Tensor): + rollout_ids = rollout_ids.detach().cpu().numpy() + if isinstance(agent_ids, torch.Tensor): + agent_ids = agent_ids.detach().cpu().numpy() + self.crn_ids = crn_ids + self.rollout_ids = rollout_ids + self.agent_ids = agent_ids + metric_matrix = metric_matrix.detach().cpu() + assert 0 < metric_matrix.ndim <= 2, "Metric matrix must have less than or equal to 2 dimensions" + if metric_matrix.ndim == 1: + metric_matrix = metric_matrix.reshape(1, -1) + # Convert to float32 if tensor is in BFloat16 format (not supported by numpy) + if metric_matrix.dtype == torch.bfloat16: + metric_matrix = metric_matrix.float() + self.metric_matrix = metric_matrix.numpy() + +class RolloutTally: + """ + Tally is a utility class for collecting and storing training metrics. + It supports adding metrics at specified paths and saving them to disk. + """ + + def __init__(self): + """ + Initializes the RolloutTally object. + + Args: + tokenizer (AutoTokenizer): Tokenizer for converting token IDs to strings. + max_context_length (int, optional): Maximum context length for contextualized metrics. Defaults to 30. + """ + # Array-preserving structure (leaf lists hold numpy arrays / scalars) + self.metrics = {} + # Global ordered list of sample identifiers (crn_id, rollout_id) added in the order samples are processed + + def reset(self): + """ + Resets the base and contextualized tallies to empty dictionaries. + """ + self.metrics = {} + + def get_from_nested_dict(self, dictio: dict, path: str): + """ + Retrieves the value at a nested path in a dictionary. + + Args: + dictio (dict): The dictionary to search. + path (list): List of keys representing the path. + + Returns: + Any: The value at the specified path, or None if not found. + """ + assert isinstance(path, list), "Path must be list." + for sp in path[:-1]: + dictio = dictio.setdefault(sp, {}) + return dictio.get(path[-1], None) + + def set_at_path(self, dictio: dict, path: str, value): + """ + Sets a value at a nested path in a dictionary, creating intermediate dictionaries as needed. + + Args: + dictio (dict): The dictionary to modify. + path (list): List of keys representing the path. + value (Any): The value to set at the specified path. + """ + for sp in path[:-1]: + dictio = dictio.setdefault(sp, {}) + dictio[path[-1]] = value + + + def add_metric( + self, path: list[str], rollout_tally_item: RolloutTallyItem + ): + """ + Adds a metric to the base tally at the specified path. + + Args: + path (list): List of keys representing the path in the base tally. + rollout_tally_item (RolloutTallyItem): The rollout tally item to add. + """ + rollout_tally_item = deepcopy(rollout_tally_item) + + # Update array-preserving tally + array_list = self.get_from_nested_dict(dictio=self.metrics, path=path) + if array_list is None: + self.set_at_path(dictio=self.metrics, path=path, value=[rollout_tally_item]) + else: + array_list.append(rollout_tally_item) + + + def save(self, identifier: str, folder: str): + """ + Saves the base and contextualized tallies to disk as JSON files, and also saves contextualized tallies as CSV files for each game/rollout. + + Args: + path (str): Directory path where the metrics will be saved. + """ + os.makedirs(name=folder, exist_ok=True) + + from datetime import datetime + + now = datetime.now() + + # Pickle only (fastest, exact structure with numpy/scalars at leaves) + try: + import pickle + + pkl_path = os.path.join(folder, f"{identifier}.rt_tally.pkl") + payload = {"metrics": self.metrics} + with open(pkl_path, "wb") as f: + pickle.dump(payload, f, protocol=pickle.HIGHEST_PROTOCOL) + except Exception: + pass diff --git a/src_code_for_reproducibility/training/tally_tokenwise.py b/src_code_for_reproducibility/training/tally_tokenwise.py new file mode 100644 index 0000000000000000000000000000000000000000..5c04f480dedc113e4004c1143ff46ff5cc34785b --- /dev/null +++ b/src_code_for_reproducibility/training/tally_tokenwise.py @@ -0,0 +1,276 @@ +import json +import os +from typing import Any, Dict, List, Tuple, Union + +import numpy as np +import pandas as pd +import torch +from transformers import AutoTokenizer + + +class ContextualizedTokenwiseTally: + """ + Collect, store, and save token-level metrics per rollout. + + - One DataFrame per rollout_id in `paths` + - Index = timestep (int) + - Columns are added incrementally via `add_contexts()` and `add_data()` + - Cells may contain scalars, strings, or lists (dtype=object) + """ + + def __init__( + self, + tokenizer: AutoTokenizer, + paths: List[str], + max_context_length: int = 30, + ): + """ + Args: + tokenizer: HuggingFace tokenizer used to convert tids -> tokens + paths: rollout identifiers (parallel to batch dimension) + max_context_length: truncate context token lists to this length + """ + self.tokenizer = tokenizer + self.paths = paths + self.max_context_length = max_context_length + self.tally: Dict[str, pd.DataFrame] = {path: pd.DataFrame() for path in paths} + + # set later by setters + self.contexts: torch.Tensor | None = None + self.action_mask: torch.Tensor | None = None + self.range: Tuple[int, int] | None = None + + # --------- Utilities --------- + + def tids_to_str(self, tids: List[int]) -> List[str]: + """Convert a list of token IDs to a list of token strings.""" + return self.tokenizer.convert_ids_to_tokens(tids) + + def _ensure_ready(self): + assert self.action_mask is not None, "call set_action_mask(mask) first" + assert self.range is not None, "call set_range((start, end)) first" + + @staticmethod + def _sanitize_filename(name: Any) -> str: + """Make a safe filename from any rollout_id.""" + s = str(name) + bad = {os.sep, " ", ":", "|", "<", ">", '"', "'"} + if os.altsep is not None: + bad.add(os.altsep) + for ch in bad: + s = s.replace(ch, "_") + return s + + @staticmethod + def _pad_left(seq: List[Any], length: int, pad_val: Any = "") -> List[Any]: + """Left-pad a sequence to `length` with `pad_val`.""" + if len(seq) >= length: + return seq[-length:] + return [pad_val] * (length - len(seq)) + list(seq) + + # --------- Setters --------- + + def set_action_mask(self, action_mask: torch.Tensor): + """ + action_mask: (B, S) bool or 0/1 indicating valid steps + """ + self.action_mask = action_mask + + def set_range(self, range: Tuple[int, int]): + """ + range: slice (start, end) into self.paths for current batch + """ + self.range = range + + # --------- Column builders --------- + + def add_contexts(self, contexts: torch.Tensor): + """ + Add a single 'context' column (list[str]) for valid steps. + + Expects `contexts` with shape (B, S): token id at each timestep. + For each valid timestep t, we use the last N tokens up to and including t: + window = contexts[i, max(0, t - N + 1) : t + 1] + The list is left-padded with "" to always be length N. + """ + self._ensure_ready() + + current_paths = self.paths[self.range[0] : self.range[1]] + B, S = contexts.shape + N = self.max_context_length + + # to CPU ints once + contexts_cpu = contexts.detach().to("cpu") + + for i in range(B): + rollout_id = current_paths[i] + df = self.tally.get(rollout_id, pd.DataFrame()) + + valid_idx = torch.nonzero( + self.action_mask[i].bool(), as_tuple=False + ).squeeze(-1) + if valid_idx.numel() == 0: + self.tally[rollout_id] = df + continue + + idx_list = valid_idx.tolist() + + # ensure index contains valid steps + if df.empty: + df = pd.DataFrame(index=idx_list) + else: + new_index = sorted(set(df.index.tolist()) | set(idx_list)) + if list(df.index) != new_index: + df = df.reindex(new_index) + + # build context windows + ctx_token_lists = [] + for t in idx_list: + start = max(0, t - N + 1) + window_ids = contexts_cpu[i, start : t + 1].tolist() + window_toks = self.tids_to_str([int(x) for x in window_ids]) + if len(window_toks) < N: + window_toks = [""] * (N - len(window_toks)) + window_toks + else: + window_toks = window_toks[-N:] + ctx_token_lists.append(window_toks) + + # single 'context' column + if "context" not in df.columns: + df["context"] = pd.Series(index=df.index, dtype=object) + df.loc[idx_list, "context"] = pd.Series( + ctx_token_lists, index=idx_list, dtype=object + ) + + self.tally[rollout_id] = df + + def add_data( + self, + metric_id: str, + metrics: torch.Tensor, + to_tids: bool = False, + ): + """ + Add a metric column for valid steps. + + Args: + metric_id: column name + metrics: shape (B, S) for scalars/ids or (B, S, K) for top-k vectors + to_tids: if True, treat ints/lists of ints as tids and convert to tokens + """ + self._ensure_ready() + current_paths = self.paths[self.range[0] : self.range[1]] + + if metrics.dim() == 2: + B, S = metrics.shape + elif metrics.dim() == 3: + B, S, _ = metrics.shape + else: + raise ValueError("metrics must be (B, S) or (B, S, K)") + + for i in range(B): + rollout_id = current_paths[i] + df = self.tally.get(rollout_id, pd.DataFrame()) + + valid_idx = torch.nonzero( + self.action_mask[i].bool(), as_tuple=False + ).squeeze(-1) + if valid_idx.numel() == 0: + self.tally[rollout_id] = df + continue + + idx_list = valid_idx.detach().cpu().tolist() + + # Ensure index contains valid steps + if df.empty: + df = pd.DataFrame(index=idx_list) + else: + new_index = sorted(set(df.index.tolist()) | set(idx_list)) + if list(df.index) != new_index: + df = df.reindex(new_index) + + # Slice metrics at valid steps + m_valid = metrics[i][valid_idx] + + # -> pure python lists (1D list or list-of-lists) + values = m_valid.detach().cpu().tolist() + + # optional tids -> tokens + if to_tids: + + def _to_tokish(x): + if isinstance(x, list): + return self.tids_to_str([int(v) for v in x]) + else: + return self.tids_to_str([int(x)])[0] + + values = [_to_tokish(v) for v in values] + + # Ensure column exists with object dtype, then assign via aligned Series + if metric_id not in df.columns: + df[metric_id] = pd.Series(index=df.index, dtype=object) + + if isinstance(values, np.ndarray): + values = values.tolist() + + if len(values) != len(idx_list): + raise ValueError( + f"Length mismatch for '{metric_id}': values={len(values)} vs idx_list={len(idx_list)}" + ) + + df.loc[idx_list, metric_id] = pd.Series( + values, index=idx_list, dtype=object + ) + self.tally[rollout_id] = df + + # --------- Saving --------- + + def save(self, path: str): + """ + Write a manifest JSON and one CSV per rollout. + + - Manifest includes metadata only (safe to JSON). + - Each rollout CSV is written with index label 'timestep'. + - Only a single 'context' column (list[str]). + """ + if not self.tally or all(df.empty for df in self.tally.values()): + return + + os.makedirs(path, exist_ok=True) + from datetime import datetime + + now = datetime.now() + + manifest = { + "created_at": f"{now:%Y-%m-%d %H:%M:%S}", + "max_context_length": self.max_context_length, + "num_rollouts": len(self.tally), + "rollouts": [], + } + + for rid, df in self.tally.items(): + rid_str = str(rid) + safe_name = self._sanitize_filename(rid_str) + csv_path = os.path.join(path, f"{safe_name}_tokenwise.csv") + + # Put 'context' first, then the rest + cols = ["context"] + [c for c in df.columns if c != "context"] + try: + df[cols].to_csv(csv_path, index=True, index_label="timestep") + except Exception as e: + continue + + manifest["rollouts"].append( + { + "rollout_id": rid_str, + "csv": csv_path, + "num_rows": int(df.shape[0]), + "columns": cols, + } + ) + + manifest_path = os.path.join( + path, f"tokenwise_manifest_{now:%Y-%m-%d___%H-%M-%S}.json" + ) + with open(manifest_path, "w") as fp: + json.dump(manifest, fp, indent=2) diff --git a/src_code_for_reproducibility/training/tokenize_chats.py b/src_code_for_reproducibility/training/tokenize_chats.py new file mode 100644 index 0000000000000000000000000000000000000000..5e6017cdbb805c3002c254eebd151d7192eec4ad --- /dev/null +++ b/src_code_for_reproducibility/training/tokenize_chats.py @@ -0,0 +1,128 @@ +import logging +import sys + +import regex +import torch +from transformers import AutoTokenizer + +from mllm.training.training_data_utils import TrainingChatTurn, TrajectoryBatch + +logger = logging.getLogger(__name__) +logger.addHandler(logging.StreamHandler(sys.stdout)) + + +# def get_chat_dicts(chat: list[TrainingChatTurn]) -> list[dict]: +# chat_dicts = [chat_turn.dict() for chat_turn in chat] +# return chat_dicts + + +def process_training_chat( + tokenizer: AutoTokenizer, + chat_history: list[TrainingChatTurn], + entropy_mask_regex: str | None = None, + exploration_prompts_to_remove: list[str] = [], + use_engine_out_token_ids: bool = False, +) -> tuple[torch.IntTensor, torch.BoolTensor, torch.IntTensor, torch.BoolTensor]: + """Tokenize a single training chat and build aligned per-token masks. + + Given an ordered list of `TrainingChatTurn`, this function tokenizes each + turn independently using the tokenizer's chat template, then concatenates + all resulting token sequences. It also constructs three parallel 1D masks + that align with the concatenated tokens: + + - input_ids: token ids for the entire chat, turn by turn + - action_mask: True for tokens that belong to assistant turns (i.e., model + actions), False for tokens from other roles + - timesteps: per-token time step copied from the originating turn's + `time_step` + - state_ends_mask: True for the last token of any turn where + `is_state_end` is True, otherwise False + + Important details: + - Each turn is passed as a single-message list to + `tokenizer.apply_chat_template` and flattened; the per-turn outputs are + then concatenated in the original order. + - Turn boundaries are not explicitly encoded beyond what the chat template + inserts; masks provide alignment for learning signals and state endings. + - No truncation or padding is performed here; downstream code should handle + batching/padding as needed. + - Note on dtypes: `input_ids` will be a LongTensor (int64). `action_mask` + and `state_ends_mask` are BoolTensors. `timesteps` is currently created + as a float tensor; adjust the implementation if integer dtype is + required downstream. + + Args: + tokenizer: A Hugging Face tokenizer supporting `apply_chat_template`. + chat_history: Ordered list of `TrainingChatTurn` forming one dialogue. + + Returns: + A tuple of four 1D tensors, all of equal length N (the total number of + tokens across all turns), in the following order: + - input_ids (LongTensor) + - action_mask (BoolTensor) + - timesteps (FloatTensor as implemented; see note above) + - state_ends_mask (BoolTensor) + """ + state_ends_mask = [] + input_ids = [] + action_mask = [] + timesteps = [] + entropy_mask = [] + engine_log_probs = [] + for train_chat_turn in chat_history: + is_state_end = train_chat_turn.is_state_end + time_step = train_chat_turn.time_step + is_action = train_chat_turn.role == "assistant" + + # Remove exploration prompts from training data + for exploration_prompt in exploration_prompts_to_remove: + if exploration_prompt in train_chat_turn.content: + train_chat_turn.content = train_chat_turn.content.replace( + exploration_prompt, "" + ) + + chat_turn = { + "role": train_chat_turn.role, + "content": train_chat_turn.content, + } + if entropy_mask_regex is not None: + is_entropy_mask_true = ( + regex.search(entropy_mask_regex, train_chat_turn.content) is not None + ) + else: + is_entropy_mask_true = True + if is_action: + chat_turn_ids = train_chat_turn.out_token_ids + nb_chat_turns_ids = chat_turn_ids.numel() + action_mask.append(torch.ones(nb_chat_turns_ids, dtype=torch.bool)) + engine_log_probs.append(train_chat_turn.log_probs) + else: + chat_turn_ids = train_chat_turn.chat_template_token_ids + nb_chat_turns_ids = chat_turn_ids.numel() + action_mask.append(torch.zeros(nb_chat_turns_ids, dtype=torch.bool)) + engine_log_probs.append(torch.zeros(nb_chat_turns_ids, dtype=torch.float)) + nb_chat_turns_ids = chat_turn_ids.numel() + state_ends_mask.append(torch.zeros(nb_chat_turns_ids, dtype=torch.bool)) + if is_state_end: + state_ends_mask[-1][-1] = True # last token is state end + input_ids.append(chat_turn_ids) + entropy_mask.append(torch.ones(nb_chat_turns_ids, dtype=torch.bool)) + if not is_entropy_mask_true: + entropy_mask[-1] = entropy_mask[-1] * False + timesteps.append(torch.ones(nb_chat_turns_ids) * time_step) + input_ids = torch.cat(input_ids) + action_mask = torch.cat(action_mask) + entropy_mask = torch.cat(entropy_mask) + timesteps = torch.cat(timesteps) + timesteps = timesteps.to(torch.long) + state_ends_mask = torch.cat(state_ends_mask) + engine_log_probs = torch.cat(engine_log_probs) + + return ( + input_ids, + action_mask, + entropy_mask, + timesteps, + state_ends_mask, + engine_log_probs, + ) diff --git a/src_code_for_reproducibility/training/trainer_ad_align.py b/src_code_for_reproducibility/training/trainer_ad_align.py new file mode 100644 index 0000000000000000000000000000000000000000..d058d34c3cb51615428eca7905ef27cb7509eb18 --- /dev/null +++ b/src_code_for_reproducibility/training/trainer_ad_align.py @@ -0,0 +1,495 @@ +import copy +import logging +import sys +from dataclasses import dataclass +from typing import Tuple + +import torch +from torch.nn.utils.rnn import pad_sequence + +from mllm.markov_games.rollout_tree import ( + ChatTurn, + RolloutTreeBranchNode, + RolloutTreeRootNode, +) +from mllm.training.credit_methods import ( + get_advantage_alignment_credits, + get_discounted_state_visitation_credits, +) +from mllm.training.tally_metrics import Tally +from mllm.training.tally_rollout import RolloutTally, RolloutTallyItem +from mllm.training.tally_tokenwise import ContextualizedTokenwiseTally +from mllm.training.tokenize_chats import process_training_chat +from mllm.training.trainer_common import BaseTrainer +from mllm.training.training_data_utils import ( + AdvantagePacket, + TrainingBatch, + TrainingChatTurn, + TrajectoryBatch, + get_main_chat_list_and_rewards, + get_tokenwise_credits, +) +from mllm.utils.resource_context import resource_logger_context + +logger = logging.getLogger(__name__) +logger.addHandler(logging.StreamHandler(sys.stdout)) + +RolloutId = int +AgentId = str + + +@dataclass +class AdAlignTrainingData: + agent_id: str + main_data: TrajectoryBatch + # list-of-tensors: per rollout advantages with length jT + main_advantages: list[torch.FloatTensor] | None = None + # list-of-tensors: per rollout matrix (jT, A) + alternative_advantages: list[torch.FloatTensor] | None = None + advantage_alignment_credits: list[torch.FloatTensor] | None = None + + +def get_alternative_chat_histories( + agent_id: str, root: RolloutTreeRootNode +) -> list[list[TrainingChatTurn], list[torch.FloatTensor]]: + """ + args: + agent_id: The agent we want to get the chat history for. + root: The root of the rollout tree. + returns: + alternative_chats: list[list[TrainingChatTurn]] (jT*A, jS') + alternative_rewards: list[torch.FloatTensor] (jT*A, jT') + """ + current_node = root.child + branches = current_node.branches + pre_branch_chat = [] + pre_branch_rewards = [] + alternative_rewards = [] + alternative_chats = [] + while current_node is not None: + assert isinstance( + current_node, RolloutTreeBranchNode + ), "Current node should be a branch node." + main_node = current_node.main_child + branches = current_node.branches + current_node = main_node.child + + # Get the `A` alternative trajectories + alternative_nodes = branches[agent_id] + for alt_node in alternative_nodes: + post_branch_chat, post_branch_rewards = get_main_chat_list_and_rewards( + agent_id=agent_id, root=alt_node + ) + branch_chat = pre_branch_chat + post_branch_chat + alternative_chats.append(branch_chat) + alternative_rewards.append( + torch.cat([torch.tensor(pre_branch_rewards), post_branch_rewards]) + ) + + chat_turns: list[ChatTurn] = main_node.step_log.action_logs[agent_id].chat_turns + chat_turns: list[TrainingChatTurn] = [ + TrainingChatTurn(time_step=main_node.time_step, **turn.model_dump()) + for turn in chat_turns + ] + + pre_branch_chat.extend(chat_turns) + pre_branch_rewards.append( + main_node.step_log.simulation_step_log.rewards[agent_id] + ) + + return alternative_chats, alternative_rewards + + +class TrainerAdAlign(BaseTrainer): + """ + Extends the reinforce trainer to support Advantage Alignment. + """ + + def __init__( + self, + ad_align_beta: float, + ad_align_gamma: float, + ad_align_exclude_k_equals_t: bool, + ad_align_use_sign: bool, + ad_align_clipping: float, + ad_align_force_coop_first_step: bool, + use_old_ad_align: bool, + use_time_regularization: bool, + rloo_branch: bool, + reuse_baseline: bool, + ad_align_beta_anneal_step: int = -1, + ad_align_beta_anneal_rate: float = 0.5, + min_ad_align_beta: float = 0.1, + mean_normalize_ad_align: bool = False, + whiten_adalign_advantages: bool = False, + whiten_adalign_advantages_time_step_wise: bool = False, + ad_align_discount_t: bool = False, + *args, + **kwargs, + ): + """ + Initialize the advantage alignment trainer. + Args: + ad_align_beta: Beta parameter for the advantage alignment. + ad_align_gamma: Gamma parameter for the advantage alignment. + ad_align_exclude_k_equals_t: Whether to include k = t in the advantage alignment. + ad_align_use_sign: Whether to use sign in the advantage alignment. + ad_align_clipping: Clipping value for the advantage alignment. + ad_align_force_coop_first_step: Whether to force coop on the first step of the advantage alignment. + """ + super().__init__(*args, **kwargs) + self.ad_align_beta = ad_align_beta + self.ad_align_gamma = ad_align_gamma + self.ad_align_exclude_k_equals_t = ad_align_exclude_k_equals_t + self.ad_align_use_sign = ad_align_use_sign + self.ad_align_clipping = ad_align_clipping + self.ad_align_force_coop_first_step = ad_align_force_coop_first_step + self.use_old_ad_align = use_old_ad_align + self.use_time_regularization = use_time_regularization + self.rloo_branch = rloo_branch + self.reuse_baseline = reuse_baseline + self.ad_align_beta_anneal_step = ad_align_beta_anneal_step + self.ad_align_beta_anneal_rate = ad_align_beta_anneal_rate + self.min_ad_align_beta = min_ad_align_beta + self.past_ad_align_step = -1 + self.mean_normalize_ad_align = mean_normalize_ad_align + self.whiten_adalign_advantages = whiten_adalign_advantages + self.whiten_adalign_advantages_time_step_wise = ( + whiten_adalign_advantages_time_step_wise + ) + self.ad_align_discount_t = ad_align_discount_t + self.training_data: dict[AgentId, AdAlignTrainingData] = {} + self.debug_path_list: list[str] = [] + + def set_agent_trajectory_data( + self, agent_id: str, roots: list[RolloutTreeRootNode] + ): + """ + TOWRITE + Set the advantage alignment data for the trainer. + """ + + B = len(roots) # Number of rollouts + + # For main rollouts + batch_rollout_ids = [] + batch_crn_ids = [] + batch_input_ids = [] + batch_action_mask = [] + batch_entropy_mask = [] + batch_timesteps = [] + batch_state_ends_mask = [] + batch_engine_log_probs = [] + batch_rewards = [] + + # For alternative actions rollouts + batch_branching_time_steps = [] + alternative_batch_input_ids = [] + alternative_batch_action_mask = [] + alternative_batch_entropy_mask = [] + alternative_batch_timesteps = [] + alternative_batch_state_ends_mask = [] + alternative_batch_engine_log_probs = [] + alternative_batch_rewards = [] + jT_list = [] + + try: + A = len(roots[0].child.branches[agent_id]) # Number of alternative actions + except: + A = 0 + + for root in roots: + rollout_id = root.id + self.debug_path_list.append( + "mgid:" + str(rollout_id) + "_agent_id:" + agent_id + ) + # Get main trajectory + batch_rollout_ids.append(rollout_id) + batch_crn_ids.append(root.crn_id) + main_chat, main_rewards = get_main_chat_list_and_rewards( + agent_id=agent_id, root=root + ) + ( + input_ids, + action_mask, + entropy_mask, + timesteps, + state_ends_mask, + engine_log_probs, + ) = process_training_chat( + tokenizer=self.tokenizer, + chat_history=main_chat, + entropy_mask_regex=self.entropy_mask_regex, + exploration_prompts_to_remove=self.exploration_prompts_to_remove, + ) + batch_input_ids.append(input_ids) + batch_action_mask.append(action_mask) + batch_entropy_mask.append(entropy_mask) + batch_timesteps.append(timesteps) + batch_state_ends_mask.append(state_ends_mask) + batch_engine_log_probs.append(engine_log_probs) + batch_rewards.append(main_rewards) + jT = main_rewards.numel() # TODO: better than this + jT_list.append(jT) + if A > 0: + # We get the branching time steps for each of the `jT` time steps in the main trajectory. + branching_time_steps = [bt for item in range(jT) for bt in A * [item]] + batch_branching_time_steps.extend(branching_time_steps) + + # Get all of the (jT*A) alternative trajectories in the tree + # (jT is the number of time steps in the main trajectory, A is the number of alternative actions) + alternative_chats, alternative_rewards = get_alternative_chat_histories( + agent_id=agent_id, root=root + ) + assert ( + len(alternative_chats) == A * jT + ), "Incorrect number of alternative trajectories." + + for chat, rewards in zip(alternative_chats, alternative_rewards): + ( + input_ids, + action_mask, + entropy_mask, + timesteps, + state_ends_mask, + engine_log_probs, + ) = process_training_chat( + tokenizer=self.tokenizer, + chat_history=chat, + entropy_mask_regex=self.entropy_mask_regex, + exploration_prompts_to_remove=self.exploration_prompts_to_remove, + ) + alternative_batch_input_ids.append(input_ids) + alternative_batch_action_mask.append(action_mask) + alternative_batch_entropy_mask.append(entropy_mask) + alternative_batch_timesteps.append(timesteps) + alternative_batch_state_ends_mask.append(state_ends_mask) + alternative_batch_engine_log_probs.append(engine_log_probs) + alternative_batch_rewards.append(rewards) + + jT_list = torch.Tensor(jT_list) + + # Assert that number of alternative actions is constant + # assert len(set(nb_alternative_actions)) == 1, "Number of alternative actions must be constant" + # A = nb_alternative_actions[0] + + trajectory_batch = TrajectoryBatch( + rollout_ids=torch.tensor(batch_rollout_ids, dtype=torch.int32), # (B,) + crn_ids=torch.tensor(batch_crn_ids, dtype=torch.int32), + agent_ids=[agent_id] * len(batch_rollout_ids), + batch_input_ids=batch_input_ids, + batch_action_mask=batch_action_mask, + batch_entropy_mask=batch_entropy_mask, + batch_timesteps=batch_timesteps, + batch_state_ends_mask=batch_state_ends_mask, + batch_engine_log_probs=batch_engine_log_probs, + batch_rewards=batch_rewards, + ) + # Get Advantages & Train Critic + with resource_logger_context( + logger, "Get advantages with critic gradient accumulation" + ): + self.batch_advantages: torch.FloatTensor = ( + self.get_advantages_with_critic_gradient_accumulation(trajectory_batch) + ) # (B, jT) + + if A > 0: + # Here, `A` is the number of alternative actions / trajectories taken at each time step. + # For each of the `B` rollout perspectives, at each of its jT (`j` is for jagged, since each main rollout may be of a different length) steps, we take A alternate trajectories (from different actions). + # Therefore, we have ∑jT * A trajectories to process. If each of the main trajectories have T steps, we will have `B*T*A` to process. + with resource_logger_context(logger, "Create alternative trajectory batch"): + sum_jT = int(torch.sum(jT_list).item()) + jT_list = ( + jT_list.int().tolist() + ) # (jT,) # (we only want the advantages where we branched out) + alternative_trajectory_batch = TrajectoryBatch( + rollout_ids=torch.zeros(A * sum_jT, dtype=torch.int32), + crn_ids=torch.zeros(A * sum_jT, dtype=torch.int32), + agent_ids=[agent_id] * (A * sum_jT), + batch_input_ids=alternative_batch_input_ids, + batch_action_mask=alternative_batch_action_mask, + batch_entropy_mask=alternative_batch_entropy_mask, + batch_timesteps=alternative_batch_timesteps, + batch_state_ends_mask=alternative_batch_state_ends_mask, + batch_engine_log_probs=alternative_batch_engine_log_probs, + batch_rewards=alternative_batch_rewards, + ) + + # Get alternative advantages + # BAAs stands for batch alternative advantages + # (torch nested tensors have very little api support, so we have to do some odd manual work here) + with resource_logger_context( + logger, "Compute alternative advantage estimates" + ): + BAAs_list = self.get_advantages_with_critic_gradient_accumulation( + alternative_trajectory_batch + ) # list length (∑jT * A), each (jT',) + # Pad alternative advantages to (∑jT*A, P) + + BAAs_padded = pad_sequence( + BAAs_list, batch_first=True, padding_value=0.0 + ) + branch_idx = torch.tensor( + batch_branching_time_steps, + device=BAAs_padded.device, + dtype=torch.long, + ) + gathered = BAAs_padded.gather( + dim=1, index=branch_idx.unsqueeze(1) + ).squeeze(1) + # Reshape and split per rollout, then transpose to (jT_i, A) + gathered = gathered.view(A, sum_jT) # (A, ∑jT) + blocks = list( + torch.split(gathered, jT_list, dim=1) + ) # len B, shapes (A, jT_i) + BAAs = [ + blk.transpose(0, 1).contiguous() for blk in blocks + ] # list of (jT_i, A) + if self.ad_align_beta_anneal_step > 0: + max_rollout_id = torch.max(trajectory_batch.rollout_ids) + 1 + if ( + max_rollout_id % self.ad_align_beta_anneal_step == 0 + and self.past_ad_align_step != max_rollout_id + ): + self.ad_align_beta = max( + self.ad_align_beta * self.ad_align_beta_anneal_rate, + self.min_ad_align_beta, + ) + logger.info(f"Annealing ad_align_beta to {self.ad_align_beta}") + self.past_ad_align_step = max_rollout_id + self.training_data[agent_id] = AdAlignTrainingData( + agent_id=agent_id, + main_data=trajectory_batch, + main_advantages=self.batch_advantages, + alternative_advantages=BAAs if A > 0 else None, + ) + + def share_advantage_data(self) -> list[AdvantagePacket]: + """ + Share the advantage alignment data with other agents. + Returns: + AdvantagePacket: The advantage packet containing the agent's advantages. + """ + logger.info(f"Sharing advantage alignment data.") + advantage_packets = [] + for _, agent_data in self.training_data.items(): + advantage_packets.append( + AdvantagePacket( + agent_id=agent_data.agent_id, + rollout_ids=agent_data.main_data.rollout_ids, + main_advantages=agent_data.main_advantages, + ) + ) + return advantage_packets + + def receive_advantage_data(self, advantage_packets: list[AdvantagePacket]): + """ + Receive advantage packets from other players. + These contain the advantages of the other players' rollouts estimated by them. + """ + logger.info(f"Receiving advantage packets.") + + assert ( + len(advantage_packets) > 0 + ), "At least one advantage packet must be provided." + + for agent_id, agent_data in self.training_data.items(): + coagent_advantage_packets = [ + packet for packet in advantage_packets if packet.agent_id != agent_id + ] + agent_rollout_ids = agent_data.main_data.rollout_ids + agent_advantages = agent_data.main_advantages + co_agent_advantages = [] + for rollout_id in agent_rollout_ids: + for co_agent_packet in coagent_advantage_packets: + if rollout_id in co_agent_packet.rollout_ids: + index = torch.where(rollout_id == co_agent_packet.rollout_ids)[ + 0 + ].item() + co_agent_advantages.append( + co_agent_packet.main_advantages[index] + ) + # assumes that its two player game, with one co-agent + break + assert len(co_agent_advantages) == len(agent_advantages) + B = len(agent_advantages) + assert all( + a.shape[0] == b.shape[0] + for a, b in zip(co_agent_advantages, agent_advantages) + ), "Number of advantages must match for advantage alignment." + + # Get padded tensors (advantage alignment is invariant to padding) + lengths = torch.tensor( + [len(t) for t in agent_advantages], + device=self.device, + dtype=torch.long, + ) + padded_main_advantages = pad_sequence( + agent_advantages, batch_first=True, padding_value=0.0 + ) + if agent_data.alternative_advantages: + padded_alternative_advantages = pad_sequence( + agent_data.alternative_advantages, + batch_first=True, + padding_value=0.0, + ) # (B, P, A) + else: + padded_alternative_advantages = None + padded_co_agent_advantages = pad_sequence( + co_agent_advantages, batch_first=True, padding_value=0.0 + ) + + # Create training batch data + credits, sub_tensors = get_advantage_alignment_credits( + a1=padded_main_advantages, + a1_alternative=padded_alternative_advantages, + a2=padded_co_agent_advantages, + beta=self.ad_align_beta, + gamma=self.ad_align_gamma, + exclude_k_equals_t=self.ad_align_exclude_k_equals_t, + use_sign=self.ad_align_use_sign, + clipping=self.ad_align_clipping, + force_coop_first_step=self.ad_align_force_coop_first_step, + use_old_ad_align=self.use_old_ad_align, + use_time_regularization=self.use_time_regularization, + rloo_branch=self.rloo_branch, + reuse_baseline=self.reuse_baseline, + mean_normalize_ad_align=self.mean_normalize_ad_align, + whiten_adalign_advantages=self.whiten_adalign_advantages, + whiten_adalign_advantages_time_step_wise=self.whiten_adalign_advantages_time_step_wise, + discount_t=self.ad_align_discount_t, + ) + for key, value in sub_tensors.items(): + self.rollout_tally.add_metric( + path=[key], + rollout_tally_item=RolloutTallyItem( + crn_ids=agent_data.main_data.crn_ids, + rollout_ids=agent_data.main_data.rollout_ids, + agent_ids=agent_data.main_data.agent_ids, + metric_matrix=value, + ), + ) + + if not self.skip_discounted_state_visitation: + credits = get_discounted_state_visitation_credits( + credits, + self.discount_factor, + ) + self.rollout_tally.add_metric( + path=["discounted_state_visitation_credits"], + rollout_tally_item=RolloutTallyItem( + crn_ids=agent_data.main_data.crn_ids, + rollout_ids=agent_data.main_data.rollout_ids, + agent_ids=agent_data.main_data.agent_ids, + metric_matrix=sub_tensors[ + "discounted_state_visitation_credits" + ], + ), + ) + + # Slice back to jagged + advantage_alignment_credits = [credits[i, : lengths[i]] for i in range(B)] + # Replace stored training data for this agent by the concrete trajectory batch + # and attach the computed credits for policy gradient. + self.training_data[agent_id] = agent_data.main_data + self.training_data[agent_id].batch_credits = advantage_alignment_credits diff --git a/src_code_for_reproducibility/training/trainer_common.py b/src_code_for_reproducibility/training/trainer_common.py new file mode 100644 index 0000000000000000000000000000000000000000..8da913096e80194d87d4f6f69e0a5ebc3b4c9cb6 --- /dev/null +++ b/src_code_for_reproducibility/training/trainer_common.py @@ -0,0 +1,1054 @@ +""" +TODO: Add coefficients for losses (depend on total number of tokens or batch) +TODO: adapt reinforce step for torch.compile +TODO: add lr schedulers support +""" +import logging +import os +import pickle +import sys +from abc import ABC, abstractmethod +from typing import Callable, Literal, Union + +import numpy as np +import torch +import torch.nn.functional as F +from accelerate import Accelerator +from pandas._libs.tslibs.offsets import CBMonthBegin +from peft import LoraConfig +from torch.nn.utils.rnn import pad_sequence +from transformers import AutoModelForCausalLM, AutoTokenizer + +from mllm.markov_games.rollout_tree import * +from mllm.markov_games.rollout_tree import RolloutTreeRootNode +from mllm.training.annealing_methods import sigmoid_annealing +from mllm.training.credit_methods import ( + get_discounted_returns, + get_generalized_advantage_estimates, + get_rloo_credits, + whiten_advantages, + whiten_advantages_time_step_wise, +) +from mllm.training.tally_metrics import Tally +from mllm.training.tally_rollout import RolloutTally, RolloutTallyItem +from mllm.training.tally_tokenwise import ContextualizedTokenwiseTally +from mllm.training.tokenize_chats import * +from mllm.training.tokenize_chats import process_training_chat +from mllm.training.training_data_utils import * +from mllm.training.training_data_utils import ( + TrainingBatch, + TrajectoryBatch, + get_tokenwise_credits, +) +from mllm.utils.resource_context import resource_logger_context + +logger = logging.getLogger(__name__) +logger.addHandler(logging.StreamHandler(sys.stdout)) + + +@dataclass +class TrainerAnnealingState: + annealing_step_counter: int = 0 + + +class BaseTrainer(ABC): + """ + Trainer + """ + + def __init__( + self, + policy: AutoModelForCausalLM, + policy_optimizer: torch.optim.Optimizer, + critic: Union[AutoModelForCausalLM, None], + critic_optimizer: Union[torch.optim.Optimizer, None], + tokenizer: AutoTokenizer, + lr_scheduler: torch.optim.lr_scheduler.LRScheduler, + critic_lr_scheduler: Union[torch.optim.lr_scheduler.LRScheduler, None], + ###################################################################### + entropy_coeff: float, + entropy_topk: int, + entropy_mask_regex: Union[str, None], + kl_coeff: float, + gradient_clipping: Union[float, None], + restrict_tokens: Union[list[str], None], + mini_batch_size: int, + use_gradient_checkpointing: bool, + temperature: float, + device: str, + whiten_advantages: bool, + whiten_advantages_time_step_wise: bool, + use_gae: bool, + use_gae_lambda_annealing: bool, + gae_lambda_annealing_limit: float, + gae_lambda_annealing_method: Literal["sigmoid_annealing"], + gae_lambda_annealing_method_params: dict, + pg_loss_normalization: Literal["batch", "nb_tokens"], + use_rloo: bool, + skip_discounted_state_visitation: bool, + discount_factor: float, + enable_tokenwise_logging: bool, + save_path: str, + reward_normalizing_constant: float = 1.0, + critic_loss_type: Literal["mse", "huber"] = "huber", + exploration_prompts_to_remove: list[str] = [], + filter_higher_refprob_tokens_kl: bool = False, + truncated_importance_sampling_ratio_cap: float = 0.0, + importance_sampling_strategy: Literal[ + "per_token", "per_sequence" + ] = "per_token", + ): + """ + Initialize the REINFORCE trainer with reward shaping for multi-agent or single-agent training. + + Args: + model (AutoModelForCausalLM): The main policy model. + tokenizer (AutoTokenizer): Tokenizer for the model. + optimizer (torch.optim.Optimizer): Optimizer for the policy model. + lr_scheduler (torch.optim.lr_scheduler.LRScheduler): Learning rate scheduler for the policy model. + critic (AutoModelForCausalLM or None): Critic model for value estimation (optional). + critic_optimizer (torch.optim.Optimizer or None): Optimizer for the critic model (optional). + critic_lr_scheduler (torch.optim.lr_scheduler.LRScheduler or None): LR scheduler for the critic (optional). + config (RtConfig): Configuration object for training. + """ + self.tokenizer = tokenizer + # self.tokenizer.padding_side = "left" # needed for flash attention + if self.tokenizer.pad_token_id is None: + self.tokenizer.pad_token_id = self.tokenizer.eos_token_id + self.lr_scheduler = lr_scheduler + self.accelerator = Accelerator() + ( + self.policy, + self.policy_optimizer, + self.critic, + self.critic_optimizer, + ) = self.accelerator.prepare(policy, policy_optimizer, critic, critic_optimizer) + + self.critic_lr_scheduler = critic_lr_scheduler + self.tally = Tally() + + if use_gradient_checkpointing == True: + self.policy.gradient_checkpointing_enable(dict(use_reentrant=False)) + if critic is not None: + self.critic.gradient_checkpointing_enable(dict(use_reentrant=False)) + + self.save_path = save_path + + # Load trainer state if it exists + self.trainer_annealing_state_path = os.path.join( + self.save_path, "trainer_annealing_state.pkl" + ) + if os.path.exists(self.trainer_annealing_state_path): + logger.info( + f"Loading trainer state from {self.trainer_annealing_state_path}" + ) + self.trainer_annealing_state = pickle.load( + open(self.trainer_annealing_state_path, "rb") + ) + else: + self.trainer_annealing_state = TrainerAnnealingState() + + # Load policy optimizer state if it exists + self.policy_optimizer_path = os.path.join( + self.save_path, "policy_optimizer_state.pt" + ) + if os.path.exists(self.policy_optimizer_path): + logger.info( + f"Loading policy optimizer state from {self.policy_optimizer_path}" + ) + self.policy_optimizer.load_state_dict( + torch.load(self.policy_optimizer_path) + ) + + # Load critic optimizer state if it exists + self.critic_optimizer_path = os.path.join( + self.save_path, "critic_optimizer_state.pt" + ) + if ( + os.path.exists(self.critic_optimizer_path) + and self.critic_optimizer is not None + ): + logger.info( + f"Loading critic optimizer state from {self.critic_optimizer_path}" + ) + self.critic_optimizer.load_state_dict( + torch.load(self.critic_optimizer_path) + ) + self.device = self.accelerator.device + self.entropy_coeff = entropy_coeff + self.entropy_topk = entropy_topk + self.entropy_mask_regex = entropy_mask_regex + self.kl_coeff = kl_coeff + self.gradient_clipping = gradient_clipping + self.restrict_tokens = restrict_tokens + self.mini_batch_size = mini_batch_size + self.use_gradient_checkpointing = use_gradient_checkpointing + self.temperature = temperature + self.use_gae = use_gae + self.whiten_advantages = whiten_advantages + self.whiten_advantages_time_step_wise = whiten_advantages_time_step_wise + self.use_rloo = use_rloo + self.skip_discounted_state_visitation = skip_discounted_state_visitation + self.use_gae_lambda_annealing = use_gae_lambda_annealing + self.gae_lambda_annealing_limit = gae_lambda_annealing_limit + if use_gae_lambda_annealing: + self.gae_lambda_annealing_method: Callable[ + [int], float + ] = lambda step: eval(gae_lambda_annealing_method)( + step=step, **gae_lambda_annealing_method_params + ) + self.discount_factor = discount_factor + self.enable_tokenwise_logging = enable_tokenwise_logging + self.reward_normalizing_constant = reward_normalizing_constant + self.pg_loss_normalization = pg_loss_normalization + self.critic_loss_type = critic_loss_type + self.exploration_prompts_to_remove = exploration_prompts_to_remove + # Common containers used by all trainers + self.training_data: dict = {} + self.debug_path_list: list[str] = [] + self.policy_gradient_data = None + self.tally = Tally() + self.rollout_tally = RolloutTally() + self.tokenwise_tally: Union[ContextualizedTokenwiseTally, None] = None + self.filter_higher_refprob_tokens_kl = filter_higher_refprob_tokens_kl + self.truncated_importance_sampling_ratio_cap = ( + truncated_importance_sampling_ratio_cap + ) + self.importance_sampling_strategy = importance_sampling_strategy + + def mask_non_restricted_token_logits(self, logits: torch.Tensor) -> torch.Tensor: + """ + Masks logits so that only allowed tokens (as specified in config.restrict_tokens) + and the EOS token are active. + All other logits are set to -inf, effectively removing them from the softmax. + + Args: + logits (torch.Tensor): The logits tensor of shape (B, S, V). + + Returns: + torch.Tensor: The masked logits tensor. + """ + # TODO: verify. Not sure what we do here is differentiable + # also, we recompute for nothing + + if self.restrict_tokens is not None: + allowed_token_ids = [] + for token in self.restrict_tokens: + token_ids = self.tokenizer(token, add_special_tokens=False)["input_ids"] + allowed_token_ids.append(token_ids[0]) + allowed_token_ids.append( + self.tokenizer.eos_token_id + ) # This token should always be active + allowed_token_ids = torch.tensor(allowed_token_ids, device=logits.device) + # Mask log_probs and probs to only allowed tokens + mask = torch.zeros_like(logits).bool() # (B, S, V) + mask[..., allowed_token_ids] = True + logits = torch.where( + mask, + logits, + torch.tensor(-float("inf"), device=logits.device), + ) + + return logits + + # def get_gradient_magnitude(self, loss_term: torch.Tensor) -> float: + # """ + # Computes the L2 norm of the gradients of the given loss term with respect to the model parameters. + + # Args: + # loss_term (torch.Tensor): The loss tensor to compute gradients for. + + # Returns: + # float: The L2 norm of the gradients, or 0.0 if no gradients are present. + # """ + # with torch.no_grad(): + # grads = torch.autograd.grad( + # loss_term, + # [p for p in self.policy.parameters() if p.requires_grad], + # retain_graph=True, + # allow_unused=True, + # ) + # grads = [g for g in grads if g is not None] + # if not grads: + # return torch.tensor(0.0, device=loss_term.device) + # return torch.norm(torch.stack([g.norm(2) for g in grads])).item() + + def apply_reinforce_step( + self, + training_batch: TrainingBatch, + ) -> None: + """ + Applies a single REINFORCE policy gradient step using the provided batch of rollouts. + Handles batching, loss computation (including entropy and KL regularization), gradient accumulation, and optimizer step. + Optionally logs various metrics and statistics. + + Args: + paths (list[str]): List of game complete file paths for each rollout. + contexts (list[torch.Tensor]): List of context tensors for each rollout. + credits (list[torch.Tensor]): List of credit tensors (rewards/advantages) for each rollout. + action_masks (list[torch.Tensor]): List of action mask tensors for each rollout. + """ + with resource_logger_context(logger, "Apply reinforce step"): + self.policy.train() + mb_size = self.mini_batch_size + nb_rollouts = len(training_batch) + + # Initialize running mean logs + running_mean_logs = { + "rl_objective": 0.0, + "policy_gradient_loss": 0.0, + "policy_gradient_norm": 0.0, + "log_probs": 0.0, + "credits": 0.0, + "entropy": 0.0, + "engine_log_probs_diff_clampfrac": 0.0, + "tis_imp_ratio": 0.0, + "ref_log_probs_diff_clampfrac": 0.0, + "higher_refprob_frac": 0.0, + "tis_imp_ratio_clampfrac": 0.0, + } + if self.entropy_coeff != 0.0: + running_mean_logs["entropy"] = 0.0 + if self.kl_coeff != 0.0: + running_mean_logs["kl_divergence"] = 0.0 + + # Get total number of tokens generated + total_tokens_generated = 0 + for att_mask in training_batch.batch_action_mask: + total_tokens_generated += att_mask.sum() + + # Obtain loss normalization + if self.pg_loss_normalization == "nb_tokens": + normalization_factor = total_tokens_generated + elif self.pg_loss_normalization == "batch": + normalization_factor = np.ceil(nb_rollouts / mb_size).astype(int) + else: + raise ValueError( + f"Invalid pg_loss_normalization: {self.pg_loss_normalization}" + ) + + # Gradient accumulation for each mini-batch + for mb in range(0, nb_rollouts, mb_size): + logger.info(f"Processing mini-batch {mb} of {nb_rollouts}") + loss = 0.0 + training_mb = training_batch[mb : mb + mb_size] + training_mb = training_mb.get_padded_tensors() + training_mb.to(self.device) + ( + tokens_mb, + action_mask_mb, + entropy_mask_mb, + credits_mb, + engine_log_probs_mb, + timesteps_mb, + ) = ( + training_mb.batch_input_ids, + training_mb.batch_action_mask, + training_mb.batch_entropy_mask, + training_mb.batch_credits, + training_mb.batch_engine_log_probs, + training_mb.batch_timesteps, + ) + + # Next token prediction + contexts_mb = tokens_mb[:, :-1] + shifted_contexts_mb = tokens_mb[:, 1:] + action_mask_mb = action_mask_mb[:, 1:] + entropy_mask_mb = entropy_mask_mb[:, 1:] + credits_mb = credits_mb[:, 1:] + engine_log_probs_mb = engine_log_probs_mb[:, 1:] + timesteps_mb = timesteps_mb[:, 1:] + + if self.enable_tokenwise_logging: + self.tokenwise_tally.set_action_mask(action_mask=action_mask_mb) + self.tokenwise_tally.set_range(range=(mb, mb + mb_size)) + self.tokenwise_tally.add_contexts(contexts=contexts_mb) + self.tokenwise_tally.add_data( + metric_id="next_token", + metrics=shifted_contexts_mb, + to_tids=True, + ) + self.tokenwise_tally.add_data( + metric_id="entropy_mask", + metrics=entropy_mask_mb, + ) + + if self.enable_tokenwise_logging: + self.tokenwise_tally.add_data( + metric_id="next_token_credit", metrics=credits_mb + ) + + # Forward pass + cast to FP-32 for higher prec. + # TODO: create attention mask if not relying on default (assume causal llm) + logits = self.policy(input_ids=contexts_mb)[0] # (B, S, V) + + # Mask non-restricted tokens + if self.restrict_tokens is not None: + logits = self.mask_non_restricted_token_logits(logits) + + logits /= self.temperature # (B, S, V) + + # Compute new log probabilities + log_probs = F.log_softmax(logits, dim=-1) # (B, S, V) + + # Get log probabilities of actions taken during rollouts + action_log_probs = log_probs.gather( + dim=-1, index=shifted_contexts_mb.unsqueeze(-1) + ).squeeze( + -1 + ) # (B, S) + if self.pg_loss_normalization == "batch": + den_running_mean = action_mask_mb.sum() * normalization_factor + else: + den_running_mean = normalization_factor + running_mean_logs["log_probs"] += ( + action_log_probs * action_mask_mb + ).sum().item() / den_running_mean + running_mean_logs["credits"] += ( + credits_mb * action_mask_mb + ).sum().item() / den_running_mean + + if self.enable_tokenwise_logging: + self.tokenwise_tally.add_data( + metric_id="next_token_log_prob", + metrics=action_log_probs, + ) + self.tokenwise_tally.add_data( + metric_id="engine_next_token_log_prob", + metrics=engine_log_probs_mb, + ) + self.tokenwise_tally.add_data( + metric_id="next_token_prob", + metrics=torch.exp(action_log_probs), + ) + top_k_indices = torch.topk(logits, k=5, dim=-1).indices + self.tokenwise_tally.add_data( + metric_id=f"top_{5}_tids", + metrics=top_k_indices, + to_tids=True, + ) + self.tokenwise_tally.add_data( + metric_id=f"top_{5}_probs", + metrics=torch.exp(log_probs).gather( + dim=-1, index=top_k_indices + ), + ) + + rewarded_action_log_probs = ( + action_mask_mb * credits_mb * action_log_probs + ) + # (B, S) + INVALID_LOGPROB = 1.0 + CLAMP_VALUE = 40.0 + masked_action_log_probs = torch.masked_fill( + action_log_probs, ~action_mask_mb, INVALID_LOGPROB + ) + masked_engine_log_probs = torch.masked_fill( + engine_log_probs_mb, ~action_mask_mb, INVALID_LOGPROB + ) + with torch.no_grad(): + action_engine_log_probs_diff = ( + masked_action_log_probs - masked_engine_log_probs + ).clamp(-CLAMP_VALUE, CLAMP_VALUE) + running_mean_logs["engine_log_probs_diff_clampfrac"] += ( + action_engine_log_probs_diff.abs() + .eq(CLAMP_VALUE) + .float() + .sum() + .item() + / den_running_mean + ) + if self.importance_sampling_strategy == "per_sequence": + tis_imp_ratio = torch.zeros_like(action_engine_log_probs_diff) + for mb_idx in range(action_engine_log_probs_diff.shape[0]): + valid_token_mask = action_mask_mb[mb_idx] + timestep_ids = timesteps_mb[mb_idx][valid_token_mask] + timestep_logprob_diffs = action_engine_log_probs_diff[mb_idx][ + valid_token_mask + ] + max_timestep = int(timestep_ids.max().item()) + 1 + timestep_sums = torch.zeros( + max_timestep, + device=action_engine_log_probs_diff.device, + dtype=action_engine_log_probs_diff.dtype, + ) + timestep_sums.scatter_add_( + 0, timestep_ids, timestep_logprob_diffs + ) + timestep_ratios = torch.exp(timestep_sums) + tis_imp_ratio[ + mb_idx, valid_token_mask + ] = timestep_ratios.gather(0, timestep_ids) + else: + tis_imp_ratio = torch.exp(action_engine_log_probs_diff) + running_mean_logs["tis_imp_ratio"] += ( + tis_imp_ratio * action_mask_mb + ).sum().item() / den_running_mean + if self.truncated_importance_sampling_ratio_cap > 0.0: + tis_imp_ratio = torch.clamp( + tis_imp_ratio, max=self.truncated_importance_sampling_ratio_cap + ) + running_mean_logs["tis_imp_ratio_clampfrac"] += ( + tis_imp_ratio.eq(self.truncated_importance_sampling_ratio_cap) + .float() + .sum() + .item() + ) / den_running_mean + rewarded_action_log_probs = ( + rewarded_action_log_probs * tis_imp_ratio + ) + + if self.enable_tokenwise_logging: + self.tokenwise_tally.add_data( + metric_id="next_token_clogπ", + metrics=rewarded_action_log_probs, + ) + + # Add value term to loss + if self.pg_loss_normalization == "batch": + nb_act_tokens = action_mask_mb.sum() + mb_value = -rewarded_action_log_probs.sum() / nb_act_tokens + else: + mb_value = -rewarded_action_log_probs.sum() + + loss += mb_value + running_mean_logs["rl_objective"] += mb_value.item() / den_running_mean + + # ------------------------------------------------- + # Entropy Regularization + # ------------------------------------------------- + # Only apply entropy on distribution defined over most probable tokens + if self.entropy_topk is not None: + top_k_indices = torch.topk( + logits, k=self.entropy_topk, dim=-1 + ).indices + entropy_logits = logits.gather(dim=-1, index=top_k_indices) + else: + entropy_logits = logits + + token_entropy_terms = -F.softmax( + entropy_logits, dim=-1 + ) * F.log_softmax( + entropy_logits, dim=-1 + ) # (B, S, T) + token_entropy_terms *= ( + action_mask_mb[:, :, None] * entropy_mask_mb[:, :, None] + ) # only get loss on specific action tokens + + mb_entropy = token_entropy_terms.sum(dim=-1) + + if self.enable_tokenwise_logging: + self.tokenwise_tally.add_data( + metric_id="entropy", + metrics=mb_entropy, + ) + if self.pg_loss_normalization == "batch": + nb_act_tokens = action_mask_mb.sum() + mb_entropy = -mb_entropy.sum() / nb_act_tokens + else: + mb_entropy = -mb_entropy.sum() + running_mean_logs["entropy"] += -mb_entropy.item() / den_running_mean + if self.entropy_coeff != 0.0: + mb_entropy *= self.entropy_coeff + loss += mb_entropy + + # ------------------------------------------------- + # KL-DIVERGENCE + # ------------------------------------------------- + if self.kl_coeff != 0.0: + ref_model_logits = self.policy.get_base_model_logits(contexts_mb) + ref_model_logits = ref_model_logits / self.temperature + # (B, S, V) + ref_model_logits = self.mask_non_restricted_token_logits( + logits=ref_model_logits + ) + # (B, S, V) + ref_model_log_probs = F.log_softmax(ref_model_logits, dim=-1) + # (B, S, V) + ref_model_action_log_probs = ref_model_log_probs.gather( + dim=-1, index=shifted_contexts_mb.unsqueeze(-1) + ).squeeze( + -1 + ) # (B,S) + # Approximating KL Divergence (see refs in docstring) + # Ref 1: http://joschu.net/blog/kl-approx.html + # Ref 2: https://github.dev/huggingface/trl/blob/main/trl/trainer/grpo_trainer.py#L1332 + masked_ref_model_action_log_probs = torch.masked_fill( + ref_model_action_log_probs, ~action_mask_mb, INVALID_LOGPROB + ) + action_log_probs_diff = ( + masked_ref_model_action_log_probs - masked_action_log_probs + ).clamp(-CLAMP_VALUE, CLAMP_VALUE) + running_mean_logs["ref_log_probs_diff_clampfrac"] += ( + action_log_probs_diff.abs().eq(CLAMP_VALUE).float().sum().item() + / den_running_mean + ) + if self.filter_higher_refprob_tokens_kl: + higher_refprob_tokens_mask = action_log_probs_diff > 0.0 + running_mean_logs["higher_refprob_frac"] += ( + higher_refprob_tokens_mask.sum().item() / den_running_mean + ) + action_log_probs_diff = action_log_probs_diff * ( + ~higher_refprob_tokens_mask + ) + kl_div = torch.expm1(action_log_probs_diff) - action_log_probs_diff + kl_div *= action_mask_mb # We only care about KLD of action tokens + if self.truncated_importance_sampling_ratio_cap > 0.0: + kl_div = kl_div * tis_imp_ratio + kl_div *= self.kl_coeff + if self.enable_tokenwise_logging: + self.tokenwise_tally.add_data( + metric_id="ref_model_next_token_log_prob", + metrics=ref_model_action_log_probs, + ) + self.tokenwise_tally.add_data( + metric_id="kl_divergence", + metrics=kl_div, + ) + + if self.pg_loss_normalization == "batch": + nb_act_tokens = action_mask_mb.sum() + mb_kl = kl_div.sum() / nb_act_tokens + else: + mb_kl = kl_div.sum() + running_mean_logs["kl_divergence"] += ( + mb_kl.item() / den_running_mean + ) + loss += mb_kl + + # Accumulate gradient + running_mean_logs["policy_gradient_loss"] += ( + loss.item() / den_running_mean + ) + loss /= normalization_factor + self.accelerator.backward(loss) + + # ensure gpu memory is freed + del training_mb + del log_probs + del logits + del loss + del action_log_probs + del rewarded_action_log_probs + + logger.info( + f"Accumulated the policy gradient loss for {total_tokens_generated} tokens." + ) + + # Clip gradients and take step + if self.gradient_clipping is not None: + grad_norm = self.accelerator.clip_grad_norm_( + self.policy.parameters(), self.gradient_clipping + ) + running_mean_logs["policy_gradient_norm"] += grad_norm.item() + + # Take step + self.policy_optimizer.step() + self.policy_optimizer.zero_grad() + + # Store logs + for key, value in running_mean_logs.items(): + self.tally.add_metric(path=key, metric=value) + + # Clear + # TODO: verify + self.accelerator.clear(self.policy, self.policy_optimizer) + import gc + + gc.collect() + torch.cuda.empty_cache() + return running_mean_logs + + def get_advantages_with_critic_gradient_accumulation( + self, trajectories: TrajectoryBatch, critic_loss_scaling_factor: float = 2.0 + ) -> torch.FloatTensor: + """ + TOWRITE + Uses GAE if enabled, otherwise uses Monte Carlo returns. + Optionally trains the critic if GAE is used. + Returns: + advantages: NestedFloatTensors + """ + + mb_size = self.mini_batch_size + batch_size = trajectories.rollout_ids.shape[0] + agent_id = trajectories.agent_ids[0] + batch_rewards = trajectories.batch_rewards + + ###################################### + # use critic for advantage estimation + ###################################### + if self.use_gae: + if "buffer" in agent_id: + self.critic.eval() + training = False + else: + self.critic.train() + training = True + advantages = [] + # critic_loss_scaling_factor comes learning single critic for two agents + normalization_factor = ( + np.ceil(batch_size / mb_size).astype(int) * critic_loss_scaling_factor + ) + # For each minibatch + for mb in range(0, batch_size, mb_size): + trajectory_mb = trajectories[mb : mb + mb_size] + trajectory_mb.to(self.device) + rewards_mb = trajectory_mb.batch_rewards + ( + tokens_mb, + state_ends_mask_mb, + timestep_counts, + ) = trajectory_mb.get_padded_tensors_for_critic() + # critic causal attention up to end flags + if training: + vals_estimate_full = self.critic(tokens_mb) + else: + with torch.no_grad(): + vals_estimate_full = self.critic(tokens_mb) + + # if vals_estimate_full.dim() == 3: + # vals_estimate_full = vals_estimate_full.squeeze(-1) + + # Select only positions where states end, per sample → list of (jT,) + B = tokens_mb.shape[0] + vals_list = [ + vals_estimate_full[b][state_ends_mask_mb[b]] for b in range(B) + ] + + # Pad to (B, max_jT) = (B, S) + vals_estimate_mb = pad_sequence( + vals_list, batch_first=True, padding_value=0.0 + ) + dtype = vals_estimate_mb.dtype + rewards_mb = pad_sequence( + rewards_mb, batch_first=True, padding_value=0.0 + ).to( + dtype=dtype + ) # (B, S) + self.rollout_tally.add_metric( + path=["batch_rewards"], + rollout_tally_item=RolloutTallyItem( + crn_ids=trajectory_mb.crn_ids, + rollout_ids=trajectory_mb.rollout_ids, + agent_ids=trajectory_mb.agent_ids, + metric_matrix=rewards_mb, + ), + ) + if self.reward_normalizing_constant != 1.0: + rewards_mb /= self.reward_normalizing_constant + + det_vals_estimate_mb = vals_estimate_mb.detach() # (B, max_jT) + self.rollout_tally.add_metric( + path=["mb_value_estimates_critic"], + rollout_tally_item=RolloutTallyItem( + crn_ids=trajectory_mb.crn_ids, + rollout_ids=trajectory_mb.rollout_ids, + agent_ids=trajectory_mb.agent_ids, + metric_matrix=det_vals_estimate_mb, + ), + ) + + # Append a 0 value to the end of the value estimates + if det_vals_estimate_mb.shape[1] == rewards_mb.shape[1]: + Bsize = det_vals_estimate_mb.shape[0] + device = det_vals_estimate_mb.device + dtype = det_vals_estimate_mb.dtype + det_vals_estimate_mb = torch.cat( + [ + det_vals_estimate_mb, + torch.zeros((Bsize, 1), device=device, dtype=dtype), + ], + dim=1, + ) # (B, max_jT+1) + else: + raise ValueError( + "Incompatible shapes for value estimates and rewards." + ) + + # Get annealed lambda + if self.use_gae_lambda_annealing: + annealing_constant = self.gae_lambda_annealing_method( + step=self.trainer_annealing_state.annealing_step_counter + ) + annealed_lambda = ( + self.gae_lambda_annealing_limit * annealing_constant + ) + self.tally.add_metric( + path="annealed_lambda", metric=annealed_lambda + ) + else: + annealed_lambda = self.gae_lambda_annealing_limit + + # Get GAE advantages + gae_advantages = get_generalized_advantage_estimates( + rewards=rewards_mb, + value_estimates=det_vals_estimate_mb, + discount_factor=self.discount_factor, + lambda_coef=annealed_lambda, + ) # (B, max_jT) + self.rollout_tally.add_metric( + path=["mb_gae_advantages"], + rollout_tally_item=RolloutTallyItem( + crn_ids=trajectory_mb.crn_ids, + rollout_ids=trajectory_mb.rollout_ids, + agent_ids=trajectory_mb.agent_ids, + metric_matrix=gae_advantages, + ), + ) + if training: + targets = ( + gae_advantages.to(dtype=dtype) + det_vals_estimate_mb[:, :-1] + ) # (B, max_jT) # A(s, a, b) + V(s) = Q(s, a, b) + self.rollout_tally.add_metric( + path=["mb_targets_critic"], + rollout_tally_item=RolloutTallyItem( + crn_ids=trajectory_mb.crn_ids, + rollout_ids=trajectory_mb.rollout_ids, + agent_ids=trajectory_mb.agent_ids, + metric_matrix=targets, + ), + ) + if self.critic_loss_type == "mse": + loss = F.mse_loss( + input=vals_estimate_mb, + target=targets, + ) + elif self.critic_loss_type == "huber": + loss = F.huber_loss( + input=vals_estimate_mb, + target=targets, + ) + self.tally.add_metric(path=["mb_critic_loss"], metric=loss.item()) + # Accumulate gradient + loss /= normalization_factor + self.accelerator.backward(loss) + del loss + del targets + del vals_estimate_mb + del trajectory_mb + del vals_estimate_full + + # Get jagged back using timestep_counts + advantages.extend( + [gae_advantages[i, : timestep_counts[i]] for i in range(B)] + ) + + ###################################### + # use exclusively Monte Carlo returns & rloo for advantage estimation + ###################################### + else: + lengths = [len(c) for c in batch_rewards] + padded_rewards = pad_sequence( + batch_rewards, batch_first=True, padding_value=0.0 + ) + self.rollout_tally.add_metric( + path=["mb_rewards"], + rollout_tally_item=RolloutTallyItem( + crn_ids=trajectories.crn_ids, + rollout_ids=trajectories.rollout_ids, + agent_ids=trajectories.agent_ids, + metric_matrix=padded_rewards, + ), + ) + if self.reward_normalizing_constant != 1.0: + padded_rewards /= self.reward_normalizing_constant + padded_advantages = get_discounted_returns( + rewards=padded_rewards, + discount_factor=self.discount_factor, + ) # no baseline for now + if self.use_rloo: + is_grouped_by_rng = ( + trajectories.crn_ids.unique().shape[0] + != trajectories.crn_ids.shape[0] + ) + if is_grouped_by_rng: + for crn_id in trajectories.crn_ids.unique(): + rng_mask = trajectories.crn_ids == crn_id + rng_advantages = padded_advantages[rng_mask] + rng_advantages, _ = get_rloo_credits(credits=rng_advantages) + padded_advantages[rng_mask] = rng_advantages + else: + padded_advantages, _ = get_rloo_credits(credits=padded_advantages) + self.rollout_tally.add_metric( + path=["mb_rloo_advantages"], + rollout_tally_item=RolloutTallyItem( + crn_ids=trajectories.crn_ids, + rollout_ids=trajectories.rollout_ids, + agent_ids=trajectories.agent_ids, + metric_matrix=padded_advantages, + ), + ) + advantages = [ + padded_advantages[i, : lengths[i]] + for i in range(padded_advantages.shape[0]) + ] + + if self.whiten_advantages_time_step_wise or self.whiten_advantages: + lengths = [len(c) for c in advantages] + padded_advantages = pad_sequence( + advantages, batch_first=True, padding_value=0.0 + ) + if self.whiten_advantages_time_step_wise: + whitened_padded_advantages = whiten_advantages_time_step_wise( + padded_advantages + ) + path = ["mb_whitened_advantages_time_step_wise"] + elif self.whiten_advantages: + whitened_padded_advantages = whiten_advantages(padded_advantages) + path = ["mb_whitened_advantages"] + self.rollout_tally.add_metric( + path=path, + rollout_tally_item=RolloutTallyItem( + crn_ids=trajectories.crn_ids, + rollout_ids=trajectories.rollout_ids, + agent_ids=trajectories.agent_ids, + metric_matrix=whitened_padded_advantages, + ), + ) + advantages = [ + whitened_padded_advantages[i, : lengths[i]] + for i in range(whitened_padded_advantages.shape[0]) + ] + + self.trainer_annealing_state.annealing_step_counter += 1 + + return advantages + + @abstractmethod + def set_agent_trajectory_data( + self, agent_id: str, roots: list[RolloutTreeRootNode] + ) -> None: + """ + TOWRITE + """ + pass + + def set_trajectory_data( + self, roots: list[RolloutTreeRootNode], agent_ids: list[str] + ) -> None: + """ + TOWRITE + """ + for agent_id in agent_ids: + self.set_agent_trajectory_data(agent_id, roots) + + @abstractmethod + def share_advantage_data(self) -> list[AdvantagePacket]: + pass + + @abstractmethod + def receive_advantage_data(self, advantage_packets: list[AdvantagePacket]) -> None: + pass + + def set_policy_gradient_data(self, agent_ids: list[str]) -> None: + """ + Already set earlier # TODO: make it separate and clean + """ + self.policy_gradient_data = None + # for agent_id, trajectory_batch in self.training_data.items(): + # if "buffer" in agent_id: + # continue + for agent_id in agent_ids: + assert "buffer" not in agent_id, "Buffer agents do not train policy" + trajectory_batch = self.training_data[agent_id] + tokenwise_batch_credits = get_tokenwise_credits( + batch_timesteps=trajectory_batch.batch_timesteps, + batch_credits=trajectory_batch.batch_credits, + ) + policy_gradient_data = TrainingBatch( + rollout_ids=trajectory_batch.rollout_ids, + batch_input_ids=trajectory_batch.batch_input_ids, + batch_action_mask=trajectory_batch.batch_action_mask, + batch_entropy_mask=trajectory_batch.batch_entropy_mask, + batch_credits=tokenwise_batch_credits, + batch_engine_log_probs=trajectory_batch.batch_engine_log_probs, + batch_timesteps=trajectory_batch.batch_timesteps, + ) + if self.policy_gradient_data is None: + self.policy_gradient_data = policy_gradient_data + else: + self.policy_gradient_data.append(policy_gradient_data) + + self.training_data = {} + self.tokenwise_tally = ContextualizedTokenwiseTally( + tokenizer=self.tokenizer, + paths=self.debug_path_list, + ) + + def train(self) -> None: + """ + TOWRITE + """ + assert self.policy_gradient_data is not None, "Policy gradient data is not set" + if self.critic_optimizer is not None: + if self.gradient_clipping is not None: + grad_norm = self.accelerator.clip_grad_norm_( + self.critic.parameters(), self.gradient_clipping + ) + self.tally.add_metric( + path="gradient_norm_critic", metric=grad_norm.item() + ) + # Take step + self.critic_optimizer.step() + self.critic_optimizer.zero_grad() + self.accelerator.clear(self.critic, self.critic_optimizer) + import gc + + gc.collect() + torch.cuda.empty_cache() + running_mean_logs = self.apply_reinforce_step( + training_batch=self.policy_gradient_data + ) + return running_mean_logs + + def export_training_tally(self, identifier: str, folder: str) -> None: + """ + Saves and resets the collected training metrics using the tally object. + """ + os.makedirs(folder, exist_ok=True) + self.tally.save(identifier=identifier, folder=folder) + self.tokenwise_tally.save( + path=os.path.join(folder, f"{identifier}_tokenwise.csv") + ) + self.rollout_tally.save(identifier=identifier, folder=folder) + self.tally.reset() + self.tokenwise_tally = None + self.rollout_tally.reset() + self.debug_path_list = [] + + def export_optimizer_states(self) -> None: + """ + Saves the optimizer states for both the main model and critic (if it exists). + """ + try: + os.makedirs(self.save_path, exist_ok=True) + + torch.save(self.policy_optimizer.state_dict(), self.policy_optimizer_path) + logger.info(f"Saved main optimizer state to {self.policy_optimizer_path}") + + if self.critic_optimizer is not None: + torch.save( + self.critic_optimizer.state_dict(), self.critic_optimizer_path + ) + logger.info( + f"Saved critic optimizer state to {self.critic_optimizer_path}" + ) + except Exception as e: + logger.error(f"Error saving optimizer states: {str(e)}") + raise + + def export_trainer_annealing_state(self) -> None: + """ + Saves the trainer state. + """ + with open(self.trainer_annealing_state_path, "wb") as f: + pickle.dump(self.trainer_annealing_state, f) + logger.info(f"Saved trainer state to {self.trainer_annealing_state_path}") + + def export_trainer_states(self) -> None: + """ + Saves the trainer states. + """ + self.export_optimizer_states() + self.export_trainer_annealing_state() diff --git a/src_code_for_reproducibility/training/trainer_independent.py b/src_code_for_reproducibility/training/trainer_independent.py new file mode 100644 index 0000000000000000000000000000000000000000..30c1505a7fa07fcbe7044c4e83a6b7f556fe3817 --- /dev/null +++ b/src_code_for_reproducibility/training/trainer_independent.py @@ -0,0 +1,155 @@ +""" + +""" +import logging +import os +import sys +from typing import Union + +import torch +import torch.nn.functional as F +from accelerate import Accelerator +from pandas._libs.tslibs.offsets import CBMonthBegin +from peft import LoraConfig +from torch.nn.utils.rnn import pad_sequence +from transformers import AutoModelForCausalLM, AutoTokenizer + +from mllm.markov_games.rollout_tree import * +from mllm.markov_games.rollout_tree import RolloutTreeRootNode +from mllm.training.credit_methods import ( + get_discounted_returns, + get_discounted_state_visitation_credits, + get_generalized_advantage_estimates, + get_rloo_credits, +) +from mllm.training.tally_metrics import Tally +from mllm.training.tally_tokenwise import ContextualizedTokenwiseTally +from mllm.training.tokenize_chats import * +from mllm.training.tokenize_chats import process_training_chat +from mllm.training.trainer_common import BaseTrainer +from mllm.training.training_data_utils import * +from mllm.training.training_data_utils import ( + TrainingBatch, + TrajectoryBatch, + get_tokenwise_credits, +) +from mllm.utils.resource_context import resource_logger_context + +logger = logging.getLogger(__name__) +logger.addHandler(logging.StreamHandler(sys.stdout)) + + +@dataclass +class TrainingData: + agent_id: str + main_data: TrajectoryBatch + # list-of-tensors: per rollout advantages with length jT + main_advantages: list[torch.FloatTensor] | None = None + + +class TrainerNaive(BaseTrainer): + def set_agent_trajectory_data( + self, agent_id: str, roots: list[RolloutTreeRootNode] + ) -> None: + """ + TOWRITE + """ + # TODO: append to current batch data instead, else we will only train for one agent! + self.policy_gradient_data = None + + # Tensorize Chats + rollout_ids = [] + crn_ids = [] # common random number id + batch_input_ids = [] + batch_action_mask = [] + batch_entropy_mask = [] + batch_timesteps = [] + batch_state_ends_mask = [] + batch_engine_log_probs = [] + batch_rewards = [] + for root in roots: + rollout_id = root.id + self.debug_path_list.append( + "mgid:" + str(rollout_id) + "_agent_id:" + agent_id + ) + rollout_ids.append(rollout_id) + crn_ids.append(root.crn_id) + chat, rewards = get_main_chat_list_and_rewards(agent_id=agent_id, root=root) + ( + input_ids, + action_mask, + entropy_mask, + timesteps, + state_ends_mask, + engine_log_probs, + ) = process_training_chat( + tokenizer=self.tokenizer, + chat_history=chat, + entropy_mask_regex=self.entropy_mask_regex, + exploration_prompts_to_remove=self.exploration_prompts_to_remove, + ) + batch_input_ids.append(input_ids) + batch_action_mask.append(action_mask) + batch_entropy_mask.append(entropy_mask) + batch_timesteps.append(timesteps) + batch_state_ends_mask.append(state_ends_mask) + batch_engine_log_probs.append(engine_log_probs) + batch_rewards.append(rewards) + + trajectory_batch = TrajectoryBatch( + rollout_ids=torch.tensor(rollout_ids, dtype=torch.int32), + crn_ids=torch.tensor(crn_ids, dtype=torch.int32), + agent_ids=[agent_id] * len(rollout_ids), + batch_input_ids=batch_input_ids, + batch_action_mask=batch_action_mask, + batch_entropy_mask=batch_entropy_mask, + batch_timesteps=batch_timesteps, + batch_state_ends_mask=batch_state_ends_mask, + batch_rewards=batch_rewards, + batch_engine_log_probs=batch_engine_log_probs, + ) + + # Get Advantages + batch_advantages: torch.FloatTensor = ( + self.get_advantages_with_critic_gradient_accumulation(trajectory_batch) + ) + + # Discount state visitation (the mathematically correct way) + if not self.skip_discounted_state_visitation: + for i in range(len(batch_advantages)): + batch_advantages[i] = get_discounted_state_visitation_credits( + batch_advantages[i].unsqueeze(0), + self.discount_factor, + ).squeeze(0) + + self.training_data[agent_id] = TrainingData( + agent_id=agent_id, + main_data=trajectory_batch, + main_advantages=batch_advantages, + ) + + def receive_advantage_data(self, advantage_packets: list[AdvantagePacket]): + """ + This trainer ignores the advantages of the other trainers. + """ + for agent_id, agent_data in self.training_data.items(): + self.training_data[agent_id] = agent_data.main_data + self.training_data[agent_id].batch_credits = agent_data.main_advantages + + def share_advantage_data(self) -> list[AdvantagePacket]: + """ + Share the advantage data with other agents. + Returns: + AdvantagePacket: The advantage packet containing the agent's advantages. + """ + logger.info(f"Sharing advantage data.") + advantage_packets = [] + for agent_id, agent_data in self.training_data.items(): + advantage_packets.append( + AdvantagePacket( + agent_id=agent_id, + rollout_ids=agent_data.main_data.rollout_ids, + main_advantages=agent_data.main_advantages, + ) + ) + return advantage_packets diff --git a/src_code_for_reproducibility/training/trainer_sum_rewards.py b/src_code_for_reproducibility/training/trainer_sum_rewards.py new file mode 100644 index 0000000000000000000000000000000000000000..8f2ca9d16bf3b3990f01aeb99b6941f6a1df87d2 --- /dev/null +++ b/src_code_for_reproducibility/training/trainer_sum_rewards.py @@ -0,0 +1,127 @@ +""" + +""" +import logging +import os +import sys +from typing import Union + +import torch +import torch.nn.functional as F +from accelerate import Accelerator +from pandas._libs.tslibs.offsets import CBMonthBegin +from peft import LoraConfig +from torch.nn.utils.rnn import pad_sequence +from transformers import AutoModelForCausalLM, AutoTokenizer + +from mllm.markov_games.rollout_tree import * +from mllm.markov_games.rollout_tree import RolloutTreeRootNode +from mllm.training.credit_methods import ( + get_discounted_returns, + get_discounted_state_visitation_credits, + get_generalized_advantage_estimates, + get_rloo_credits, +) +from mllm.training.tally_metrics import Tally +from mllm.training.tally_rollout import RolloutTally, RolloutTallyItem +from mllm.training.tally_tokenwise import ContextualizedTokenwiseTally +from mllm.training.tokenize_chats import * +from mllm.training.tokenize_chats import process_training_chat +from mllm.training.trainer_common import BaseTrainer +from mllm.training.trainer_independent import TrainerNaive, TrainingData +from mllm.training.training_data_utils import * +from mllm.training.training_data_utils import ( + AdvantagePacket, + TrainingBatch, + TrajectoryBatch, + get_tokenwise_credits, +) +from mllm.utils.resource_context import resource_logger_context + +logger = logging.getLogger(__name__) +logger.addHandler(logging.StreamHandler(sys.stdout)) + + +class TrainerSumRewards(TrainerNaive): + def receive_advantage_data(self, advantage_packets: list[AdvantagePacket]): + """ + Sums the advantages of the other trainers + """ + logger.info(f"Receiving advantage packets.") + + assert ( + len(advantage_packets) > 0 + ), "At least one advantage packet must be provided." + + for agent_id, agent_data in self.training_data.items(): + coagent_advantage_packets = [ + packet for packet in advantage_packets if packet.agent_id != agent_id + ] + agent_rollout_ids = agent_data.main_data.rollout_ids + agent_advantages = agent_data.main_advantages + co_agent_advantages = [] + for rollout_id in agent_rollout_ids: + for co_agent_packet in coagent_advantage_packets: + if rollout_id in co_agent_packet.rollout_ids: + index = torch.where(rollout_id == co_agent_packet.rollout_ids)[ + 0 + ].item() + co_agent_advantages.append( + co_agent_packet.main_advantages[index] + ) + # assumes that its two player game, with one co-agent + break + assert len(co_agent_advantages) == len(agent_advantages) + B = len(agent_advantages) + assert all( + a.shape[0] == b.shape[0] + for a, b in zip(co_agent_advantages, agent_advantages) + ), "Number of advantages must match in order to sum them up." + + # Get padded tensors (advantage alignment is invariant to padding) + lengths = torch.tensor( + [len(t) for t in agent_advantages], + device=self.device, + dtype=torch.long, + ) + padded_main_advantages = pad_sequence( + agent_advantages, batch_first=True, padding_value=0.0 + ) + + padded_co_agent_advantages = pad_sequence( + co_agent_advantages, batch_first=True, padding_value=0.0 + ) + + # Create training batch data + sum_of_ad_credits = padded_main_advantages + padded_co_agent_advantages + self.rollout_tally.add_metric( + path=["sum_of_ad_credits"], + rollout_tally_item=RolloutTallyItem( + crn_ids=agent_data.main_data.crn_ids, + rollout_ids=agent_data.main_data.rollout_ids, + agent_ids=agent_data.main_data.agent_ids, + metric_matrix=sum_of_ad_credits, + ), + ) + + if not self.skip_discounted_state_visitation: + sum_of_ad_credits = get_discounted_state_visitation_credits( + sum_of_ad_credits, + self.discount_factor, + ) + self.rollout_tally.add_metric( + path=["discounted_state_visitation_credits"], + rollout_tally_item=RolloutTallyItem( + crn_ids=agent_data.main_data.crn_ids, + rollout_ids=agent_data.main_data.rollout_ids, + agent_ids=agent_data.main_data.agent_ids, + metric_matrix=sub_tensors[ + "discounted_state_visitation_credits" + ], + ), + ) + + # Slice back to jagged and convert to tokenwise credits + sum_of_ad_credits = [sum_of_ad_credits[i, : lengths[i]] for i in range(B)] + self.training_data[agent_id] = agent_data.main_data + self.training_data[agent_id].batch_credits = sum_of_ad_credits diff --git a/src_code_for_reproducibility/training/training_data_utils.py b/src_code_for_reproducibility/training/training_data_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..deb59e5c575c8d18e8f49d348c8aa342f1c40d22 --- /dev/null +++ b/src_code_for_reproducibility/training/training_data_utils.py @@ -0,0 +1,394 @@ +from dataclasses import dataclass +from typing import Literal, Optional, Tuple + +import torch +from torch.nn.utils.rnn import pad_sequence + +from mllm.markov_games.rollout_tree import ( + ChatTurn, + RolloutTreeBranchNode, + RolloutTreeNode, + RolloutTreeRootNode, +) + + +@dataclass +class AdvantagePacket: + agent_id: str + rollout_ids: torch.IntTensor # (B,) + # list-of-tensors + main_advantages: list[torch.FloatTensor] + + +class TrainingChatTurn: + # TODO: simplify by making this a child of ChatTurn + """ + This class contains the chat turns for a single agent. + It is like ChatTurn, but with the time step added. + """ + + def __init__( + self, + time_step: int, + role: str, + agent_id: str, + content: str, + chat_template_token_ids: list[int], + reasoning_content: str, + is_state_end: bool, + out_token_ids: Optional[list[int]] = None, + log_probs: Optional[list[float]] = None, + ) -> None: + self.time_step = time_step + self.role = role + self.agent_id = agent_id + self.content = content + self.chat_template_token_ids = chat_template_token_ids + self.reasoning_content = reasoning_content + self.is_state_end = is_state_end + self.out_token_ids = out_token_ids + self.log_probs = log_probs + + def dict(self): + return { + "time_step": self.time_step, + "role": self.role, + "agent_id": self.agent_id, + "content": self.content, + "chat_template_token_ids": self.chat_template_token_ids, + "reasoning_content": self.reasoning_content, + "is_state_end": self.is_state_end, + "out_token_ids": self.out_token_ids, + "log_probs": self.log_probs, + } + + +def get_main_chat_list_and_rewards( + agent_id: str, root: RolloutTreeRootNode | RolloutTreeNode +) -> Tuple[list[TrainingChatTurn], torch.FloatTensor]: + """ + This method traverses a rollout tree and returns a the list of ChatTurn + for an agent. If it encounters a branch node, it follows the main path. + """ + # TODO; extend for all trees, not just linear + if isinstance(root, RolloutTreeRootNode): + current_node = root.child + else: + current_node = root + + chat = [] + rewards = [] + while current_node is not None: + if isinstance(current_node, RolloutTreeBranchNode): + current_node = current_node.main_child + reward: float = current_node.step_log.simulation_step_log.rewards[agent_id] + rewards.append(reward) + chat_turns: list[TrainingChatTurn] = current_node.step_log.action_logs[ + agent_id + ].chat_turns + chat_turns = [ + TrainingChatTurn(time_step=current_node.time_step, **turn.model_dump()) + for turn in chat_turns + ] + chat.extend(chat_turns) + current_node = current_node.child + return chat, torch.FloatTensor(rewards) + + +def get_tokenwise_credits( + # B := batch size, S := number of tokens / seq. length, T := number of states. `j` stands for jagged (see pytorch nested tensors.) + batch_timesteps: torch.IntTensor | torch.Tensor, # (B, jS), + batch_credits: torch.FloatTensor | torch.Tensor, # (B, jT) +) -> torch.FloatTensor | torch.Tensor: # (B, jS) + """ + TOWRITE + """ + # TODO vectorize this code + batch_token_credits = [] + for credits, timesteps in zip(batch_credits, batch_timesteps): + token_credits = torch.zeros_like( + timesteps, + dtype=credits.dtype, + device=timesteps.device, + ) + for idx, credit in enumerate(credits): + token_credits[timesteps == idx] = credit + batch_token_credits.append(token_credits) + return batch_token_credits + + +@dataclass +class TrajectoryBatch: + """ + Tensorized batch of trajectories using list-of-tensors for jagged dimensions. + """ + + # B := batch size, S := number of tokens / seq. length, T := number of states. + rollout_ids: torch.IntTensor # (B,) + crn_ids: torch.IntTensor # (B,) + agent_ids: list[str] # (B,) + batch_input_ids: list[torch.LongTensor] # List[(jS,)] + batch_action_mask: list[torch.BoolTensor] # List[(jS,)] + batch_entropy_mask: list[torch.BoolTensor] # List[(jS,)] + batch_timesteps: list[torch.IntTensor] # List[(jS,)] + batch_state_ends_mask: list[torch.BoolTensor] # List[(jS,)] + batch_engine_log_probs: Optional[list[torch.FloatTensor]] # List[(jS,)] + batch_rewards: list[torch.FloatTensor] # List[(jT,)] + batch_credits: Optional[list[torch.FloatTensor]] = None # List[(jS,)] + + def __post_init__(self): + """ + Validate per-sample consistency. + """ + B = self.rollout_ids.shape[0] + assert ( + self.crn_ids.shape[0] == B + ), "RNG IDs must have length equal to batch size." + assert ( + len(self.agent_ids) == B + ), "agent_ids must have length equal to batch size." + assert ( + len(self.batch_input_ids) + == len(self.batch_action_mask) + == len(self.batch_entropy_mask) + == len(self.batch_timesteps) + == len(self.batch_state_ends_mask) + == len(self.batch_engine_log_probs) + == len(self.batch_rewards) + == B + ), "Jagged lists must all have length equal to batch size." + + for b in range(B): + nb_rewards = int(self.batch_rewards[b].shape[0]) + nb_timesteps = int(torch.max(self.batch_timesteps[b]).item()) + 1 + assert ( + nb_rewards == nb_timesteps + ), "Number of rewards and timesteps mismatch." + assert ( + self.batch_input_ids[b].shape[0] + == self.batch_action_mask[b].shape[0] + == self.batch_entropy_mask[b].shape[0] + == self.batch_engine_log_probs[b].shape[0] + == self.batch_timesteps[b].shape[0] + ), "Tensors must have the same shape along the jagged dimension." + assert ( + int(self.batch_state_ends_mask[b].sum()) + == self.batch_rewards[b].shape[0] + ), "Number of rewards must match number of state ends." + + """ + Entries: + Here, we ignore the batch dimension. + input_ids: + All of the tokens of both the user and the assistant, flattened. + action_mask: + Set to true on the tokens of the assistant (tokens generated by the model). + timesteps: + Therefore, max(timesteps) = Ns - 1. + state_ends_idx: + Indices of the tokens at which state descriptions end. + rewards: + rewards[t] := R_t(s_t, a_t) + Example: + position: "0 1 2 3 4 5 6 7 8 9 10 11 12 13 14" + input_ids: "U U U a a a U a U a a a U U U" (U := User, a := Assistant) + action_mask: "x x x ✓ ✓ ✓ x ✓ x ✓ ✓ ✓ x x x" + timestep: "0 0 0 0 0 0 1 1 1 1 1 1 2 2 2" + state_ends_dx: [2, 6, 14] + rewards: [r0, r1, r2] + """ + + def __getitem__(self, key) -> "TrajectoryBatch": + if isinstance(key, slice): + return TrajectoryBatch( + rollout_ids=self.rollout_ids.__getitem__(key), + crn_ids=self.crn_ids.__getitem__(key), + agent_ids=self.agent_ids[key], + batch_input_ids=self.batch_input_ids[key], + batch_action_mask=self.batch_action_mask[key], + batch_entropy_mask=self.batch_entropy_mask[key], + batch_timesteps=self.batch_timesteps[key], + batch_state_ends_mask=self.batch_state_ends_mask[key], + batch_engine_log_probs=self.batch_engine_log_probs[key], + batch_rewards=self.batch_rewards[key], + batch_credits=self.batch_credits[key] if self.batch_credits else None, + ) + + def __len__(self): + return len(self.batch_input_ids) + + def to(self, device): + self.rollout_ids = self.rollout_ids.to(device) + self.crn_ids = self.crn_ids.to(device) + self.batch_input_ids = [t.to(device) for t in self.batch_input_ids] + self.batch_action_mask = [t.to(device) for t in self.batch_action_mask] + self.batch_entropy_mask = [t.to(device) for t in self.batch_entropy_mask] + self.batch_timesteps = [t.to(device) for t in self.batch_timesteps] + self.batch_state_ends_mask = [t.to(device) for t in self.batch_state_ends_mask] + self.batch_engine_log_probs = [ + t.to(device) for t in self.batch_engine_log_probs + ] + self.batch_rewards = [t.to(device) for t in self.batch_rewards] + self.batch_credits = ( + [t.to(device) for t in self.batch_credits] if self.batch_credits else None + ) + + def get_padded_tensors_for_critic(self): + """ + Returns: + padded_batch_input_ids: (B, P) + padded_batch_state_ends_mask: (B, P) + timestep_counts: (B,) tensor of ints indicating number of states per sample + """ + padded_batch_input_ids = pad_sequence( + self.batch_input_ids, batch_first=True, padding_value=0 + ) + padded_batch_state_ends_mask = pad_sequence( + self.batch_state_ends_mask, batch_first=True, padding_value=0 + ).bool() + # number of states equals number of True in state_ends_mask + timestep_counts = torch.tensor( + [int(mask.sum().item()) for mask in self.batch_state_ends_mask], + device=padded_batch_input_ids.device, + dtype=torch.long, + ) + return padded_batch_input_ids, padded_batch_state_ends_mask, timestep_counts + + +timestep = int + + +@dataclass +class PaddedTensorTrainingBatch: + batch_input_ids: torch.LongTensor | torch.Tensor + batch_action_mask: torch.BoolTensor | torch.Tensor + batch_entropy_mask: Optional[torch.BoolTensor | torch.Tensor] + batch_credits: torch.FloatTensor | torch.Tensor + batch_engine_log_probs: torch.FloatTensor | torch.Tensor + batch_timesteps: torch.IntTensor | torch.Tensor + + def __len__(self): + return self.batch_input_ids.shape[0] + + def to(self, device): + self.batch_input_ids = self.batch_input_ids.to(device) + self.batch_action_mask = self.batch_action_mask.to(device) + self.batch_entropy_mask = self.batch_entropy_mask.to(device) + self.batch_credits = self.batch_credits.to(device) + self.batch_engine_log_probs = self.batch_engine_log_probs.to(device) + self.batch_timesteps = self.batch_timesteps.to(device) + + +@dataclass +class TrainingBatch: + rollout_ids: torch.IntTensor | torch.Tensor # (B,) + batch_input_ids: list[torch.LongTensor] # List[(jS,)] + batch_action_mask: list[torch.BoolTensor] # List[(jS,)] + batch_entropy_mask: Optional[list[torch.BoolTensor]] # List[(jS,)] + batch_credits: list[torch.FloatTensor] # List[(jS,)] + batch_engine_log_probs: list[torch.FloatTensor] # List[(jS,)] + batch_timesteps: list[torch.IntTensor] # List[(jS,)] + + def __post_init__(self): + # Put everything in the right device + # self.rollout_ids = self.rollout_ids.to("cuda" if torch.cuda.is_available() else "cpu") + # self.batch_input_ids = self.batch_input_ids.to("cuda" if torch.cuda.is_available() else "cpu") + # self.batch_action_mask = self.batch_action_mask.to("cuda" if torch.cuda.is_available() else "cpu") + # self.batch_credits = self.batch_credits.to("cuda" if torch.cuda.is_available() else "cpu") + # Ensure batch dimension is present + assert ( + len(self.batch_input_ids) + == len(self.batch_action_mask) + == len(self.batch_entropy_mask) + == len(self.batch_credits) + == len(self.batch_engine_log_probs) + == len(self.batch_timesteps) + == self.rollout_ids.shape[0] + ), "Jagged lists must all have length equal to batch size." + for inp, mask, cred, engine_log_prob, timestep in zip( + self.batch_input_ids, + self.batch_action_mask, + self.batch_credits, + self.batch_engine_log_probs, + self.batch_timesteps, + ): + assert ( + inp.shape[0] + == mask.shape[0] + == cred.shape[0] + == engine_log_prob.shape[0] + == timestep.shape[0] + ), "Tensors must have the same shapes along the jagged dimension." + + def __getitem__(self, key) -> "TrainingBatch": + if isinstance(key, slice): + return TrainingBatch( + rollout_ids=self.rollout_ids.__getitem__(key), + batch_input_ids=self.batch_input_ids[key], + batch_action_mask=self.batch_action_mask[key], + batch_entropy_mask=self.batch_entropy_mask[key], + batch_credits=self.batch_credits[key], + batch_engine_log_probs=self.batch_engine_log_probs[key], + batch_timesteps=self.batch_timesteps[key], + ) + + def __len__(self): + return len(self.batch_input_ids) + + def to(self, device): + self.rollout_ids = self.rollout_ids.to(device) + self.batch_input_ids = [t.to(device) for t in self.batch_input_ids] + self.batch_action_mask = [t.to(device) for t in self.batch_action_mask] + self.batch_entropy_mask = [t.to(device) for t in self.batch_entropy_mask] + self.batch_credits = [t.to(device) for t in self.batch_credits] + self.batch_engine_log_probs = [ + t.to(device) for t in self.batch_engine_log_probs + ] + self.batch_timesteps = [t.to(device) for t in self.batch_timesteps] + + def get_padded_tensors(self, padding: float = 0.0): + """ + TOWRITE + Always pad to the right. + """ + padded_batch_input_ids = pad_sequence( + self.batch_input_ids, batch_first=True, padding_value=int(padding) + ) + padded_batch_action_mask = pad_sequence( + [m.to(dtype=torch.bool) for m in self.batch_action_mask], + batch_first=True, + padding_value=False, + ) + padded_batch_entropy_mask = pad_sequence( + self.batch_entropy_mask, batch_first=True, padding_value=False + ) + padded_batch_credits = pad_sequence( + self.batch_credits, batch_first=True, padding_value=float(padding) + ) + padded_batch_engine_log_probs = pad_sequence( + self.batch_engine_log_probs, batch_first=True, padding_value=float(padding) + ) + padded_batch_timesteps = pad_sequence( + self.batch_timesteps, batch_first=True, padding_value=0 + ) + + return PaddedTensorTrainingBatch( + padded_batch_input_ids, + padded_batch_action_mask, + padded_batch_entropy_mask, + padded_batch_credits, + padded_batch_engine_log_probs, + padded_batch_timesteps, + ) + + def append(self, other: "TrainingBatch"): + self.rollout_ids = torch.cat([self.rollout_ids, other.rollout_ids]) + self.batch_input_ids.extend(other.batch_input_ids) + self.batch_action_mask.extend(other.batch_action_mask) + self.batch_entropy_mask.extend(other.batch_entropy_mask) + self.batch_credits.extend(other.batch_credits) + self.batch_engine_log_probs.extend(other.batch_engine_log_probs) + self.batch_timesteps.extend(other.batch_timesteps) + + +timestep = int diff --git a/src_code_for_reproducibility/utils/__init__.py b/src_code_for_reproducibility/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src_code_for_reproducibility/utils/__pycache__/__init__.cpython-312.pyc b/src_code_for_reproducibility/utils/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ba6cc80392e0f2f7225603138d2e0ed90ba72a7c Binary files /dev/null and b/src_code_for_reproducibility/utils/__pycache__/__init__.cpython-312.pyc differ diff --git a/src_code_for_reproducibility/utils/__pycache__/dict_get_path.cpython-312.pyc b/src_code_for_reproducibility/utils/__pycache__/dict_get_path.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cdd2e2245a161c486c78cb1b2fa0999ebc64e794 Binary files /dev/null and b/src_code_for_reproducibility/utils/__pycache__/dict_get_path.cpython-312.pyc differ diff --git a/src_code_for_reproducibility/utils/__pycache__/get_coagent_id.cpython-312.pyc b/src_code_for_reproducibility/utils/__pycache__/get_coagent_id.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9165e24ec1d3a0d719bba1e440f942f5a0efc5fe Binary files /dev/null and b/src_code_for_reproducibility/utils/__pycache__/get_coagent_id.cpython-312.pyc differ diff --git a/src_code_for_reproducibility/utils/__pycache__/kill_sglang.cpython-312.pyc b/src_code_for_reproducibility/utils/__pycache__/kill_sglang.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..009dcac8975aea1519d156413ea672cd03642827 Binary files /dev/null and b/src_code_for_reproducibility/utils/__pycache__/kill_sglang.cpython-312.pyc differ diff --git a/src_code_for_reproducibility/utils/__pycache__/resource_context.cpython-312.pyc b/src_code_for_reproducibility/utils/__pycache__/resource_context.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..563e7cf2dacb0dfe7f81d089366d103620ae3716 Binary files /dev/null and b/src_code_for_reproducibility/utils/__pycache__/resource_context.cpython-312.pyc differ diff --git a/src_code_for_reproducibility/utils/__pycache__/rollout_tree_gather_utils.cpython-312.pyc b/src_code_for_reproducibility/utils/__pycache__/rollout_tree_gather_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c9cbf2cb5430eae74fa89c6a2c410ba5aa4cad00 Binary files /dev/null and b/src_code_for_reproducibility/utils/__pycache__/rollout_tree_gather_utils.cpython-312.pyc differ diff --git a/src_code_for_reproducibility/utils/__pycache__/rollout_tree_stats.cpython-312.pyc b/src_code_for_reproducibility/utils/__pycache__/rollout_tree_stats.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3c7ee269de5a2423778b5bbc71093903f3fc08cd Binary files /dev/null and b/src_code_for_reproducibility/utils/__pycache__/rollout_tree_stats.cpython-312.pyc differ diff --git a/src_code_for_reproducibility/utils/__pycache__/short_id_gen.cpython-312.pyc b/src_code_for_reproducibility/utils/__pycache__/short_id_gen.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..796e0282d192e19efa6a60429321d7affa658b02 Binary files /dev/null and b/src_code_for_reproducibility/utils/__pycache__/short_id_gen.cpython-312.pyc differ diff --git a/src_code_for_reproducibility/utils/__pycache__/stat_pack.cpython-312.pyc b/src_code_for_reproducibility/utils/__pycache__/stat_pack.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d09130d5baaf644a3bbc95e330ec65df39764a28 Binary files /dev/null and b/src_code_for_reproducibility/utils/__pycache__/stat_pack.cpython-312.pyc differ diff --git a/src_code_for_reproducibility/utils/__pycache__/update_start_epoch.cpython-312.pyc b/src_code_for_reproducibility/utils/__pycache__/update_start_epoch.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..098889d4cd8aa59c5514fbc18955a62ce012157b Binary files /dev/null and b/src_code_for_reproducibility/utils/__pycache__/update_start_epoch.cpython-312.pyc differ diff --git a/src_code_for_reproducibility/utils/__pycache__/wandb_utils.cpython-312.pyc b/src_code_for_reproducibility/utils/__pycache__/wandb_utils.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..085c6327b4801fb26f38ed0e8efc0c6b7c5b4f89 Binary files /dev/null and b/src_code_for_reproducibility/utils/__pycache__/wandb_utils.cpython-312.pyc differ diff --git a/src_code_for_reproducibility/utils/dict_get_path.py b/src_code_for_reproducibility/utils/dict_get_path.py new file mode 100644 index 0000000000000000000000000000000000000000..cad0acbc4ecd3003dc3ece71a586b4796c0796af --- /dev/null +++ b/src_code_for_reproducibility/utils/dict_get_path.py @@ -0,0 +1,12 @@ + +def get_from_nested_dict(a:dict, path) -> any: + # path is string or list of string + try: + if isinstance(path, str): + return a[path] + else: + for p in path: + a = a[p] + return a + except Exception: + return None diff --git a/src_code_for_reproducibility/utils/format_time.py b/src_code_for_reproducibility/utils/format_time.py new file mode 100644 index 0000000000000000000000000000000000000000..1ad936b01738baaaa9aa03bc8436dcfef1bb4825 --- /dev/null +++ b/src_code_for_reproducibility/utils/format_time.py @@ -0,0 +1,7 @@ +def format_time(seconds): + if seconds >= 3600: + return f"{int(seconds // 3600)}h {int((seconds % 3600) // 60)}m {int(seconds % 60)}s" + elif seconds >= 60: + return f"{int(seconds // 60)}m {int(seconds % 60)}s" + else: + return f"{int(seconds)}s" diff --git a/src_code_for_reproducibility/utils/gather_training_stats.py b/src_code_for_reproducibility/utils/gather_training_stats.py new file mode 100644 index 0000000000000000000000000000000000000000..98ae3d9cad04748059b1b471033178a1f7e4f385 --- /dev/null +++ b/src_code_for_reproducibility/utils/gather_training_stats.py @@ -0,0 +1,257 @@ +import copy +import csv +import gc +import json +import logging +import os +import pickle +import random +import re +import subprocess +import sys +import time +from datetime import datetime +from statistics import mean +from typing import Any, Dict + +import hydra +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import torch +from omegaconf import OmegaConf + +from mllm.training.tally_metrics import Tally +from mllm.utils.stat_pack import StatPack + + +def get_from_nested_dict(dictio: dict, path: list[str]): + for sp in path[:-1]: + dictio = dictio[sp] + return dictio.get(path[-1]) + + +def set_at_path(dictio: dict, path: list[str], value): + for sp in path[:-1]: + if sp not in dictio: + dictio[sp] = {} + dictio = dictio[sp] + dictio[path[-1]] = value + + +def produce_tabular_render(inpath: str, outpath: str = None): + """ + TODO: docstring + """ + with open(inpath, "r") as f: + data = json.load(f) + rollout_paths = data.keys() + for rollout_path in rollout_paths: + if outpath is None: + m_path = rollout_path.replace("/", "|") + m_path = m_path.replace(".json", "") + m_path = ( + os.path.split(inpath)[0] + + "/contextualized_tabular_renders/" + + m_path + + "_tabular_render.render.csv" + ) + # import pdb; pdb.set_trace() + os.makedirs(os.path.split(m_path)[0], exist_ok=True) + metrics = data[rollout_path] + d = {k: [] for k in metrics[0].keys()} + for m in metrics: + for k, v in m.items(): + d[k].append(v) + d = pd.DataFrame(d) + d.to_csv(m_path) + + +def get_metric_paths(data: list[dict]): + d = data[0] + paths = [] + + def traverse_dict(d, current_path=[]): + for key, value in d.items(): + new_path = current_path + [key] + if isinstance(value, dict): + traverse_dict(value, new_path) + else: + paths.append(new_path) + + traverse_dict(d) + return paths + + +def print_metric_paths(data: list[dict]): + paths = get_metric_paths(data) + for p in paths: + print(p) + + +def get_metric_iteration_list(data: list[dict], metric_path: list[str]): + if isinstance(metric_path, str): + metric_path = [metric_path] + sgl = [] + for d in data: + sgl.append(get_from_nested_dict(d, metric_path)) + return sgl + + +def to_1d_numeric(x): + """Return a 1-D float array (or None if not numeric). Accepts scalars, numpy arrays, or nested list/tuple of them.""" + if x is None: + return None + if isinstance(x, (int, float, np.number)): + return np.array([float(x)], dtype=float) + if isinstance(x, np.ndarray): + try: + return x.astype(float).ravel() + except Exception: + return None + if isinstance(x, (list, tuple)): + parts = [] + for e in x: + arr = to_1d_numeric(e) + if arr is not None and arr.size > 0: + parts.append(arr) + if parts: + return np.concatenate(parts) + return None + return None + + +def get_single_metric_vector(data, metric_path, iterations=None): + if isinstance(metric_path, str): + metric_path = [metric_path] + if iterations == None: + iterations = len(data) + vecs = [] + for d in data: + ar = get_from_nested_dict(d, metric_path) + arr = to_1d_numeric(ar) + if arr is not None: + vecs.append(arr) + + return np.concatenate(vecs) if vecs else np.empty(0, dtype=float) + + +def _load_metrics_file(file_path: str): + if not (file_path.endswith(".tally.pkl") or file_path.endswith(".pkl")): + raise ValueError("Only *.tally.pkl files are supported.") + import pickle + + with open(file_path, "rb") as f: + tree = pickle.load(f) + return tree + + +def get_leaf_items(array_tally: dict, prefix: list[str] = None): + if prefix is None: + prefix = [] + for key, value in array_tally.items(): + next_prefix = prefix + [str(key)] + if isinstance(value, dict): + yield from get_leaf_items(value, next_prefix) + else: + yield next_prefix, value + + +def _sanitize_filename_part(part: str) -> str: + s = part.replace("/", "|") + s = s.replace(" ", "_") + return s + + +def render_rt_tally_pkl_to_csvs(pkl_path: str, outdir: str): + """ + This method takes care of tokenwise logging. + """ + with open(pkl_path, "rb") as f: + payload = pickle.load(f) + # Backward compatibility: older tallies stored the dict directly + if isinstance(payload, dict) and "array_tally" in payload: + array_tally = payload.get("array_tally", {}) + else: + array_tally = payload + + os.makedirs(outdir, exist_ok=True) + trainer_id = os.path.basename(pkl_path).replace(".rt_tally.pkl", "") + for path_list, rollout_tally_items in get_leaf_items(array_tally): + # Create file and initiate writer + path_part = ".".join(_sanitize_filename_part(p) for p in path_list) + filename = f"{trainer_id}__{path_part}.render.csv" + out_path = os.path.join(outdir, filename) + + # Write metric rows to CSV + with open(out_path, "w", newline="") as f: + writer = csv.writer(f) + + # Write header row - need to determine metric column count from first rollout_tally_item + first_item = rollout_tally_items[0] + metric_cols = ( + first_item.metric_matrix.shape[1] + if first_item.metric_matrix.ndim > 1 + else 1 + ) + header = ["agent_id", "crn_id", "rollout_id"] + [ + f"t_{i}" for i in range(metric_cols) + ] + writer.writerow(header) + + for rollout_tally_item in rollout_tally_items: + crn_ids = rollout_tally_item.crn_ids + rollout_ids = rollout_tally_item.rollout_ids + agent_ids = rollout_tally_item.agent_ids + metric_matrix = rollout_tally_item.metric_matrix + for i in range(metric_matrix.shape[0]): + row_vals = metric_matrix[i].reshape(-1) + # Convert row_vals to a list to avoid numpy concatenation issues + row_vals = ( + row_vals.tolist() + if hasattr(row_vals, "tolist") + else list(row_vals) + ) + row_prefix = [ + agent_ids[i], + crn_ids[i], + rollout_ids[i], + ] + writer.writerow(row_prefix + row_vals) + + +def tally_to_stat_pack(tally: Dict[str, Any]): + stat_pack = StatPack() + if "array_tally" in tally: + tally = tally["array_tally"] + + # backward compatibility: will remove later, flatten keys in tally + def get_from_nested_dict(dictio: dict, path: list[str]): + for sp in path[:-1]: + dictio = dictio[sp] + return dictio.get(path[-1]) + + def get_metric_paths(tally: dict): + paths = [] + + def traverse_dict(tally, current_path=[]): + for key, value in tally.items(): + new_path = current_path + [key] + if isinstance(value, dict): + traverse_dict(value, new_path) + else: + paths.append(new_path) + + traverse_dict(tally) + return paths + + paths = get_metric_paths(tally) + modified_tally = {} + for p in paths: + val = get_from_nested_dict(tally, p) + modified_tally["_".join(p)] = np.mean(val) + del tally + tally = modified_tally + for key, value in tally.items(): + stat_pack.add_stat(key, value) + return stat_pack diff --git a/src_code_for_reproducibility/utils/get_coagent_id.py b/src_code_for_reproducibility/utils/get_coagent_id.py new file mode 100644 index 0000000000000000000000000000000000000000..16f0dff5c3f1834aa74c9d8257573b829d62603d --- /dev/null +++ b/src_code_for_reproducibility/utils/get_coagent_id.py @@ -0,0 +1,4 @@ + +def get_coagent_id(ids: list[str], agent_id:str) -> str | None: + for id in ids: + if id != agent_id: return id diff --git a/src_code_for_reproducibility/utils/get_stochastic_game_lengths.py b/src_code_for_reproducibility/utils/get_stochastic_game_lengths.py new file mode 100644 index 0000000000000000000000000000000000000000..a43c386aa1764ae2de1d6e177a0238c633c74bba --- /dev/null +++ b/src_code_for_reproducibility/utils/get_stochastic_game_lengths.py @@ -0,0 +1,30 @@ +import numpy as np + +def get_stochastic_game_lengths( + max_length, + nb_games, + continuation_prob, + same_length_batch=False +): + """ + Generates stochastic game lengths based on a geometric distribution. + + Args: + max_length (int): The maximum length a game can have. + nb_games (int): The number of games to generate lengths for. + continuation_prob (float): The probability of the game continuing after each round. + same_length_batch (bool): If True, all games will have the same length. + + Returns: + Array: An array of game lengths. + """ + if continuation_prob == 1: + return [max_length] * nb_games + if same_length_batch: + length = np.random.geometric(1 - continuation_prob, 1) + game_lengths = np.repeat(length, nb_games) + else: + game_lengths = np.random.geometric(1 - continuation_prob, nb_games) + + game_lengths = np.where(game_lengths > max_length, max_length, game_lengths) + return game_lengths.tolist() diff --git a/src_code_for_reproducibility/utils/kill_sglang.py b/src_code_for_reproducibility/utils/kill_sglang.py new file mode 100644 index 0000000000000000000000000000000000000000..e1e1fe9059e4262e995e0876f7eaec1c7aae4464 --- /dev/null +++ b/src_code_for_reproducibility/utils/kill_sglang.py @@ -0,0 +1,17 @@ +import psutil +import signal + +target_name = "sglang::scheduler" +killed = [] + +def kill_sglang(): + for proc in psutil.process_iter(['pid', 'name', 'cmdline']): + try: + # Some processes may not have a name or cmdline + cmdline = " ".join(proc.info['cmdline']) if proc.info['cmdline'] else "" + if target_name in cmdline: + print(f"Killing PID {proc.pid}: {cmdline}") + proc.send_signal(signal.SIGKILL) + killed.append(proc.pid) + except (psutil.NoSuchProcess, psutil.AccessDenied): + pass diff --git a/src_code_for_reproducibility/utils/output_source_code.py b/src_code_for_reproducibility/utils/output_source_code.py new file mode 100644 index 0000000000000000000000000000000000000000..42b51ecb2b818e5e225af10c58898b8be21dee4d --- /dev/null +++ b/src_code_for_reproducibility/utils/output_source_code.py @@ -0,0 +1,6 @@ +def output_source_code(model, output_path: str) -> None: + """ + Outputs the source code of the model to the given path. + """ + with open(output_path, "w") as f: + f.write(model.source_code) diff --git a/src_code_for_reproducibility/utils/resource_context.py b/src_code_for_reproducibility/utils/resource_context.py new file mode 100644 index 0000000000000000000000000000000000000000..43a3a55d0ca0d4a69eadd0c57650a5afd2ae4831 --- /dev/null +++ b/src_code_for_reproducibility/utils/resource_context.py @@ -0,0 +1,78 @@ +import logging +import time +from contextlib import contextmanager + +import torch + + +def vram_usage(): + output = "" + for i in range(torch.cuda.device_count()): + gpu_memory_allocated = torch.cuda.memory_allocated(i) / ( + 1024**3 + ) # Convert bytes to GB + gpu_memory_reserved = torch.cuda.memory_reserved(i) / ( + 1024**3 + ) # Convert bytes to GB + output += f"GPU {i}: Memory Allocated: {gpu_memory_allocated:.2f} GB, Memory Reserved: {gpu_memory_reserved:.2f} GB" + return output + + +def ram_usage(): + import psutil + + process = psutil.Process() + memory_info = process.memory_info() + ram_used = memory_info.rss / (1024**3) # Convert bytes to GB + return f"RAM Usage: {ram_used:.2f} GB" + + +@contextmanager +def resource_logger_context(logger: logging.Logger, task_description: str): + """ + Context manager to log the resource usage of the current task. + Args: + logger: The logger to use to log the resource usage. + task_description: The description of the task to log. + Returns: + None + """ + try: + initial_time = time.time() + # Assume CUDA is available and use device 0 only + total_mem_bytes = torch.cuda.get_device_properties(0).total_memory + initial_total_bytes = ( + torch.cuda.memory_allocated(0) + torch.cuda.memory_reserved(0) + ) + torch.cuda.reset_peak_memory_stats(0) + yield None + finally: + final_time = time.time() + # Ensure kernels within the block are accounted for + torch.cuda.synchronize() + + # Compute metrics + final_allocated_bytes = torch.cuda.memory_allocated(0) + final_reserved_bytes = torch.cuda.memory_reserved(0) + final_total_bytes = final_allocated_bytes + final_reserved_bytes + + delta_vram_percent_total = ( + 100 * (final_total_bytes - initial_total_bytes) / total_mem_bytes + if total_mem_bytes + else 0.0 + ) + current_percent_vram_taken = ( + 100 * final_total_bytes / total_mem_bytes if total_mem_bytes else 0.0 + ) + block_peak_percent = ( + 100 * torch.cuda.max_memory_allocated(0) / total_mem_bytes + if total_mem_bytes + else 0.0 + ) + delta_time_str = time.strftime( + '%H:%M:%S', time.gmtime(final_time - initial_time) + ) + + logger.info( + f"For task: {task_description}, ΔVRAM % (total): {delta_vram_percent_total:.2f}%, Current % of VRAM taken: {current_percent_vram_taken:.2f}%, Block Peak % of device VRAM: {block_peak_percent:.2f}%, ΔTime: {delta_time_str}" + ) diff --git a/src_code_for_reproducibility/utils/rollout_tree_chat_htmls.py b/src_code_for_reproducibility/utils/rollout_tree_chat_htmls.py new file mode 100644 index 0000000000000000000000000000000000000000..5f4ce6b0eeddfab84447f164e3c2432e00a50ba1 --- /dev/null +++ b/src_code_for_reproducibility/utils/rollout_tree_chat_htmls.py @@ -0,0 +1,1921 @@ +from pathlib import Path +from typing import List + +from mllm.utils.rollout_tree_gather_utils import * + + +def html_from_chat_turns(chat_turns: List[ChatTurnLog]) -> str: + """ + Render chat turns as a single, wrapping sequence of messages in time order. + Keep badge and message bubble styles, include time on every badge and + include rewards on assistant badges. Each message is individually + hide/show by click; when hidden, only the badge remains and "(...)" is + shown inline (not inside a bubble). + """ + import html + import re as _re + + # Prepare ordering: sort by (time_step, original_index) to keep stable order within same step + indexed_turns = list(enumerate(chat_turns)) + indexed_turns.sort(key=lambda t: (t[1].time_step, t[0])) + assistant_agents = sorted({t.agent_id for t in chat_turns if t.role == "assistant"}) + enable_split_view = len(assistant_agents) == 2 + + # CSS styles (simplified layout; no time-step or agent-column backgrounds) + css = """ + + """ + + # HTML structure + html_parts = [ + "", + "", + "", + "", + "Chat Turns", + css, + "", + "", + "", + '
', + '
', + '
', + '', + '', + "timesteps", + '', + '', + '', + '', + "to", + '', + '', + '', + ( + '' + if enable_split_view + else "" + ), + '', + '', + '', + '', + '', + '900px', + '', + '', + '', + '', + '', + '', + '', + '', + 'px', + '', + '', + '', + '', + '', + '|', + '', + '', + '', + '', + "
", + "
", + '") # close linear flow + if enable_split_view: + import html as _html_mod + + html_parts.append( + '") # flow-split + + # Add Chat View + import html as _html_mod + html_parts.append('
') + + # Helper function to add context annotation areas + def add_context_area(position: str, time_step: int): + context_key = f"round-context-{position}-{time_step}" + placeholder = f"Add context {position} round {time_step}..." + color_buttons = "" + # Add default/reset color button first + color_buttons += ( + f'
' + ) + for color_name, color_value in [ + ('red', '#d32f2f'), + ('orange', '#f57c00'), + ('yellow', '#f9a825'), + ('green', '#388e3c'), + ('blue', '#1976d2'), + ('purple', '#7b1fa2'), + ('gray', '#666666'), + ]: + color_buttons += ( + f'
' + ) + + html_parts.append( + f'
' + f'
' + f'
{color_buttons}
' + f'
' + ) + + # Helper function to add split agent context boxes + def add_split_agent_contexts(position: str, time_step: int): + color_buttons = "" + # Add default/reset color button first + color_buttons += ( + f'
' + ) + for color_name, color_value in [ + ('red', '#d32f2f'), + ('orange', '#f57c00'), + ('yellow', '#f9a825'), + ('green', '#388e3c'), + ('blue', '#1976d2'), + ('purple', '#7b1fa2'), + ('gray', '#666666'), + ]: + color_buttons += ( + f'
' + ) + + html_parts.append('
') + + # Alice box + alice_key = f"agent-context-alice-{position}-{time_step}" + alice_placeholder = f"..." + html_parts.append( + f'
' + f'
' + f'
{color_buttons}
' + f'
' + ) + + # Bob box + bob_key = f"agent-context-bob-{position}-{time_step}" + bob_placeholder = f"..." + html_parts.append( + f'
' + f'
' + f'
{color_buttons}
' + f'
' + ) + + html_parts.append('
') # split-agent-context + + last_time_step_chat = None + for original_index, turn in indexed_turns: + agent_class = f"agent-{re.sub('[^a-z0-9_-]', '-', turn.agent_id.lower())}" + role_class = f"role-{turn.role}" + + # Add time step divider and beginning context + if last_time_step_chat is None or turn.time_step != last_time_step_chat: + # Add end contexts for previous round (only regular context, not prompt summary) + if last_time_step_chat is not None: + add_context_area("end", last_time_step_chat) + + html_parts.append( + f'
' + f'⏱ Round {turn.time_step + 1}' + f'
' + ) + + # Add beginning contexts for new round (both context and prompt summary) + add_context_area("beginning", turn.time_step) + add_split_agent_contexts("beginning", turn.time_step) + + last_time_step_chat = turn.time_step + + # Build chat message with merge controls + html_parts.append(f'
') + + # Add merge control button + html_parts.append( + f'' + ) + + html_parts.append('
') + + # Header with agent name and reward (always show reward) + agent_id_clean = _html_mod.escape(turn.agent_id).lower() + if turn.role == "assistant": + name = _html_mod.escape(turn.agent_id) + raw_val = turn.reward + if isinstance(raw_val, (int, float)): + reward_val = f"{raw_val:.4f}".rstrip("0").rstrip(".") + if len(reward_val) > 8: + reward_val = reward_val[:8] + "…" + else: + reward_val = str(raw_val) + header_html = ( + f'
' + f'🤖 {name}' + f'⚑ {reward_val}' + f'
' + ) + else: + name = _html_mod.escape(turn.agent_id) + header_html = f'
Prompt of {name}
' + + html_parts.append(header_html) + + # Reasoning content if present + if turn.reasoning_content: + _raw_reasoning = turn.reasoning_content.replace("\r\n", "\n") + _raw_reasoning = _re.sub(r"^\s*\n+", "", _raw_reasoning) + esc_reasoning = _html_mod.escape(_raw_reasoning) + html_parts.append( + f'' + ) + + # Message bubble + esc_content = _html_mod.escape(turn.content) + html_parts.append(f'
{esc_content}
') + + html_parts.append('
') # chat-message-content + html_parts.append('
') # chat-message + + # Add end contexts for the last round (only regular context, not prompt summary) + if last_time_step_chat is not None: + add_context_area("end", last_time_step_chat) + + html_parts.append("
") # flow-chat + html_parts.extend(["", ""]) + + return "\n".join(html_parts) + + +def export_html_from_rollout_tree(path: Path, outdir: Path, main_only: bool = False): + """Process a rollout tree file and generate HTML files for each path. + Creates separate HTML files for the main path and each branch path. + The main path is saved in the root output directory, while branch paths + are saved in a 'branches' subdirectory. + + Args: + path: Path to the rollout tree JSON file + outdir: Output directory for HTML files + main_only: If True, only export the main trajectory (default: False) + """ + root = load_rollout_tree(path) + mgid = root.id + + main_path, branch_paths = get_rollout_tree_paths(root) + + outdir.mkdir(parents=True, exist_ok=True) + + # Create branches subdirectory if we have branch paths + if not main_only and branch_paths: + branches_dir = outdir / f"mgid:{mgid}_branches_html_renders" + branches_dir.mkdir(parents=True, exist_ok=True) + + # Generate HTML for the main path + chat_turns = gather_all_chat_turns_for_path(main_path) + html_content = html_from_chat_turns(chat_turns) + output_file = outdir / f"mgid:{mgid}_main_html_render.render.html" + with open(output_file, "w", encoding="utf-8") as f: + f.write(html_content) + + # Generate HTML for each branch path + for path_obj in branch_paths: + chat_turns = gather_all_chat_turns_for_path(path_obj) + + html_content = html_from_chat_turns(chat_turns) + + path_id: str = path_obj.id + output_filename = f"{path_id}_html_render.render.html" + + output_file = branches_dir / output_filename + + with open(output_file, "w", encoding="utf-8") as f: + f.write(html_content) diff --git a/src_code_for_reproducibility/utils/rollout_tree_gather_utils.py b/src_code_for_reproducibility/utils/rollout_tree_gather_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d1844aa5a20c01b591865cad108f1fb1577d3ef4 --- /dev/null +++ b/src_code_for_reproducibility/utils/rollout_tree_gather_utils.py @@ -0,0 +1,314 @@ +from __future__ import annotations + +import csv +import os +import pickle +import re +from collections import defaultdict +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple + +from mllm.markov_games.rollout_tree import * + + + + + +def load_rollout_tree(path: Path) -> RolloutTreeRootNode: + """Load a rollout tree from a PKL file containing a dict.""" + with open(path, "rb") as f: + data = pickle.load(f) + return RolloutTreeRootNode.model_validate(data) + + +@dataclass +class RolloutNodeList: + id: str + nodes: List[RolloutTreeNode] + + +def get_rollout_tree_paths( + root: RolloutTreeRootNode, mgid: Optional[str] = None +) -> Tuple[RolloutNodeList, List[RolloutNodeList]]: + """ + Returns: + main_path: The main path from the root to the end of the tree. + branch_paths: A list of all branch paths from the root to the end of the tree. + Each branch path contains a list of nodes that are part of the branch, including the nodes from the main path before the branch was taken. + """ + branch_paths = [] + + def collect_path_nodes(current) -> List[RolloutTreeNode]: + """Recursively collect all nodes in a path starting from current node.""" + if current is None: + return [] + + if isinstance(current, RolloutTreeNode): + return [current] + collect_path_nodes(current.child) + + elif isinstance(current, RolloutTreeBranchNode): + # For branch nodes, we only follow the main_child for path collection + if current.main_child: + return [current.main_child] + collect_path_nodes( + current.main_child.child + ) + else: + return [] + + def traverse_for_branches( + current, + main_path_prefix: List[RolloutTreeNode], + path_id: str, + current_time_step: Optional[int] = 0, + ): + """Traverse tree to collect all branch paths.""" + if current is None: + return + + if isinstance(current, RolloutTreeNode): + # Continue traversing with this node added to the main path prefix + new_prefix = main_path_prefix + [current] + traverse_for_branches(current.child, new_prefix, path_id, current.time_step) + + elif isinstance(current, RolloutTreeBranchNode): + # Collect all branch paths + if current.branches: + for agent_id, branch_node_list in current.branches.items(): + if branch_node_list: + # Start with the main path prefix, then recursively collect all nodes in this branch + branch_path_nodes = main_path_prefix.copy() + for branch_node in branch_node_list: + branch_path_nodes.extend(collect_path_nodes(branch_node)) + + # Create proper branch path ID with mgid, agent_id, and time_step + mgid_str = mgid or str(root.id) + branch_path_id = f"mgid:{mgid_str}_type:branch_agent:{agent_id}_time_step:{current_time_step}" + branch_paths.append( + RolloutNodeList(id=branch_path_id, nodes=branch_path_nodes) + ) + + # Process the main child and add to prefix + new_prefix = main_path_prefix + if current.main_child: + new_prefix = main_path_prefix + [current.main_child] + + # Continue traversing the main path + if current.main_child: + traverse_for_branches( + current.main_child.child, + new_prefix, + path_id, + current.main_child.time_step, + ) + + # Collect the main path nodes + main_path_nodes = collect_path_nodes(root.child) + + # Traverse to collect all branch paths + traverse_for_branches(root.child, [], "") + + # Create the main path with proper mgid format + mgid_str = mgid or str(root.id) + main_path = RolloutNodeList(id=f"mgid:{mgid_str}_type:main", nodes=main_path_nodes) + + return main_path, branch_paths + + +class ChatTurnLog(BaseModel): + time_step: int + agent_id: str + role: str + content: str + reasoning_content: Optional[str] = None + is_state_end: bool + reward: float + + +def gather_agent_chat_turns_for_path( + agent_id: str, path: RolloutNodeList +) -> List[ChatTurnLog]: + """Iterate through all chat turns for a specific agent in a path sorted by time step.""" + turns = [] + for node in path.nodes: + action_log = node.step_log.action_logs.get(agent_id, []) + if action_log: + for chat_turn in action_log.chat_turns or []: + turns.append( + ChatTurnLog( + time_step=node.time_step, + agent_id=agent_id, + role=chat_turn.role, + content=chat_turn.content, + reasoning_content=getattr(chat_turn, "reasoning_content", None), + is_state_end=chat_turn.is_state_end, + reward=node.step_log.simulation_step_log.rewards.get( + agent_id, 0 + ), + ) + ) + return turns + + +def gather_all_chat_turns_for_path(path: RolloutNodeList) -> List[ChatTurnLog]: + """Iterate through all chat turns for all agents in a path sorted by time step.""" + turns = [] + + # Collect turns from all agents, but interleave them per timestep by (user, assistant) pairs + for node in path.nodes: + # Build (user[, assistant]) pairs for each agent at this timestep + agent_ids = sorted(list(node.step_log.action_logs.keys())) + per_agent_pairs: Dict[str, List[List[ChatTurnLog]]] = {} + + for agent_id in agent_ids: + action_log = node.step_log.action_logs.get(agent_id) + pairs: List[List[ChatTurnLog]] = [] + current_pair: List[ChatTurnLog] = [] + + if action_log and action_log.chat_turns: + for chat_turn in action_log.chat_turns: + turn_log = ChatTurnLog( + time_step=node.time_step, + agent_id=agent_id, + role=chat_turn.role, + content=chat_turn.content, + reasoning_content=getattr(chat_turn, "reasoning_content", None), + is_state_end=chat_turn.is_state_end, + reward=node.step_log.simulation_step_log.rewards.get( + agent_id, 0 + ), + ) + + if chat_turn.role == "user": + # If a previous pair is open, close it and start a new one + if current_pair: + pairs.append(current_pair) + current_pair = [] + current_pair = [turn_log] + else: + # assistant: attach to an open user message if present; otherwise stand alone + if ( + current_pair + and len(current_pair) == 1 + and current_pair[0].role == "user" + ): + current_pair.append(turn_log) + pairs.append(current_pair) + current_pair = [] + else: + # No preceding user or already paired; treat as its own unit + pairs.append([turn_log]) + + if current_pair: + # Unpaired trailing user message + pairs.append(current_pair) + + per_agent_pairs[agent_id] = pairs + + # Interleave pairs across agents: A1, B1, A2, B2, ... + index = 0 + while True: + added_any = False + for agent_id in agent_ids: + agent_pairs = per_agent_pairs.get(agent_id, []) + if index < len(agent_pairs): + for tl in agent_pairs[index]: + turns.append(tl) + added_any = True + if not added_any: + break + index += 1 + + return turns + + +def chat_turns_to_dict(chat_turns: Iterator[ChatTurnLog]) -> Iterator[Dict[str, Any]]: + """Render all chat turns for a path as structured data for JSON.""" + for chat_turn in chat_turns: + yield chat_turn.model_dump() + + +def get_all_agents(root: RolloutTreeRootNode) -> List[str]: + """list of all agent IDs that appear in the tree.""" + if root.child is None: + return [] + + # Get the first node to extract all agent IDs + first_node = root.child + if isinstance(first_node, RolloutTreeBranchNode): + first_node = first_node.main_child + + if first_node is None: + return [] + + # All agents should be present in the first node + agents = set(first_node.step_log.action_logs.keys()) + agents.update(first_node.step_log.simulation_step_log.rewards.keys()) + + return sorted(list(agents)) + + +def gather_agent_main_rewards(agent_id: str, path: RolloutNodeList) -> List[float]: + """Gather main rewards for a specific agent in a path.""" + rewards = [] + for node in path.nodes: + reward = node.step_log.simulation_step_log.rewards[agent_id] + rewards.append(reward) + return rewards + + +def gather_all_rewards(path: RolloutNodeList) -> List[Dict[AgentId, float]]: + """Gather main rewards from main trajectory in a path.""" + rewards = [] + for node in path.nodes: + rewards.append(node.step_log.simulation_step_log.rewards.copy()) + return rewards + + +def gather_simulation_stats( + path: RolloutNodeList, + filter: Callable[[SimulationStepLog], bool], + stat_func: Callable[[SimulationStepLog], Any], +) -> List[Any]: + """Gather stats from main trajectory in a path.""" + stats = [] + for node in path.nodes: + sl = node.step_log.simulation_step_log + if filter(sl): + stats.append(stat_func(sl)) + return stats + + +def gather_simulation_step_logs(path: RolloutNodeList) -> List[SimulationStepLog]: + """Gather simulation information from main trajectory in a path.""" + infos = [] + for node in path.nodes: + infos.append(node.step_log.simulation_step_log) + return infos + + +def export_chat_logs(path: Path, outdir: Path): + """Process a rollout tree PKL file and generate a JSONL of chat turns as dicts. + Each line contains an object with path_id and chat_turns for a single path. + """ + import json + + root = load_rollout_tree(path) + mgid = root.id + + main_path, branch_paths = get_rollout_tree_paths(root) + all_paths = [main_path] + branch_paths + + outdir.mkdir(parents=True, exist_ok=True) + output_file = outdir / f"mgid:{mgid}_plucked_chats.render.jsonl" + + with open(output_file, "w", encoding="utf-8") as f: + for path_obj in all_paths: + chat_turns = gather_all_chat_turns_for_path(path_obj) + output_obj = { + "path_id": str(path_obj.id), + "chat_turns": list(chat_turns_to_dict(iter(chat_turns))), + } + f.write(json.dumps(output_obj, ensure_ascii=False) + "\n") + + diff --git a/src_code_for_reproducibility/utils/rollout_tree_stats.py b/src_code_for_reproducibility/utils/rollout_tree_stats.py new file mode 100644 index 0000000000000000000000000000000000000000..1ac3cd0e34212e7fdbeba19e501b7d96a5f128e1 --- /dev/null +++ b/src_code_for_reproducibility/utils/rollout_tree_stats.py @@ -0,0 +1,50 @@ +from typing import Any, Callable, List, Tuple + +from mllm.markov_games.rollout_tree import RolloutTreeRootNode +from mllm.markov_games.simulation import SimulationStepLog +from mllm.utils.rollout_tree_gather_utils import ( + gather_simulation_step_logs, + get_rollout_tree_paths, +) +from mllm.utils.stat_pack import StatPack + + +def get_rollout_tree_stat_tally( + rollout_tree: RolloutTreeRootNode, + metrics: List[Callable[[SimulationStepLog], List[Tuple[str, float]]]], +) -> StatPack: + stat_tally = StatPack() + # get simulation step logs + node_list = get_rollout_tree_paths(rollout_tree)[0] + simulation_step_logs = gather_simulation_step_logs(node_list) + for simulation_step_log in simulation_step_logs: + for metric in metrics: + metric_result = metric(simulation_step_log) + if metric_result is not None: + for key, value in metric_result: + stat_tally.add_stat(key, value) + return stat_tally + + +def get_rollout_tree_mean_stats( + rollout_tree: RolloutTreeRootNode, metrics: List[Callable[[SimulationStepLog], Any]] +) -> StatPack: + """Get the mean stats for a rollout tree.""" + stat_tally = get_rollout_tree_stat_tally(rollout_tree, metrics) + return stat_tally.mean() + + +def get_mean_rollout_tree_stats( + rollout_trees: List[RolloutTreeRootNode], + metrics: List[Callable[[SimulationStepLog], Any]], +) -> StatPack: + """Get the mean stats for a list of rollout trees.""" + # TODO complete this + stat_tallies = [ + get_rollout_tree_mean_stats(rollout_tree, metrics) + for rollout_tree in rollout_trees + ] + mean_stat_tally = StatPack() + for stat_tally in stat_tallies: + mean_stat_tally.add_stats(stat_tally) + return mean_stat_tally.mean() diff --git a/src_code_for_reproducibility/utils/short_id_gen.py b/src_code_for_reproducibility/utils/short_id_gen.py new file mode 100644 index 0000000000000000000000000000000000000000..bb275cad8fe5e0dd52bebb14e7424e98dcbaba54 --- /dev/null +++ b/src_code_for_reproducibility/utils/short_id_gen.py @@ -0,0 +1,11 @@ +import uuid + + +def generate_short_id() -> int: + """ + Generates a short unique ID for tracking adapter versions. + + Returns: + int: An 8-digit integer ID. + """ + return int(str(uuid.uuid4().int)[:8]) diff --git a/src_code_for_reproducibility/utils/stat_pack.py b/src_code_for_reproducibility/utils/stat_pack.py new file mode 100644 index 0000000000000000000000000000000000000000..46b397139a1a8a4149030a9cc33d2b3afb7b4a12 --- /dev/null +++ b/src_code_for_reproducibility/utils/stat_pack.py @@ -0,0 +1,113 @@ +import csv +import json +import os +import pickle +from collections import Counter +from copy import deepcopy +from locale import strcoll +from statistics import mean +from typing import Any, Dict, Iterator, List, Optional, Tuple, TypedDict + +import matplotlib.pyplot as plt +import numpy as np +from torch.utils.tensorboard import SummaryWriter + +plt.style.use( + "https://raw.githubusercontent.com/dereckpiche/DedeStyle/refs/heads/main/dedestyle.mplstyle" +) + +import wandb + +from . import wandb_utils + + +class StatPack: + def __init__(self): + self.data = {} + + def add_stat(self, key: str, value: float | int | None): + assert ( + isinstance(value, float) or isinstance(value, int) or value is None + ), f"Value {value} is not a valid type" + if key not in self.data: + self.data[key] = [] + self.data[key].append(value) + + def add_stats(self, other: "StatPack"): + for key in other.keys(): + self.add_stat(key, other[key]) + + def __getitem__(self, key: str): + return self.data[key] + + def __setitem__(self, key: str, value: Any): + self.data[key] = value + + def __contains__(self, key: str): + return key in self.data + + def __len__(self): + return len(self.data) + + def __iter__(self): + return iter(self.data) + + def keys(self): + return self.data.keys() + + def values(self): + return self.data.values() + + def items(self): + return self.data.items() + + def mean(self): + mean_st = StatPack() + for key in self.keys(): + if isinstance(self[key], list): + # TODO: exclude None values + non_none_values = [v for v in self[key] if v is not None] + if non_none_values: + mean_st[key] = np.mean(np.array(non_none_values)) + else: + mean_st[key] = None + return mean_st + + def store_plots(self, folder: str): + os.makedirs(folder, exist_ok=True) + for key in self.keys(): + plt.figure(figsize=(10, 5)) + plt.plot(self[key]) + plt.title(key) + plt.savefig(os.path.join(folder, f"{key}.pdf")) + plt.close() + + def store_numpy(self, folder: str): + os.makedirs(folder, exist_ok=True) + for key in self.keys(): + # Sanitize filename components (avoid slashes, spaces, etc.) + safe_key = str(key).replace(os.sep, "_").replace("/", "_").replace(" ", "_") + values = self[key] + # Convert None to NaN for numpy compatibility + arr = np.array( + [(np.nan if (v is None) else v) for v in values], dtype=float + ) + np.save(os.path.join(folder, f"{safe_key}.npy"), arr) + + def store_json(self, folder: str, filename: str = "stats.json"): + os.makedirs(folder, exist_ok=True) + with open(os.path.join(folder, filename), "w") as f: + json.dump(self.data, f, indent=4) + + def store_csv(self, folder: str): + os.makedirs(folder, exist_ok=True) + for key in self.keys(): + with open(os.path.join(folder, f"stats.csv"), "w") as f: + writer = csv.writer(f) + writer.writerow([key] + self[key]) + + def store_pickle(self, folder: str): + os.makedirs(folder, exist_ok=True) + for key in self.keys(): + with open(os.path.join(folder, f"stats.pkl"), "wb") as f: + pickle.dump(self[key], f) diff --git a/src_code_for_reproducibility/utils/update_start_epoch.py b/src_code_for_reproducibility/utils/update_start_epoch.py new file mode 100644 index 0000000000000000000000000000000000000000..036ddce31b12e7a6547c5099dd37962a88055643 --- /dev/null +++ b/src_code_for_reproducibility/utils/update_start_epoch.py @@ -0,0 +1,9 @@ +import os + +# During run, set hydra.run.dir=./outputs/{folder} +def update_start_epoch(cfg, output_directory): + if cfg["experiment"]["resume_experiment"]: + folders = [f for f in os.listdir(output_directory) if f.startswith("iteration_")] + iterations = [int(f.split("_")[1]) for f in folders] if folders else [0] + cfg["experiment"]["start_epoch"] = max(iterations) + return None diff --git a/src_code_for_reproducibility/utils/wandb_utils.py b/src_code_for_reproducibility/utils/wandb_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..3e5d83ed1a5304f208288f78457582b0acdb58c4 --- /dev/null +++ b/src_code_for_reproducibility/utils/wandb_utils.py @@ -0,0 +1,164 @@ +import os +from typing import Any, Dict, Optional + + +_WANDB_AVAILABLE = False +_WANDB_RUN = None + + +def _try_import_wandb(): + global _WANDB_AVAILABLE + if _WANDB_AVAILABLE: + return True + try: + import wandb # type: ignore + + _WANDB_AVAILABLE = True + return True + except Exception: + _WANDB_AVAILABLE = False + return False + + +def _safe_get(cfg: Dict[str, Any], path: list[str], default: Any = None) -> Any: + cur: Any = cfg + for key in path: + if not isinstance(cur, dict) or key not in cur: + return default + cur = cur[key] + return cur + + +def is_enabled(cfg: Dict[str, Any]) -> bool: + return bool(_safe_get(cfg, ["logging", "wandb", "enabled"], False)) + + +def init(cfg: Dict[str, Any], run_dir: str, run_name: Optional[str] = None) -> None: + """ + Initialize Weights & Biases if enabled in config. No-op if disabled or wandb not installed. + """ + global _WANDB_RUN + if not is_enabled(cfg): + return + if not _try_import_wandb(): + return + + import wandb # type: ignore + + project = _safe_get(cfg, ["logging", "wandb", "project"], "llm-negotiation") + entity = _safe_get(cfg, ["logging", "wandb", "entity"], None) + mode = _safe_get(cfg, ["logging", "wandb", "mode"], "online") + tags = _safe_get(cfg, ["logging", "wandb", "tags"], []) or [] + notes = _safe_get(cfg, ["logging", "wandb", "notes"], None) + group = _safe_get(cfg, ["logging", "wandb", "group"], None) + name = _safe_get(cfg, ["logging", "wandb", "name"], run_name) + + # Ensure files are written into the hydra run directory + os.makedirs(run_dir, exist_ok=True) + os.environ.setdefault("WANDB_DIR", run_dir) + + # Convert cfg to plain types for W&B config; fallback to minimal dictionary + try: + from omegaconf import OmegaConf # type: ignore + + cfg_container = OmegaConf.to_container(cfg, resolve=True) # type: ignore + except Exception: + cfg_container = cfg + + _WANDB_RUN = wandb.init( + project=project, + entity=entity, + mode=mode, + name=name, + group=group, + tags=tags, + notes=notes, + config=cfg_container, + dir=run_dir, + reinit=True, + ) + + +def log(metrics: Dict[str, Any], step: Optional[int] = None) -> None: + """Log a flat dictionary of metrics to W&B if active.""" + if not _WANDB_AVAILABLE or _WANDB_RUN is None: + return + try: + import wandb # type: ignore + + wandb.log(metrics if step is None else dict(metrics, step=step)) + except Exception: + pass + + +def _flatten(prefix: str, data: Dict[str, Any], out: Dict[str, Any]) -> None: + for k, v in data.items(): + key = f"{prefix}.{k}" if prefix else k + if isinstance(v, dict): + _flatten(key, v, out) + else: + out[key] = v + + +def _summarize_value(value: Any) -> Dict[str, Any]: + import numpy as np # local import to avoid hard dependency during disabled mode + + if value is None: + return {"none": 1} + # Scalars + if isinstance(value, (int, float)): + return {"value": float(value)} + # Lists or arrays + try: + arr = np.asarray(value) + if arr.size == 0: + return {"size": 0} + return { + "mean": float(np.nanmean(arr)), + "min": float(np.nanmin(arr)), + "max": float(np.nanmax(arr)), + "last": float(arr.reshape(-1)[-1]), + "size": int(arr.size), + } + except Exception: + # Fallback: string repr + return {"text": str(value)} + + +def log_tally(array_tally: Dict[str, Any], prefix: str = "", step: Optional[int] = None) -> None: + """ + Flatten and summarize Tally.array_tally and log to WandB. + Each leaf list/array is summarized with mean/min/max/last/size. + """ + if not _WANDB_AVAILABLE or _WANDB_RUN is None: + return + summarized: Dict[str, Any] = {} + + def walk(node: Any, path: list[str]): + if isinstance(node, dict): + for k, v in node.items(): + walk(v, path + [k]) + return + # node is a list of values accumulated over time + key = ".".join([p for p in ([prefix] if prefix else []) + path]) + try: + summary = _summarize_value(node) + for sk, sv in summary.items(): + summarized[f"{key}.{sk}"] = sv + except Exception: + summarized[f"{key}.error"] = 1 + + walk(array_tally, []) + if summarized: + log(summarized, step=step) + + +def log_flat_stats(stats: Dict[str, Any], prefix: str = "", step: Optional[int] = None) -> None: + if not _WANDB_AVAILABLE or _WANDB_RUN is None: + return + flat: Dict[str, Any] = {} + _flatten(prefix, stats, flat) + if flat: + log(flat, step=step) + +