Spaces:
Runtime error
Runtime error
| import os | |
| import random | |
| from threading import Thread | |
| from typing import Iterable | |
| import torch | |
| from huggingface_hub import HfApi | |
| from datasets import load_dataset | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer | |
| ground_truth = "" | |
| TOKEN = os.environ.get("HF_TOKEN", None) | |
| type2dataset = { | |
| "re2text-easy": load_dataset('3B-Group/ConvRe', "en-re2text", token=TOKEN, split="prompt1"), | |
| "re2text-hard": load_dataset('3B-Group/ConvRe', "en-re2text", token=TOKEN, split="prompt4"), | |
| "text2re-easy": load_dataset('3B-Group/ConvRe', "en-text2re", token=TOKEN, split="prompt1"), | |
| "text2re-hard": load_dataset('3B-Group/ConvRe', "en-text2re", token=TOKEN, split="prompt3") | |
| } | |
| model_id = "meta-llama/Llama-2-7b-chat-hf" | |
| tokenizer = AutoTokenizer.from_pretrained(model_id, token=TOKEN) | |
| model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, token=TOKEN, device_map="auto").eval() | |
| # model_id = "google/flan-t5-base" | |
| # tokenizer = T5Tokenizer.from_pretrained(model_id) | |
| # model = T5ForConditionalGeneration.from_pretrained(model_id, device_map="auto") | |
| # type2dataset = {} | |
| def generate(input_text, sys_prompt, temperature, max_new_tokens) -> str: | |
| sys_prompt = f'''[INST] <<SYS>> | |
| {sys_prompt} | |
| <</SYS>> | |
| ''' | |
| input_str = sys_prompt + input_text + " [/INST]" | |
| input_ids = tokenizer(input_str, return_tensors="pt").to("cuda") | |
| streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True) | |
| generate_kwargs = dict( | |
| input_ids, | |
| streamer=streamer, | |
| max_new_tokens=max_new_tokens, | |
| do_sample=True, | |
| temperature=float(temperature) | |
| ) | |
| t = Thread(target=model.generate, kwargs=generate_kwargs) | |
| t.start() | |
| # Pull the generated text from the streamer, and update the model output. | |
| model_output = "" | |
| for new_text in streamer: | |
| model_output += new_text | |
| yield model_output | |
| return model_output | |
| def random_examples(dataset_key) -> str: | |
| # target_dataset = type2dataset[f"{task.lower()}-{type.lower()}"] | |
| target_dataset = type2dataset[dataset_key] | |
| idx = random.randint(0, len(target_dataset) - 1) | |
| item = target_dataset[idx] | |
| global ground_truth | |
| ground_truth = item['answer'] | |
| return item['query'] | |
| def return_ground_truth() -> str: | |
| correct_answer = ground_truth | |
| return correct_answer | |