| | |
| | |
| |
|
| | from typing import List |
| |
|
| | import fire |
| |
|
| | from llama import Llama |
| | import json |
| |
|
| | def read_json(file_path): |
| | with open(file_path, 'r', encoding='utf-8') as file: |
| | data = json.load(file) |
| | return data |
| |
|
| | def write_json(file_path, data): |
| | with open(file_path, 'w', encoding='utf-8') as file: |
| | json.dump(data, file, ensure_ascii=False, indent=4) |
| |
|
| | def main( |
| | ckpt_dir: str, |
| | tokenizer_path: str, |
| | temperature: float = 0.6, |
| | top_p: float = 0.9, |
| | max_seq_len: int = 128, |
| | max_gen_len: int = 64, |
| | max_batch_size: int = 4, |
| | json_path: str = None, |
| | ): |
| | """ |
| | Examples to run with the pre-trained models (no fine-tuning). Prompts are |
| | usually in the form of an incomplete text prefix that the model can then try to complete. |
| | |
| | The context window of llama3 models is 8192 tokens, so `max_seq_len` needs to be <= 8192. |
| | `max_gen_len` is needed because pre-trained models usually do not stop completions naturally. |
| | """ |
| | generator = Llama.build( |
| | ckpt_dir=ckpt_dir, |
| | tokenizer_path=tokenizer_path, |
| | max_seq_len=max_seq_len, |
| | max_batch_size=max_batch_size, |
| | ) |
| | with open(json_path) as f: |
| | data = json.load(f) |
| | |
| | ans = [] |
| | begin, end,batch_size = 0,len(data),max_batch_size |
| | for batch_idx in tqdm(range(begin, end, max_batch_size)): |
| | up = min(batch_idx + max_batch_size, end) |
| | batch = data[batch_idx:up] |
| | print(f"batch {batch_idx} to {up}") |
| |
|
| | text_batch = [] |
| | for idx,i in enumerate(batch): |
| | text_batch.append(idx) |
| | res = generator.text_completion( |
| | text_batch, |
| | max_gen_len=max_gen_len, |
| | temperature=temperature, |
| | top_p=top_p, |
| | ) |
| | ans.append(res) |
| | cnt = cnt + 1 |
| | if cnt % 10 == 0: |
| | print(f"batch {cnt} done") |
| | write_json(ans, "ans.json") |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| |
|
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| |
|
| | if __name__ == "__main__": |
| | fire.Fire(main) |
| |
|