| from abc import ABC, abstractmethod |
| from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig, Pipeline, pipeline |
| import logging |
| import torch |
|
|
| from utils import get_torch_device |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| class ChatModel(ABC): |
|
|
| @abstractmethod |
| def generate(self, messages: list[dict[str, str]]) -> dict[str, str]: |
| pass |
|
|
|
|
| class AdjLabeler: |
| def __init__(self, model: ChatModel): |
| self.model = model |
|
|
| def label_example(self, exp, feature_name): |
| messages = [ |
| {"role": "system", |
| "content": "You are a helpful Grammar tutor."}, |
| {"role": "user", |
| "content": "An adjective is a word that describes a noun?"}, |
| {"role": "assistant", |
| "content": "Yes, that's correct! An adjective relates to, modifies, or describes nouns."}, |
| {"role": "user", |
| "content": "Are they always used with nouns?"}, |
| {"role": "assistant", |
| "content": ("No, adjectives often appear directly before nouns (e.g. \"a red apple\") " |
| "but they can also follow linking verbs to describe the subject (e.g. \"The sky is blue\"). " |
| "Sometimes, adjectives are used as complements in certain constructions or phrases " |
| "(e.g. \"the rich\" or \"well-known author\").")}, |
| {"role": "user", |
| "content": "They can have comparative or superlative forms too, right?"}, |
| {"role": "assistant", |
| "content": ("Yes, that's right! The word \"fast\" can take a comparative form as in \"faster\" " |
| "or a superlative form as in \"fastest\". Some adjectives don't have comparative or " |
| "superlative forms but use the word \"more\" or \"most\" to become comparative or " |
| "superlative.")}, |
| {"role": "user", |
| "content": f"How about this example: {exp['tokens']}"}, |
| ] |
|
|
| token_labels = [] |
| for idx, token in enumerate(exp["tokens"]): |
| token_messages = messages.copy() |
| token_messages.append({"role": "user", |
| "content": f"Is '{token}' at position {idx} an adjective? Answer 'yes' or 'no'."}) |
| |
|
|
| assistant_message = self.model.generate(token_messages) |
| logger.info(f"{assistant_message} - {token}") |
| token_messages.append(assistant_message) |
| messages += token_messages |
| return token_labels |
|
|
|
|
| class LlamaPipeline(ChatModel): |
| def __init__(self, model_name: str): |
| self.device = get_torch_device() |
| self.tokenizer = AutoTokenizer.from_pretrained(model_name) |
| self.pipeline = pipeline( |
| "text-generation", |
| model=model_name, |
| model_kwargs={"torch_dtype": torch.bfloat16}, |
| device_map="auto", |
| ) |
|
|
| def generate(self, messages, max_new_tokens=1) : |
| outputs = self.pipeline( |
| messages, |
| max_new_tokens=max_new_tokens, |
| pad_token_id=self.tokenizer.eos_token_id, |
| temperature=0.6, |
| top_p=0.9, |
| ) |
| return outputs[0]["generated_text"][-1] |
|
|
|
|
| class LlamaModel(ChatModel): |
| """ |
| A wrapper around a Llama model checkpoint using Hugging Face Transformers. |
| """ |
|
|
| def __init__(self, model_name: str): |
| torch_device = get_torch_device() |
|
|
| self.tokenizer = AutoTokenizer.from_pretrained(model_name) |
| self.model = AutoModelForCausalLM.from_pretrained( |
| model_name, |
| device_map=str(torch_device), |
| torch_dtype=torch.float16, |
| ) |
| self.model.to(torch_device) |
| self.model.eval() |
|
|
| |
| self.generation_config = GenerationConfig( |
| max_new_tokens=1, |
| pad_token_id=self.tokenizer.eos_token_id, |
| temperature=0.7, |
| top_p=0.9, |
| do_sample=True, |
| ) |
|
|
| def generate(self, prompt: str) -> str: |
| """ |
| Generate text from the model given a prompt. |
| """ |
| inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device) |
| with torch.no_grad(): |
| output_ids = self.model.generate( |
| **inputs, |
| generation_config=self.generation_config |
| ) |
| raw_output = self.tokenizer.decode(output_ids[0], skip_special_tokens=True) |
| return raw_output[len(prompt):] |
|
|
|
|
| |
| |
| |
|
|
| if __name__ == "__main__": |
| import logging.config |
| from utils import default_logging_config |
| logging.config.dictConfig(default_logging_config) |
|
|
| llama_pipeline = LlamaPipeline( |
| model_name="meta-llama/Llama-3.2-3B-Instruct", |
| |
| ) |
| adj_labeler = AdjLabeler(llama_pipeline) |
|
|
| basic_cases = [ |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| {"text": "Big cars use a lot more gas.", |
| "tokens": ["Big", "cars", "use", "a", "lot", "more", "gas."]}, |
| {"text": "My car is faster than my bicycle.", |
| "tokens": ["My", "car", "is", "faster", "than", "my", "bicycle."]}, |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| ] |
| for case in basic_cases: |
| adj_labels = adj_labeler.label_example(case, "adj") |
| logger.info(f"\ntokens:\t{case['tokens']}\nadj:\t{adj_labels}") |
|
|