| import os | |
| import torch | |
| import random | |
| import numpy as np | |
| import argparse | |
| import json | |
| import cohere | |
| from openai import OpenAI | |
| from tqdm import tqdm | |
| from collections import Counter | |
| from transformers import LlamaForCausalLM, AutoTokenizer, AutoModelForCausalLM, AutoModelForSeq2SeqLM | |
| import hashlib | |
| OPENAI_TOKEN = "" | |
| COHERE_TOKEN = "" | |
| HF_TOKEN = "" | |
| def argmax(array): | |
| """argmax with deterministic pseudorandom tie breaking.""" | |
| max_indices = np.arange(len(array))[array == np.max(array)] | |
| idx = int(hashlib.sha256(np.asarray(array).tobytes()).hexdigest(), 16) % len(max_indices) | |
| return max_indices[idx] | |
| def logsumexp(x): | |
| c = x.max() | |
| return c + np.log(np.sum(np.exp(x - c))) | |
| def normalize(x): | |
| x = np.array(x) | |
| return np.exp(x - logsumexp(x)) | |
| def set_seed(seed): | |
| random.seed(seed) | |
| np.random.seed(seed) | |
| torch.manual_seed(seed) | |
| torch.cuda.manual_seed(seed) | |
| def get_commandr_chat_response(gen_model, gen_model_checkpoint, text, seed): | |
| response = gen_model.chat( | |
| model="command-r", | |
| message=text, | |
| temperature=0, | |
| max_tokens=64, | |
| seed=seed, | |
| p=1 | |
| ) | |
| return response.text | |
| def get_mt0_response(gen_model, tokenizer, gen_model_checkpoint, text, seed): | |
| input_ids = tokenizer.encode(text, return_tensors="pt").to(gen_model.device) | |
| outputs = gen_model.generate( | |
| input_ids, | |
| max_new_tokens=10, | |
| do_sample=True, | |
| temperature=0.2, | |
| top_p=1 | |
| ) | |
| response = outputs[0] | |
| return tokenizer.decode(response, skip_special_tokens=True) | |
| def get_gemma_response(gen_model, tokenizer, gen_model_checkpoint, text, seed): | |
| messages = [ | |
| {"role": "user", "content": text}, | |
| ] | |
| input_ids = tokenizer.apply_chat_template( | |
| messages, | |
| add_generation_prompt=True, | |
| return_tensors="pt" | |
| ).to(gen_model.device) | |
| terminators = [ | |
| tokenizer.eos_token_id, | |
| tokenizer.convert_tokens_to_ids("<|eot_id|>") | |
| ] | |
| outputs = gen_model.generate( | |
| input_ids, | |
| max_new_tokens=10, | |
| eos_token_id=terminators, | |
| do_sample=True, | |
| temperature=0.2, | |
| top_p=1 | |
| ) | |
| response = outputs[0][input_ids.shape[-1]:] | |
| return tokenizer.decode(response, skip_special_tokens=True) | |
| def get_mistral_instruct_chat_response(gen_model, tokenizer, gen_model_checkpoint, text, seed): | |
| messages = [ | |
| {"role": "user", "content": text}, | |
| ] | |
| input_ids = tokenizer.apply_chat_template( | |
| messages, | |
| add_generation_prompt=True, | |
| return_tensors="pt" | |
| ).to(gen_model.device) | |
| terminators = [ | |
| tokenizer.eos_token_id, | |
| tokenizer.convert_tokens_to_ids("<|eot_id|>") | |
| ] | |
| outputs = gen_model.generate( | |
| input_ids, | |
| max_new_tokens=10, | |
| eos_token_id=terminators, | |
| do_sample=True, | |
| temperature=0.2, | |
| top_p=1 | |
| ) | |
| response = outputs[0][input_ids.shape[-1]:] | |
| return tokenizer.decode(response, skip_special_tokens=True) | |
| def get_llama3_instruct_chat_response(gen_model, tokenizer, gen_model_checkpoint, text, seed): | |
| messages = [ | |
| {"role": "user", "content": text}, | |
| ] | |
| input_ids = tokenizer.apply_chat_template( | |
| messages, | |
| add_generation_prompt=True, | |
| return_tensors="pt" | |
| ).to(gen_model.device) | |
| terminators = [ | |
| tokenizer.eos_token_id, | |
| tokenizer.convert_tokens_to_ids("<|eot_id|>") | |
| ] | |
| outputs = gen_model.generate( | |
| input_ids, | |
| max_new_tokens=10, | |
| eos_token_id=terminators, | |
| do_sample=True, | |
| temperature=0.2, | |
| top_p=1 | |
| ) | |
| response = outputs[0][input_ids.shape[-1]:] | |
| return tokenizer.decode(response, skip_special_tokens=True) | |
| def get_openai_chat_response(gen_model, gen_model_checkpoint, text, seed): | |
| messages = [ | |
| { | |
| "role": "user", | |
| "content": text | |
| } | |
| ] | |
| response = gen_model.chat.completions.create( | |
| model=gen_model_checkpoint, | |
| messages=messages, | |
| temperature=0, | |
| max_tokens=64, | |
| top_p=1, | |
| seed=seed | |
| ) | |
| return response.choices[0].message.content | |
| def load_model(gen_model_checkpoint, load_in_8bit=False): | |
| gen_model = None | |
| tokenizer = None | |
| if "mistralai/Mistral-7B-Instruct-v0.3" in gen_model_checkpoint or "meta-llama/Meta-Llama-3-8B-Instruct" in gen_model_checkpoint or "google/gemma-1.1-7b-it" in gen_model_checkpoint: | |
| if load_in_8bit: | |
| gen_model = AutoModelForCausalLM.from_pretrained(gen_model_checkpoint, token=HF_TOKEN, device_map="auto", | |
| load_in_8bit=True) | |
| tokenizer = AutoTokenizer.from_pretrained(gen_model_checkpoint, token=HF_TOKEN, device_map="auto", | |
| load_in_8bit=True) | |
| else: | |
| gen_model = AutoModelForCausalLM.from_pretrained(gen_model_checkpoint, token=HF_TOKEN) | |
| tokenizer = AutoTokenizer.from_pretrained(gen_model_checkpoint, token=HF_TOKEN) | |
| elif "CohereForAI/aya-101" in gen_model_checkpoint or "bigscience/mt0" in gen_model_checkpoint: | |
| if load_in_8bit: | |
| gen_model = AutoModelForSeq2SeqLM.from_pretrained(gen_model_checkpoint, token=HF_TOKEN, device_map="auto", | |
| load_in_8bit=True) | |
| tokenizer = AutoTokenizer.from_pretrained(gen_model_checkpoint, token=HF_TOKEN, device_map="auto", | |
| load_in_8bit=True) | |
| else: | |
| gen_model = AutoModelForSeq2SeqLM.from_pretrained(gen_model_checkpoint, token=HF_TOKEN) | |
| tokenizer = AutoTokenizer.from_pretrained(gen_model_checkpoint, token=HF_TOKEN) | |
| elif "facebook/xglm" in gen_model_checkpoint or "bigscience/bloomz" in gen_model_checkpoint or "aya-23-8B" in args.gen_model_checkpoint: | |
| if load_in_8bit: | |
| gen_model = AutoModelForCausalLM.from_pretrained(gen_model_checkpoint, token=HF_TOKEN, device_map="auto", | |
| load_in_8bit=True) | |
| tokenizer = AutoTokenizer.from_pretrained(gen_model_checkpoint, token=HF_TOKEN, device_map="auto", | |
| load_in_8bit=True) | |
| else: | |
| gen_model = AutoModelForCausalLM.from_pretrained(gen_model_checkpoint, token=HF_TOKEN) | |
| tokenizer = AutoTokenizer.from_pretrained(gen_model_checkpoint, token=HF_TOKEN) | |
| elif "gpt-3.5-turbo" in gen_model_checkpoint or "gpt-4" in gen_model_checkpoint: | |
| gen_model = OpenAI(api_key=OPENAI_TOKEN) | |
| elif "command-r" in gen_model_checkpoint: | |
| gen_model = cohere.Client(COHERE_TOKEN) | |
| return gen_model, tokenizer | |