multi-classifier / llama_dataset_maker.py
veryfansome's picture
feat: UD is back, LlaMA play
0cdb887
raw
history blame
9.39 kB
from abc import ABC, abstractmethod
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig, Pipeline, pipeline
import logging
import torch
from utils import get_torch_device
logger = logging.getLogger(__name__)
class ChatModel(ABC):
@abstractmethod
def generate(self, messages: list[dict[str, str]]) -> dict[str, str]:
pass
class AdjLabeler:
def __init__(self, model: ChatModel):
self.model = model
def label_example(self, exp, feature_name):
messages = [
{"role": "system",
"content": "You are a helpful Grammar tutor."},
{"role": "user",
"content": "An adjective is a word that describes a noun?"},
{"role": "assistant",
"content": "Yes, that's correct! An adjective relates to, modifies, or describes nouns."},
{"role": "user",
"content": "Are they always used with nouns?"},
{"role": "assistant",
"content": ("No, adjectives often appear directly before nouns (e.g. \"a red apple\") "
"but they can also follow linking verbs to describe the subject (e.g. \"The sky is blue\"). "
"Sometimes, adjectives are used as complements in certain constructions or phrases "
"(e.g. \"the rich\" or \"well-known author\").")},
{"role": "user",
"content": "They can have comparative or superlative forms too, right?"},
{"role": "assistant",
"content": ("Yes, that's right! The word \"fast\" can take a comparative form as in \"faster\" "
"or a superlative form as in \"fastest\". Some adjectives don't have comparative or "
"superlative forms but use the word \"more\" or \"most\" to become comparative or "
"superlative.")},
{"role": "user",
"content": f"How about this example: {exp['tokens']}"},
]
token_labels = []
for idx, token in enumerate(exp["tokens"]):
token_messages = messages.copy()
token_messages.append({"role": "user",
"content": f"Is '{token}' at position {idx} an adjective? Answer 'yes' or 'no'."})
#logger.info(f"token_messages: {token_messages}")
assistant_message = self.model.generate(token_messages)
logger.info(f"{assistant_message} - {token}")
token_messages.append(assistant_message)
messages += token_messages
return token_labels
class LlamaPipeline(ChatModel):
def __init__(self, model_name: str):
self.device = get_torch_device()
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.pipeline = pipeline(
"text-generation",
model=model_name,
model_kwargs={"torch_dtype": torch.bfloat16},
device_map="auto",
)
def generate(self, messages, max_new_tokens=1) :
outputs = self.pipeline(
messages,
max_new_tokens=max_new_tokens,
pad_token_id=self.tokenizer.eos_token_id,
temperature=0.6,
top_p=0.9,
)
return outputs[0]["generated_text"][-1]
class LlamaModel(ChatModel):
"""
A wrapper around a Llama model checkpoint using Hugging Face Transformers.
"""
def __init__(self, model_name: str):
torch_device = get_torch_device()
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModelForCausalLM.from_pretrained(
model_name,
device_map=str(torch_device),
torch_dtype=torch.float16,
)
self.model.to(torch_device)
self.model.eval()
# Adjust generation parameters as needed
self.generation_config = GenerationConfig(
max_new_tokens=1,
pad_token_id=self.tokenizer.eos_token_id,
temperature=0.7,
top_p=0.9,
do_sample=True,
)
def generate(self, prompt: str) -> str:
"""
Generate text from the model given a prompt.
"""
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
with torch.no_grad():
output_ids = self.model.generate(
**inputs,
generation_config=self.generation_config
)
raw_output = self.tokenizer.decode(output_ids[0], skip_special_tokens=True)
return raw_output[len(prompt):]
# ----------------------------------
# Putting It All Together
# ----------------------------------
if __name__ == "__main__":
import logging.config
from utils import default_logging_config
logging.config.dictConfig(default_logging_config)
llama_pipeline = LlamaPipeline(
model_name="meta-llama/Llama-3.2-3B-Instruct",
#model_name="meta-llama/Llama-3.1-8B-Instruct",
)
adj_labeler = AdjLabeler(llama_pipeline)
basic_cases = [
#{"text": "Joan has a nice dog.",
# "tokens": ["Joan", "has", "a", "nice", "dog."]},
#{"text": "Bob is the most agile person I have ever met.",
# "tokens": ["Bob", "is", "the", "most", "agile", "person", "I", "have", "ever", "met."]},
#{"text": "He's a total shit head",
# "tokens": ["He's", "a", "total", "shit", "head"]},
#{"text": "The old, creaky house stood on the quiet street.",
# "tokens": ["The", "old,", "creaky", "house", "stood", "on", "the", "quiet", "street."]},
#{"text": "The sky turned brilliant blue as the sun emerged.",
# "tokens": ["The", "sky", "turned", "brilliant", "blue", "as", "the", "sun", "emerged."]},
#{"text": "They admired the well-behaved and enthusiastic children at the party.",
# "tokens": ["They", "admired", "the", "well-behaved", "and", "enthusiastic", "children", "at", "the",
# "party."]},
#{"text": "After dinner, she felt tired and content.",
# "tokens": ["After", "dinner,", "she", "felt", "tired", "and", "content."]},
#{"text": "The resourceful team devised a clever plan.",
# "tokens": ["The", "resourceful", "team", "devised", "a", "clever", "plan."]},
#{"text": "He handed over the thick book to the eager student.",
# "tokens": ["He", "handed", "over", "the", "thick", "book", "to", "the", "eager", "student."]},
#{"text": "We appreciated the delicious, handmade pie from our neighbor.",
# "tokens": ["We", "appreciated", "the", "delicious,", "handmade", "pie", "from", "our", "neighbor."]},
#{"text": "In the enchanted forest, sparkling fairies danced under the moonlight.",
# "tokens": ["In", "the", "enchanted", "forest,", "sparkling", "fairies", "danced", "under", "the", "moonlight."]},
#{"text": "The stray cats, hungry and dirty, roamed the narrow alley.",
# "tokens": ["The", "stray", "cats,", "hungry", "and", "dirty,", "roamed", "the", "narrow", "alley."]},
#{"text": "The challenging puzzle left the determined young boy both frustrated and excited.",
# "tokens": ["The", "challenging", "puzzle", "left", "the", "determined", "young", "boy", "both", "frustrated",
# "and", "excited."]},
{"text": "Big cars use a lot more gas.",
"tokens": ["Big", "cars", "use", "a", "lot", "more", "gas."]},
{"text": "My car is faster than my bicycle.",
"tokens": ["My", "car", "is", "faster", "than", "my", "bicycle."]},
#{"text": "This puzzle is more challenging than the one we solved yesterday.",
# "tokens": ["This", "puzzle", "is", "more", "challenging", "than", "the", "one", "we", "solved", "yesterday."]},
#{"text": "Among all the students, Lara is the most diligent.",
# "tokens": ["Among", "all", "the", "students,", "Lara", "is", "the", "most", "diligent."]},
#{"text": "That building is taller than the one next to it.",
# "tokens": ["That", "building", "is", "taller", "than", "the", "one", "next", "to", "it."]},
#{"text": "This book is more interesting than the movie adaptation.",
# "tokens": ["This", "book", "is", "more", "interesting", "than", "the", "movie", "adaptation."]},
#{"text": "Of all the fruits, mangoes are the sweetest.",
# "tokens": ["Of", "all", "the", "fruits,", "mangoes", "are", "the", "sweetest."]},
#{"text": "His running speed is quicker than anyone else's on the team.",
# "tokens": ["His", "running", "speed", "is", "quicker", "than", "anyone", "else's", "on", "the", "team."]},
#{"text": "The exam was easier than I had anticipated.",
# "tokens": ["The", "exam", "was", "easier", "than", "I", "had", "anticipated."]},
#{"text": "Among all the flavors, vanilla is the mildest.",
# "tokens": ["Among", "all", "the", "flavors,", "vanilla", "is", "the", "mildest."]},
#{"text": "The new smartphone is lighter than the previous version.",
# "tokens": ["The", "new", "smartphone", "is", "lighter", "than", "the", "previous", "version."]},
]
for case in basic_cases:
adj_labels = adj_labeler.label_example(case, "adj")
logger.info(f"\ntokens:\t{case['tokens']}\nadj:\t{adj_labels}")