import os
import re
import pathlib
from argparse import ArgumentParser
from typing import List, Dict, Optional
from dataclasses import dataclass, field
import torch
from torch import nn
import torch.nn.functional as F
from torch.optim import AdamW
from torch.utils.data import DataLoader, Dataset
from transformers import get_cosine_schedule_with_warmup, AutoTokenizer
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
AutoModelForMaskedLM,
AutoProcessor,
)
from datasets import load_dataset, DatasetDict
from peft import get_peft_model, LoraConfig, prepare_model_for_kbit_training
from transformers import BitsAndBytesConfig
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor
from pytorch_lightning.loggers import WandbLogger
from trl import GRPOConfig, GRPOTrainer, ModelConfig, ScriptArguments, TrlParser, get_peft_config
#from unsloth import FastLanguageModel, is_bfloat16_supported
from bioreason.models.dna_llm import DNALLMModel
from bioreason.models.protein_llm import ProteinLLMModel
from bioreason.dna_modules import NucleotideDNAModule
from bioreason.models.dl.processing_dl import DLProcessor
from bioreason.trainer import DNALLMGRPOTrainer, DNALLMGRPOConfig
from bioreason.models.evo2_tokenizer import Evo2Tokenizer, register_evo2_tokenizer
register_evo2_tokenizer()
# Custom TrainerCallback to override the saving mechanism
from transformers import TrainerCallback, TrainerState, TrainerControl
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
class SaveWithPyTorchCallback(TrainerCallback):
"""Custom callback to save models with PyTorch's native save mechanism instead of safetensors"""
def on_save(self, args, state, control, **kwargs):
# Get the checkpoint folder
checkpoint_folder = os.path.join(
args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}"
)
os.makedirs(checkpoint_folder, exist_ok=True)
# Save with PyTorch instead of safetensors
checkpoint_path = os.path.join(checkpoint_folder, "pytorch_model.bin")
model = kwargs.get("model")
# Get model unwrapped from accelerator etc.
unwrapped_model = model.module if hasattr(model, "module") else model
# Save using PyTorch directly
torch.save(unwrapped_model.state_dict(), checkpoint_path)
# DNALLMModel doesn't have a direct config attribute, so we need to save
# the configs of its sub-models
if hasattr(unwrapped_model, "text_model"):
if hasattr(unwrapped_model.text_model, "config"):
unwrapped_model.text_model.config.save_pretrained(checkpoint_folder)
# Handle PEFT models which might have base_model
elif hasattr(unwrapped_model.text_model, "base_model") and hasattr(unwrapped_model.text_model.base_model, "config"):
unwrapped_model.text_model.base_model.config.save_pretrained(checkpoint_folder)
# Print info about what's being saved
print(f"Saved model checkpoint to {checkpoint_folder}")
lora_params = [k for k in unwrapped_model.state_dict().keys() if "lora" in k]
print(f"Checkpoint contains {len(lora_params)} LoRA parameters")
# Signal that we've saved
control.should_save = False
return control
def _get_target_modules(model: ProteinLLMModel):
# Apply LoRA to all linear layers in the text model
target_modules = []
# Get all unique linear layer names
seen_names = set()
for name, module in model.text.named_modules():
if isinstance(module, torch.nn.Linear):
names = name.split(".")
target_name = names[-1] # Use the last part of the name
# Skip output head but include all other linear layers
if target_name != "lm_head" and target_name not in seen_names:
target_modules.append(target_name)
seen_names.add(target_name)
# Add attention-specific layers commonly found in transformers
attention_patterns = [
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"out_proj",
"query",
"key",
"value",
"gate_proj",
"up_proj",
"down_proj",
]
for pattern in attention_patterns:
if pattern not in seen_names:
target_modules.append(pattern)
# Return all unique layer names to apply LoRA to all layers
return list(target_modules)
def extract_xml_answer(text: str) -> str:
# answer = text.split("")[-1]
# answer = answer.split("")[0]
answer = text.split("")[-1]
return answer.strip()
def extract_hash_answer(text: str) -> str | None:
if "####" not in text:
return None
return text.split("####")[1].strip()
def get_kegg_questions() -> Dataset:
data = load_dataset('wanglab/kegg', 'default') # type: ignore
# 修改为蛋白质序列示例
example_protein_sequences = ["MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLAGG",
"MKTAYIAKQRQISFVKSHFSRQLEERLGLIEVQAPILSRVGDGTQDNLSGAEKAVQVKVKALPDAQFEVVHSLAKWKRISSKLLERGKTHYPPHTMVGTGVLVTKMRVAGQEPDVQGPHAGIVVQGAGDAPVVVKPVVEMLNRMVVVVSGSAAPVVVNNNNNGAAAAAAA",
"MSQVQVQVQNQALNTLVKQLGRVLLQGKGRPPLQGFRIIEQNGGDSPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPP"]
num_protein_sequences = 2
data = data.map(lambda x: { # type: ignore
'prompt': [
{
'role': 'user',
'content': [
*({'type': 'protein', 'text': None} for _ in range(num_protein_sequences)),
{'type': 'text', 'text': x['question']},
],
},
],
'protein_sequences': [example_protein_sequences[0], example_protein_sequences[1]], # 使用蛋白质序列
'answer': x['answer'],
}) # type: ignore
return data
# uncomment middle messages for 1-shot prompting
# Reward functions
def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]:
responses = [completion[0]['content'] for completion in completions]
q = prompts[0][-1]['content']
extracted_responses = [extract_xml_answer(r) for r in responses]
# extracted_responses = [r.lower().replace("answer:", "").strip() for r in extracted_responses]
print('-'*20, f"Question:\n{q}", f"\nAnswer:\n{answer[0]}", f"\nResponse:\n{responses[0]}", f"\nExtracted:\n{extracted_responses[0]}")
return [2.0 if a.lower() in r.lower() else 0.0 for r, a in zip(extracted_responses, answer[0])]
def less_than_4_reward_func(completions, **kwargs) -> list[float]:
responses = [completion[0]['content'] for completion in completions]
extracted_responses = [extract_xml_answer(r) for r in responses]
return [0.5 if len(r.split(' ')) <= 4 else 0.0 for r in extracted_responses]
def strict_format_reward_func(completions, **kwargs) -> list[float]:
"""Reward function that checks if the completion has a specific format."""
pattern = r"^\n.*?\n\n.*?\n$"
responses = [completion[0]["content"] for completion in completions]
matches = [re.match(pattern, r) for r in responses]
return [0.5 if match else 0.0 for match in matches]
def soft_format_reward_func(completions, **kwargs) -> list[float]:
"""Reward function that checks if the completion has a specific format."""
pattern = r".*?\s*.*?"
responses = [completion[0]["content"] for completion in completions]
matches = [re.match(pattern, r) for r in responses]
return [0.5 if match else 0.0 for match in matches]
def count_xml(text) -> float:
count = 0.0
if text.count("\n") == 1:
count += 0.125
if text.count("\n\n") == 1:
count += 0.125
return count
def xmlcount_reward_func(completions, **kwargs) -> list[float]:
contents = [completion[0]["content"] for completion in completions]
return [count_xml(c) for c in contents]
#补充奖励函数
def repeatness_reward(s: str):
"""计算文本重复度,返回值越高表示重复度越低"""
def ranks(l):
index = {v: i for i, v in enumerate(sorted(set(l)))}
return [index[v] for v in l]
def suffixArray(s):
line = ranks(s)
n, k, ans, sa = len(s), 1, line, [0] * len(s)
while k < n - 1:
line = ranks(list(zip_longest(line, islice(line, k, None), fillvalue=-1)))
ans, k = line, k << 1
for i, k in enumerate(ans):
sa[k] = i
return ans, sa
def lcp(arr, suffixArr, inv_suff):
n, ans, k = len(arr), [0] * len(arr), 0
for i in range(n):
if inv_suff[i] == n - 1:
k = 0
continue
j = suffixArr[inv_suff[i] + 1]
while i + k < n and j + k < n and arr[i + k] == arr[j + k]:
k += 1
ans[inv_suff[i]] = k
if k > 0:
k -= 1
return ans
arr = [ord(i) for i in s]
n = len(arr)
if n <= 1:
return 0
c, sa = suffixArray(arr)
cnt = sum(lcp(arr, sa, c))
return 1 - cnt * 2 / (n * (n + 1))
def format_reward(predict_str: str) -> float:
"""
格式奖励函数,严格要求输出格式为:
......
中间不能有多余内容
"""
pattern = r'^.*?\s*\s*.*?\s*$'
return 1.0 if re.fullmatch(pattern, predict_str.strip(), re.DOTALL) else 0.0
def acc_reward(predict_str: str, ground_truth) -> float:
"""
准确率奖励函数
要求中内容与ground_truth完全一致
"""
match = re.search(r'\s*([^<]*?)\s*', predict_str)
if not match:
return 0.0
answer_content = match.group(1).strip()
# 处理不同类型的ground_truth
if isinstance(ground_truth, str):
return 1.0 if answer_content == ground_truth else 0.0
elif isinstance(ground_truth, (int, float)):
try:
# 尝试将答案转换为数字进行比较
return 1.0 if float(answer_content) == float(ground_truth) else 0.0
except ValueError:
# 如果转换失败,尝试字符串比较
return 1.0 if answer_content == str(ground_truth) else 0.0
else:
# 其他类型,转换为字符串比较
return 1.0 if answer_content == str(ground_truth) else 0.0
# 包装函数以适配现有的奖励函数接口
def repeatness_reward_func(completions, **kwargs) -> list[float]:
"""重复度奖励函数包装器"""
responses = [completion[0]['content'] for completion in completions]
return [repeatness_reward(r) for r in responses]
def format_reward_func(completions, **kwargs) -> list[float]:
"""格式奖励函数包装器"""
responses = [completion[0]['content'] for completion in completions]
return [format_reward(r) for r in responses]
def acc_reward_func(prompts, completions, answer, **kwargs) -> list[float]:
"""准确率奖励函数包装器"""
responses = [completion[0]['content'] for completion in completions]
# 调试信息
print(f"DEBUG acc_reward_func - answer type: {type(answer)}, answer: {answer}")
# 根据现有代码的模式,answer可能是一个嵌套结构
try:
if isinstance(answer, list) and len(answer) > 0:
# 如果answer[0]是一个列表,说明是批次数据
if isinstance(answer[0], list):
ground_truths = answer[0]
else:
# 如果answer[0]是单个值,为所有响应使用相同的真实答案
ground_truths = [answer[0]] * len(responses)
else:
# 如果answer不是期望的格式,返回全0
print(f"DEBUG: Unexpected answer format, returning zeros")
return [0.0] * len(responses)
except (IndexError, TypeError) as e:
print(f"DEBUG: Error processing answer: {e}, returning zeros")
return [0.0] * len(responses)
print(f"DEBUG: ground_truths: {ground_truths}")
# 确保responses和ground_truths长度一致
rewards = []
for i, response in enumerate(responses):
if i < len(ground_truths):
reward = acc_reward(response, ground_truths[i])
print(f"DEBUG: response {i}: '{response[:100]}...', ground_truth: '{ground_truths[i]}', reward: {reward}")
else:
# 如果ground_truths不够长,使用第一个值
reward = acc_reward(response, ground_truths[0] if ground_truths else "")
print(f"DEBUG: response {i} (fallback): reward: {reward}")
rewards.append(reward)
return rewards
#
# Format into conversation
def make_conversation(example):
return {
"prompt": [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": example["problem"]},
],
}
def make_conversation_image(example):
return {
"prompt": [
{
"role": "user",
"content": [
{"type": "image"},
],
},
],
}
@dataclass
class GRPOModelConfig(ModelConfig):
model_name_or_path: str = field(default="Qwen/Qwen3-0.6B", metadata={"help": "Model checkpoint for LLM weights initialization."})
protein_model_name_or_path: str = field(default="esm2_t33_650M_UR50D", metadata={"help": "Model checkpoint for ESM-2 protein weights initialization."})
cache_dir: str = field(default=None, metadata={"help": "Path to model cache directory."})
max_length_text: int = field(default=800, metadata={"help": "Maximum length of text sequences."})
max_length_protein: int = field(default=800, metadata={"help": "Maximum length of protein sequences (number of amino acids)."})
sft_checkpoint: str = field(default=None, metadata={"help": "Path to the checkpoint for SFT."})
lora_r: int = field(default=32, metadata={"help": "LoRA R value."})
lora_alpha: int = field(default=64, metadata={"help": "LoRA alpha."})
lora_dropout: float = field(default=0.05, metadata={"help": "LoRA dropout."})
lora_modules_to_save: Optional[list[str]] = field(
default_factory=lambda: ["embed_tokens", "lm_head"],
metadata={"help": "Model layers to unfreeze & train with LoRA."},
)
# Updated: Renamed `freeze_dna_modules` to `freeze_protein_model`
freeze_protein_model: bool = field(default=True, metadata={"help": "Whether to freeze the ESM-2 protein model during training."})
num_query_tokens: int = field(default=32, metadata={"help": "Number of query tokens for QFormer."})
qformer_num_layers: int = field(default=6, metadata={"help": "Number of layers in QFormer."})
qformer_num_heads: int = field(default=8, metadata={"help": "Number of attention heads in QFormer."})
qformer_dropout: float = field(default=0.1, metadata={"help": "Dropout rate for QFormer."})
@dataclass
class GRPOScriptArguments(ScriptArguments):
"""
Script arguments for the GRPO training script.
"""
dataset_name: str = field(default="wanglab/kegg", metadata={"help": "Dataset name with default."})
data_file_paths: str = field(
default=None,
metadata={"help": "Paths to data files, separated by ':'"},
)
arrow_cache_dir: str = field(
default=None,
metadata={"help": "Path to arrow cache directory"},
)
val_split_ratio: float = field(
default=0.0,
metadata={"help": "Ratio of validation split, default 0.0"},
)
reward_funcs: list[str] = field(
# 更新默认奖励函数列表,包含新的三个函数
default_factory=lambda: ["repeatness", "format", "acc", "xmlcount", "soft_format"],
metadata={"help": "List of reward functions. Possible values: 'repeatness', 'format', 'acc', 'xmlcount', 'soft_format', 'strict_format', 'less_than_4', 'correctness'"},
)
reward_funcs_registry = {
# "accuracy": accuracy_reward,
# "format": format_reward,
"repeatness": repeatness_reward_func,
"format": format_reward_func,
"acc": acc_reward_func,
"xmlcount": xmlcount_reward_func,
"soft_format": soft_format_reward_func,
"strict_format": strict_format_reward_func,
"less_than_4": less_than_4_reward_func,
"correctness": correctness_reward_func,
}
def get_vlm_module(model_name_or_path):
if any(mini_name in model_name_or_path.lower() for mini_name in ["qwen", "smol"]):
# 如果你有专门的蛋白质模块,使用它
try:
from bioreason.protein_modules import ProteinModule
return ProteinModule
except ImportError:
# 如果没有专门的蛋白质模块,检查DNAModule是否兼容
print("Warning: Using NucleotideDNAModule for protein processing. Consider creating a dedicated ProteinModule.")
return NucleotideDNAModule
else:
raise ValueError(f"Unsupported model: {model_name_or_path}")
def _prep_for_training(model: ProteinLLMModel, model_args, protein_model_finetune: bool = False) -> LoraConfig:
"""
准备ProteinLLMModel进行训练。
"""
# Freeze protein encoder parameters if not finetuning
if not protein_model_finetune:
for param in model.protein_model.parameters():
param.requires_grad = False
print("Frozen protein model parameters")
else:
print("Protein model parameters will be finetuned")
# Get target modules for LoRA
target_modules = _get_target_modules(model)
print(f"LoRA target modules: {target_modules}")
lora_config = LoraConfig(
r=model_args.lora_r,
lora_alpha=model_args.lora_alpha,
lora_dropout=model_args.lora_dropout,
target_modules=target_modules,
init_lora_weights="gaussian",
bias="none",
task_type="CAUSAL_LM",
)
# Prepare text model for training
model.text_model = prepare_model_for_kbit_training(model.text_model)
model.text_model = get_peft_model(model.text_model, lora_config)
# Make QFormer projection layer trainable
for param in model.protein_projection.parameters():
param.requires_grad = True
print("QFormer projection layer set as trainable")
# Print trainable parameters info
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
total_params = sum(p.numel() for p in model.parameters())
print(f"Trainable parameters: {trainable_params:,} / {total_params:,} ({100 * trainable_params / total_params:.2f}%)")
return lora_config
######################################################################
######################################################################
def main(script_args, training_args, model_args):
print(training_args.output_dir)
#pl.seed_everything(args.seed)
# os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
torch.cuda.empty_cache()
torch.set_float32_matmul_precision("medium")
# Initialize model
# Load tokenizer for target text
# tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path)
# tokenizer.pad_token = tokenizer.eos_token
# Load model
# model = ProteinLLMModel(
# text_model_name=model_args.model_name_or_path,
# dna_model_name=model_args.dna_model_name_or_path,
# cache_dir=model_args.cache_dir,
# max_length_text=model_args.max_length_text,
# max_length_dna=model_args.max_length_dna,
# text_model_finetune=True,
# dna_model_finetune=not model_args.freeze_dna_modules,
# debug=False,
# )
print("Initializing ProteinLLMModel...")
model = ProteinLLMModel(
text_model_name=model_args.model_name_or_path,
protein_model_name=model_args.protein_model_name_or_path,
biomedbert_model_name=getattr(model_args, 'biomedbert_model_name',
"microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract-fulltext"),
cache_dir=model_args.cache_dir,
max_length_text=model_args.max_length_text,
max_length_protein=model_args.max_length_protein,
text_model_finetune=True,
protein_model_finetune=not model_args.freeze_protein_modules,
biomedbert_finetune=getattr(model_args, 'biomedbert_finetune', True), # 新增:控制BiomedBERT微调
# Q-Former相关参数(简化了,因为直接使用BiomedBERT)
qformer_num_query_tokens=getattr(model_args, 'qformer_num_query_tokens', 8), # 重命名为qformer_num_query_tokens
)
# load checkpoint
if model_args.sft_checkpoint is not None:
print(f"Loading SFT checkpoint from {model_args.sft_checkpoint}")
# Determine if it's a directory (PEFT format) or file (PyTorch state dict)
is_directory = os.path.isdir(model_args.sft_checkpoint)
if is_directory:
# It's a PEFT checkpoint directory - load properly with PEFT
from peft import PeftModel
# First initialize the text model with PEFT
print("Loading as PEFT checkpoint directory")
model.text_model = PeftModel.from_pretrained(
model.text_model,
model_args.sft_checkpoint,
is_trainable=True
)
# Verify loaded adapters
print("Loaded LoRA adapters:", model.text_model.active_adapter)
# Optional: Merge weights into base model
print("Merging SFT LoRA weights into base model...")
model.text_model = model.text_model.merge_and_unload()
print("Successfully merged SFT knowledge into base model")
else:
# It's a PyTorch state dict file
print("Loading as PyTorch state dict file")
checkpoint = torch.load(model_args.sft_checkpoint, map_location="cpu")
# replace model.text_model with text_model for all in state dict
def new_key(k):
if k.startswith("=model."): return k[6:]
elif k.startswith("_forward_module."): return k[len("_forward_module."):]
else: return k
if "state_dict" in checkpoint:
magic = {new_key(k): v for k, v in checkpoint["state_dict"].items()}
elif "module" in checkpoint:
magic = {new_key(k): v for k, v in checkpoint["module"].items()}
elif isinstance(checkpoint, dict) and all(isinstance(k, str) for k in checkpoint.keys()):
# Direct state dict - the checkpoint itself is the state dict
print("Detected direct state dict format")
magic = {new_key(k): v for k, v in checkpoint.items()}
else:
raise ValueError(f"Unsupported checkpoint format: {model_args.sft_checkpoint}")
# Handle prefix mapping for different model architectures
lora_prefix = any("lora" in key for key in state_dict.keys())
if lora_prefix:
print("Detected LoRA weights in state dict")
# First prepare model for LoRA training
_prep_for_training(model, model_args, protein_model_finetune=model_args.freeze_protein_modules)
# Print diagnostic info
model_keys = set(model.state_dict().keys())
checkpoint_keys = set(state_dict.keys())
print(f"Model has {len(model_keys)} keys")
print(f"Checkpoint has {len(checkpoint_keys)} keys")
# Intelligent key mapping for different prefixes
new_state_dict = {}
for k, v in state_dict.items():
# Handle different common prefix patterns
if "base_model.model" in k and k not in model_keys:
new_k = k.replace("text_model.base_model.model", "text_model")
if new_k in model_keys:
new_state_dict[new_k] = v
continue
# Try removing/adding prefixes
if k.startswith("text_model.") and k not in model_keys:
new_k = "text_model.base_model.model." + k[len("text_model."):]
if new_k in model_keys:
new_state_dict[new_k] = v
continue
# Keep original key
new_state_dict[k] = v
state_dict = new_state_dict
print(f"After key mapping: {len(state_dict)} keys")
# Load state dict with missing/unexpected keys allowed
result = model.load_state_dict(state_dict, strict=False)
if len(result.unexpected_keys) > 0:
print(f"Sample unexpected keys: {result.unexpected_keys[:5]}")
if len(result.missing_keys) > 0:
print(f"Sample missing keys: {result.missing_keys[:5]}")
print(f"Loaded checkpoint with {len(result.missing_keys)} missing keys and {len(result.unexpected_keys)} unexpected keys")
else:
print("Standard weights detected - loading before LoRA setup")
# Handle shared memory issue for embedding weights
for key in list(state_dict.keys()):
if 'lm_head.weight' in key:
state_dict[key] = state_dict[key].clone()
# Load weights before setting up LoRA
result = model.load_state_dict(state_dict, strict=False)
print(f"Loaded checkpoint with {len(result.missing_keys)} missing keys and {len(result.unexpected_keys)} unexpected keys")
# Now prepare for LoRA training
_prep_for_training(model, model_args, protein_model_finetune=model_args.freeze_protein_modules)
else:
# No checkpoint, just prepare for training
_prep_for_training(model, model_args, protein_model_finetune=not model_args.freeze_protein_model)
# Get reward functions
reward_funcs = [reward_funcs_registry[func] for func in script_args.reward_funcs]
# reward_funcs = [
# xmlcount_reward_func,
# soft_format_reward_func,
# strict_format_reward_func,
# int_reward_func,
# correctness_reward_func,
# ]
print("reward_funcs:", [func.__name__ for func in reward_funcs])
vlm_module_cls = get_vlm_module(model_args.model_name_or_path)
print("using vlm module:", vlm_module_cls.__name__)
question_prompt = vlm_module_cls.get_question_template()
dataset = get_kegg_questions()
#dataset = get_gsm8k_questions(question_prompt)
print(dataset)
#print('ITEM ONE OF THE DATASET', dataset['train'][0])
# Custom callback to handle saving with PyTorch's native mechanism
custom_save_callback = SaveWithPyTorchCallback()
# Initialize the GRPO trainer with custom callback
trainer = DNALLMGRPOTrainer(
model=model,
reward_funcs=reward_funcs,
args=training_args,
dna_module=vlm_module_cls(),
train_dataset=dataset['train'],
eval_dataset=dataset['val'] if training_args.eval_strategy != "no" else None,
peft_config=get_peft_config(model_args),
attn_implementation=model_args.attn_implementation,
torch_dtype=model_args.torch_dtype,
callbacks=[custom_save_callback], # Add our custom callback
)
# Set the trainer to save in PyTorch format instead of safetensors
training_args.save_safetensors = False
# Train and push the model to the Hub
# if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")):
# trainer.train(resume_from_checkpoint=True)
# else:
# trainer.train()
# Train and push the model to the Hub
trainer.train()
if __name__ == "__main__":
# os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"
print(f"CUDA_VISIBLE_DEVICES: {os.environ.get('CUDA_VISIBLE_DEVICES')}")
parser = TrlParser((GRPOScriptArguments, DNALLMGRPOConfig, GRPOModelConfig))
script_args, training_args, model_args = parser.parse_args_and_config()
# Ensure we use PyTorch's save mechanism instead of safetensors
training_args.save_safetensors = False
main(script_args, training_args, model_args)
# parser.add_argument("--wandb_project", type=str, default="dna-text-finetune")
# parser.add_argument("--wandb_entity", type=str, default="adibvafa")
# args = parser.parse_args()