Spaces:
Sleeping
Sleeping
| import torch | |
| from dataset import get_data_loaders_phase1, get_data_loaders_phase2 | |
| from transformers import AutoTokenizer | |
| from model import CustomClipPhi2, MainQLoraModel, train_model_phase1, train_model_phase2 | |
| from configs import get_config_phase1, get_config_phase2 | |
| def phase_1(): | |
| # get config | |
| config = get_config_phase1() | |
| # tokenizer | |
| tokenizer = AutoTokenizer.from_pretrained(config.get("phi2_model_name"), trust_remote_code=True) | |
| # data loaders | |
| train_dataloader, val_dataloader = get_data_loaders_phase1(config.get("data_dir"), config.get("clip_model_name"), tokenizer, config.get("train_batch_size"), config.get("val_batch_size"), config.get("num_workers")) | |
| llmModel = CustomClipPhi2(tokenizer, config.get("phi2_model_name"), config.get("clip_model_name"), clip_embed=768, phi_embed=2560).to(config.get("device")) | |
| print(llmModel) | |
| # optimizer | |
| optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, llmModel.parameters()), lr=1e-3) | |
| # train model | |
| train_model_phase1(llmModel, train_dataloader, val_dataloader, optimizer, tokenizer, config) | |
| def phase_2(): | |
| # get config | |
| config = get_config_phase2() | |
| # tokenizer | |
| tokenizer = AutoTokenizer.from_pretrained(config.get("phi2_model_name"), trust_remote_code=True) | |
| # data loaders | |
| train_dataloader, val_dataloader = get_data_loaders_phase2(tokenizer, config) | |
| llmModel = MainQLoraModel(tokenizer, config).to(config.get("device")) | |
| print(llmModel) | |
| # train model | |
| train_model_phase2(llmModel, train_dataloader, val_dataloader, tokenizer, config) | |
| if __name__ == "__main__": | |
| torch.set_float32_matmul_precision('medium') | |
| phase_1() | |
| # phase_2() |