| Inference code: | |
| ```python | |
| from datasets import load_dataset | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| dataset = load_dataset("CarperAI/openai_summarize_tldr") | |
| val_prompts = [sample["prompt"] for sample in dataset["valid"]] | |
| kwargs = { | |
| "max_new_tokens": 50, | |
| "do_sample": True, | |
| "top_k": 0, | |
| "top_p": 1, | |
| } | |
| model = AutoModelForCausalLM.from_pretrained("pvduy/ppo_pythia6B_sample") | |
| model.eval() | |
| tokenizer = AutoTokenizer.from_pretrained("pvduy/ppo_pythia6B_sample") | |
| tokenizer.pad_token_id = tokenizer.eos_token_id | |
| count = 0 | |
| for prompt in val_prompts: | |
| output_tk = tokenizer(prompt, return_tensors="pt") | |
| outputs = model.generate(output_tk.input_ids, attention_mask=output_tk.attention_mask, **kwargs) | |
| print("Prompt:", prompt) | |
| print("Output:", tokenizer.decode(outputs[0], skip_special_tokens=True).split("TL;DR:")[1].strip()) | |
| print("=================================") | |
| count += 1 | |
| if count == 10: | |
| break | |
| ``` |