| import torch |
| import torch.distributed as dist |
| import torch.multiprocessing as mp |
| from transformers import AutoTokenizer, LlamaForCausalLM |
| from torch.nn.parallel import DistributedDataParallel as DDP |
| from evalplus.data import get_human_eval_plus, write_jsonl |
| import os |
| from tqdm import tqdm |
|
|
| def setup(rank, world_size): |
| os.environ['MASTER_ADDR'] = 'localhost' |
| os.environ['MASTER_PORT'] = '12355' |
| dist.init_process_group("gloo", rank=rank, world_size=world_size) |
|
|
| def cleanup(): |
| dist.destroy_process_group() |
|
|
| def generate_one_completion(ddp_model, tokenizer, prompt: str): |
| tokenizer.pad_token = tokenizer.eos_token |
| inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=4096) |
|
|
| |
| generate_ids = ddp_model.module.generate(inputs.input_ids.to("cuda"), max_new_tokens=384, do_sample=True, top_p=0.75, top_k=40, temperature=0.1, pad_token_id=tokenizer.eos_token_id) |
| completion = tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] |
| completion = completion.replace(prompt, "").split("\n\n\n")[0] |
| |
| print("-------------------") |
| print(completion) |
| return completion |
|
|
| def run(rank, world_size): |
| setup(rank, world_size) |
|
|
| model_path = "Nondzu/Mistral-7B-codealpaca-lora" |
| model = LlamaForCausalLM.from_pretrained(model_path,load_in_8bit=True) |
| ddp_model = DDP(model, device_ids=[rank]) |
| tokenizer = AutoTokenizer.from_pretrained(model_path) |
|
|
| problems = get_human_eval_plus() |
| num_samples_per_task = 1 |
|
|
| samples = [ |
| dict(task_id=task_id, completion=generate_one_completion(ddp_model, tokenizer, problems[task_id]["prompt"])) |
| for task_id in tqdm(problems) |
| for _ in range(num_samples_per_task) |
| ] |
| write_jsonl(f"samples-Nondzu-Mistral-7B-codealpaca-lora-rank{rank}.jsonl", samples) |
|
|
| cleanup() |
| |
| def main(): |
| world_size = 1 |
| mp.spawn(run, args=(world_size,), nprocs=world_size, join=True) |
|
|
| if __name__=="__main__": |
| main() |
|
|