import transformers from transformers import AutoModel, AutoTokenizer from peft import LoraConfig,get_peft_model from model.modeling_llada import LLaDAModelLM from model.configuration_llada import LLaDAConfig def get_model_by_config(config): """Select different models based on config file""" training_mode = config.get('training_mode', 'dream') if training_mode == 'llada': return get_llada(config) elif training_mode == 'dream': return get_model(config) else: raise ValueError(f"Unsupported training mode: {training_mode}") def get_model(config): # Use path from config, use default path if no config model_path = config.paths.model if hasattr(config, 'paths') and hasattr(config.paths, 'model') else "/home/wx/data/model/Dream-org/Dream-v0-Base-7B" model = AutoModel.from_pretrained(model_path, trust_remote_code=True) # print(model.named_modules()) # print(model,"model for param in model.parameters(): param.requires_grad = False tokenizer=AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) peft_config = LoraConfig(r=32, lora_alpha=32, lora_dropout=0.1,target_modules=["q_proj", "v_proj","k_proj", "o_proj"],) model = get_peft_model(model, peft_config) model.print_trainable_parameters() return model, tokenizer def get_llada(config): # Use path from config, use default path if no config model_path = config.paths.model if hasattr(config, 'paths') and hasattr(config.paths, 'model') else "/data1/xck/models/llada-8b-instruct" config_obj=LLaDAConfig.from_pretrained(model_path) model = LLaDAModelLM.from_pretrained(model_path,config=config_obj) # print(model.named_modules()) # print(model,"model # print(model) # exit() for param in model.parameters(): param.requires_grad = False tokenizer=AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) peft_config = LoraConfig(r=32, lora_alpha=32, lora_dropout=0.1,target_modules=["q_proj", "v_proj","k_proj", "attn_out"],) model = get_peft_model(model, peft_config) model.print_trainable_parameters() return model, tokenizer # def create_attention_mask(input_ids, mask_id): # """ # Create an attention mask based on the input_ids and mask_id. # Args: # input_ids (torch.Tensor): The input tensor of shape (batch_size, sequence_length). # mask_id (int): The ID of the mask token. # Returns: # torch.Tensor: The attention mask of shape (batch_size, sequence_length, sequence_length).