Spaces:
No application file
No application file
| import transformers | |
| from transformers import AutoModel, AutoTokenizer | |
| from peft import LoraConfig,get_peft_model | |
| from model.modeling_llada import LLaDAModelLM | |
| from model.configuration_llada import LLaDAConfig | |
| def get_model_by_config(config): | |
| """Select different models based on config file""" | |
| training_mode = config.get('training_mode', 'dream') | |
| if training_mode == 'llada': | |
| return get_llada(config) | |
| elif training_mode == 'dream': | |
| return get_model(config) | |
| else: | |
| raise ValueError(f"Unsupported training mode: {training_mode}") | |
| def get_model(config): | |
| # Use path from config, use default path if no config | |
| model_path = config.paths.model if hasattr(config, 'paths') and hasattr(config.paths, 'model') else "/home/wx/data/model/Dream-org/Dream-v0-Base-7B" | |
| model = AutoModel.from_pretrained(model_path, trust_remote_code=True) | |
| # print(model.named_modules()) | |
| # print(model,"model | |
| for param in model.parameters(): | |
| param.requires_grad = False | |
| tokenizer=AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) | |
| peft_config = LoraConfig(r=32, lora_alpha=32, lora_dropout=0.1,target_modules=["q_proj", "v_proj","k_proj", "o_proj"],) | |
| model = get_peft_model(model, peft_config) | |
| model.print_trainable_parameters() | |
| return model, tokenizer | |
| def get_llada(config): | |
| # Use path from config, use default path if no config | |
| model_path = config.paths.model if hasattr(config, 'paths') and hasattr(config.paths, 'model') else "/data1/xck/models/llada-8b-instruct" | |
| config_obj=LLaDAConfig.from_pretrained(model_path) | |
| model = LLaDAModelLM.from_pretrained(model_path,config=config_obj) | |
| # print(model.named_modules()) | |
| # print(model,"model | |
| # print(model) | |
| # exit() | |
| for param in model.parameters(): | |
| param.requires_grad = False | |
| tokenizer=AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) | |
| peft_config = LoraConfig(r=32, lora_alpha=32, lora_dropout=0.1,target_modules=["q_proj", "v_proj","k_proj", "attn_out"],) | |
| model = get_peft_model(model, peft_config) | |
| model.print_trainable_parameters() | |
| return model, tokenizer | |
| # def create_attention_mask(input_ids, mask_id): | |
| # """ | |
| # Create an attention mask based on the input_ids and mask_id. | |
| # Args: | |
| # input_ids (torch.Tensor): The input tensor of shape (batch_size, sequence_length). | |
| # mask_id (int): The ID of the mask token. | |
| # Returns: | |
| # torch.Tensor: The attention mask of shape (batch_size, sequence_length, sequence_length). | |