Spaces:
Sleeping
Sleeping
| import json | |
| import random | |
| import re | |
| # import spacy | |
| import torch | |
| from config import ( | |
| DEFAULT_FEW_SHOT_NUM, | |
| DEFAULT_FEW_SHOT_SELECTION, | |
| DEFAULT_TEMPERATURE, | |
| DEFAULT_TOP_P, | |
| DEFAULT_KIND, | |
| ) | |
| from typing import List, Dict, Tuple, Union | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig | |
| from .extractions import extract_all_tagged_phrases | |
| # nlp = spacy.load("en_core_web_sm") | |
| # TODO: run with constituency tests | |
| # TODO: review instruction and system level prompt (currently they are repetitive) | |
| def get_sentences(text: str) -> List[str]: | |
| # TODO: spacy splitting results in unequal lengths | |
| # doc = nlp(text) | |
| # sentences = [sent.text.strip() for sent in doc.sents] | |
| # sentences = [s for s in sentences if s] | |
| # return sentences | |
| return text.split(". ") | |
| def format_instance(sentence: str, extraction: Union[str, None]) -> str: | |
| return "".join( | |
| [ | |
| f"Sentence: {sentence}\n", | |
| ( | |
| f"Extractions:\n{extraction}\n" | |
| if extraction is not None | |
| else f"Extractions:\n" | |
| ), | |
| ] | |
| ) | |
| def generate_instructions(schema: dict, kind: str = DEFAULT_KIND) -> str: | |
| instruction_parts = [ | |
| "The following schema is provided to tag the title and abstract of a given scientific paper as shown in the examples:\n" | |
| ] | |
| if kind == "json": | |
| instruction_parts.append(f"{json.dumps(schema, indent=2)}\n\n") | |
| elif kind == "readable": | |
| readable_schema = "" | |
| for tag, description in schema.items(): | |
| readable_schema += f"{tag}: {description}\n" | |
| instruction_parts.append(f"{readable_schema}\n") | |
| else: | |
| raise ValueError(f"Invalid kind: {kind}") | |
| return "".join(instruction_parts) | |
| def generate_demonstrations( | |
| examples: List[dict], | |
| kind: str = DEFAULT_KIND, | |
| num_examples: int = DEFAULT_FEW_SHOT_NUM, | |
| selection: str = DEFAULT_FEW_SHOT_SELECTION, | |
| ) -> str: | |
| demonstration_parts = [] | |
| for example in examples: | |
| sentences = get_sentences(example["abstract"]) | |
| tagged_sentences = get_sentences(example["tagged_abstract"]) | |
| paired_sentences = list(zip(sentences, tagged_sentences, strict=True)) | |
| if selection == "random": | |
| selected_pairs = random.sample( | |
| paired_sentences, min(num_examples, len(paired_sentences)) | |
| ) | |
| elif selection == "first": | |
| selected_pairs = paired_sentences[:num_examples] | |
| elif selection == "last": | |
| selected_pairs = paired_sentences[-num_examples:] | |
| elif selection == "middle": | |
| start = max(0, (len(paired_sentences) - num_examples) // 2) | |
| selected_pairs = paired_sentences[start : start + num_examples] | |
| elif selection == "distributed": | |
| step = max(1, len(paired_sentences) // num_examples) | |
| selected_pairs = paired_sentences[::step][:num_examples] | |
| elif selection == "longest": | |
| selected_pairs = sorted( | |
| paired_sentences, key=lambda x: len(x[0]), reverse=True | |
| )[:num_examples] | |
| elif selection == "shortest": | |
| selected_pairs = sorted(paired_sentences, key=lambda x: len(x[0]))[ | |
| :num_examples | |
| ] | |
| else: | |
| raise ValueError(f"Invalid selection method: {selection}") | |
| for sentence, tagged_sentence in selected_pairs: | |
| tag_to_phrase = extract_all_tagged_phrases(tagged_sentence) | |
| if kind == "json": | |
| extractions = f"{json.dumps(tag_to_phrase, indent=2)}\n" | |
| elif kind == "readable": | |
| extractions = "".join( | |
| f"{tag}: {', '.join(phrase)}\n" | |
| for tag, phrase in tag_to_phrase.items() | |
| ) | |
| else: | |
| raise ValueError(f"Invalid kind: {kind}") | |
| demonstration_parts.append(format_instance(sentence, extractions)) | |
| return "".join(demonstration_parts) | |
| def generate_prefix(instructions: str, demonstrations: str) -> str: | |
| return f"{instructions}" f"{demonstrations}" | |
| def generate_prediction( | |
| model, | |
| tokenizer, | |
| prefix: str, | |
| input: str, | |
| kind: str, | |
| system_prompt: str = f"You are an assistant who tags papers according to given schema and " | |
| "only returns the tagged phrases in the format as provided in the examples " | |
| "without repeating anything else.", | |
| temperature: float = DEFAULT_TEMPERATURE, | |
| top_p: float = DEFAULT_TOP_P, | |
| ) -> str: | |
| prompt = prefix + input | |
| messages = [ | |
| { | |
| "role": "system", | |
| "content": system_prompt, | |
| }, | |
| {"role": "user", "content": prompt}, | |
| ] | |
| input_ids = tokenizer.apply_chat_template( | |
| messages, | |
| # add_generation_prompt=True, | |
| return_tensors="pt", | |
| ).to(model.device) | |
| terminators = [ | |
| tokenizer.eos_token_id, | |
| tokenizer.convert_tokens_to_ids("<|eot_id|>"), | |
| ] | |
| outputs = model.generate( | |
| input_ids, | |
| max_new_tokens=1200, | |
| eos_token_id=terminators, | |
| # num_beams=8, | |
| do_sample=True, | |
| temperature=temperature, | |
| top_p=top_p, | |
| ) | |
| response = outputs[0][input_ids.shape[-1] :] | |
| prediction_response = tokenizer.decode(response, skip_special_tokens=True) | |
| return prediction_response | |
| def batch_generate_prediction( | |
| model, | |
| tokenizer, | |
| prefix: str, | |
| input_ids: torch.Tensor, | |
| kind: str, | |
| system_prompt: str = "You are an assistant who tags papers according to given schema and " | |
| "only returns the tagged phrases in the format as provided in the examples " | |
| "without repeating anything else.", | |
| temperature: float = DEFAULT_TEMPERATURE, | |
| top_p: float = DEFAULT_TOP_P, | |
| max_new_tokens: int = 1200, | |
| batch_size: int = 1, | |
| device: torch.device = torch.device("cuda" if torch.cuda.is_available() else "cpu"), | |
| ) -> List[str]: | |
| all_predictions = [] | |
| # Prepare system message | |
| system_message = {"role": "system", "content": system_prompt} | |
| for i in range(0, input_ids.size(0), batch_size): | |
| batch_input_ids = input_ids[i : i + batch_size] | |
| batch_messages = [ | |
| [ | |
| system_message, | |
| { | |
| "role": "user", | |
| "content": prefix + tokenizer.decode(ids, skip_special_tokens=True), | |
| }, | |
| ] | |
| for ids in batch_input_ids | |
| ] | |
| batch_input_ids = tokenizer.apply_chat_template( | |
| batch_messages, return_tensors="pt", padding=True, truncation=True | |
| ).to(device) | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| batch_input_ids, | |
| max_new_tokens=max_new_tokens, | |
| do_sample=True, | |
| temperature=temperature, | |
| top_p=top_p, | |
| pad_token_id=tokenizer.pad_token_id, | |
| attention_mask=batch_input_ids.ne(tokenizer.pad_token_id), | |
| ) | |
| for output in outputs: | |
| response = output[batch_input_ids.size(1) :] | |
| prediction_response = tokenizer.decode(response, skip_special_tokens=True) | |
| all_predictions.append(prediction_response) | |
| torch.cuda.empty_cache() | |
| return all_predictions | |