| | |
| | |
| | |
| | |
| |
|
| | """Module to generate OpenELM output given a model and an input prompt.""" |
| | import os |
| | import logging |
| | import time |
| | import argparse |
| | from typing import Optional, Union |
| | import torch |
| |
|
| | from transformers import AutoTokenizer, AutoModelForCausalLM |
| |
|
| |
|
| | def generate( |
| | prompt: str, |
| | model: Union[str, AutoModelForCausalLM], |
| | hf_access_token: str = None, |
| | tokenizer: Union[str, AutoTokenizer] = 'meta-llama/Llama-2-7b-hf', |
| | device: Optional[str] = None, |
| | max_length: int = 1024, |
| | assistant_model: Optional[Union[str, AutoModelForCausalLM]] = None, |
| | generate_kwargs: Optional[dict] = None, |
| | ) -> str: |
| | """ Generates output given a prompt. |
| | |
| | Args: |
| | prompt: The string prompt. |
| | model: The LLM Model. If a string is passed, it should be the path to |
| | the hf converted checkpoint. |
| | hf_access_token: Hugging face access token. |
| | tokenizer: Tokenizer instance. If model is set as a string path, |
| | the tokenizer will be loaded from the checkpoint. |
| | device: String representation of device to run the model on. If None |
| | and cuda available it would be set to cuda:0 else cpu. |
| | max_length: Maximum length of tokens, input prompt + generated tokens. |
| | assistant_model: If set, this model will be used for |
| | speculative generation. If a string is passed, it should be the |
| | path to the hf converted checkpoint. |
| | generate_kwargs: Extra kwargs passed to the hf generate function. |
| | |
| | Returns: |
| | output_text: output generated as a string. |
| | generation_time: generation time in seconds. |
| | |
| | Raises: |
| | ValueError: If device is set to CUDA but no CUDA device is detected. |
| | ValueError: If tokenizer is not set. |
| | ValueError: If hf_access_token is not specified. |
| | """ |
| | if not device: |
| | if torch.cuda.is_available() and torch.cuda.device_count(): |
| | device = "cuda:0" |
| | logging.warning( |
| | 'inference device is not set, using cuda:0, %s', |
| | torch.cuda.get_device_name(0) |
| | ) |
| | else: |
| | device = 'cpu' |
| | logging.warning( |
| | ( |
| | 'No CUDA device detected, using cpu, ' |
| | 'expect slower speeds.' |
| | ) |
| | ) |
| |
|
| | if 'cuda' in device and not torch.cuda.is_available(): |
| | raise ValueError('CUDA device requested but no CUDA device detected.') |
| |
|
| | if not tokenizer: |
| | raise ValueError('Tokenizer is not set in the generate function.') |
| |
|
| | if not hf_access_token: |
| | raise ValueError(( |
| | 'Hugging face access token needs to be specified. ' |
| | 'Please refer to https://huggingface.co/docs/hub/security-tokens' |
| | ' to obtain one.' |
| | ) |
| | ) |
| |
|
| | if isinstance(model, str): |
| | checkpoint_path = model |
| | model = AutoModelForCausalLM.from_pretrained( |
| | checkpoint_path, |
| | trust_remote_code=True |
| | ) |
| | model.to(device).eval() |
| | if isinstance(tokenizer, str): |
| | tokenizer = AutoTokenizer.from_pretrained( |
| | tokenizer, |
| | token=hf_access_token, |
| | ) |
| |
|
| | |
| | draft_model = None |
| | if assistant_model: |
| | draft_model = assistant_model |
| | if isinstance(assistant_model, str): |
| | draft_model = AutoModelForCausalLM.from_pretrained( |
| | assistant_model, |
| | trust_remote_code=True |
| | ) |
| | draft_model.to(device).eval() |
| |
|
| | |
| | tokenized_prompt = tokenizer(prompt) |
| | tokenized_prompt = torch.tensor( |
| | tokenized_prompt['input_ids'], |
| | device=device |
| | ) |
| |
|
| | tokenized_prompt = tokenized_prompt.unsqueeze(0) |
| |
|
| | |
| | stime = time.time() |
| | output_ids = model.generate( |
| | tokenized_prompt, |
| | max_length=max_length, |
| | pad_token_id=0, |
| | assistant_model=draft_model, |
| | **(generate_kwargs if generate_kwargs else {}), |
| | ) |
| | generation_time = time.time() - stime |
| |
|
| | output_text = tokenizer.decode( |
| | output_ids[0].tolist(), |
| | skip_special_tokens=True |
| | ) |
| |
|
| | return output_text, generation_time |
| |
|
| |
|
| | def openelm_generate_parser(): |
| | """Argument Parser""" |
| |
|
| | class KwargsParser(argparse.Action): |
| | """Parser action class to parse kwargs of form key=value""" |
| | def __call__(self, parser, namespace, values, option_string=None): |
| | setattr(namespace, self.dest, dict()) |
| | for val in values: |
| | if '=' not in val: |
| | raise ValueError( |
| | ( |
| | 'Argument parsing error, kwargs are expected in' |
| | ' the form of key=value.' |
| | ) |
| | ) |
| | kwarg_k, kwarg_v = val.split('=') |
| | try: |
| | converted_v = int(kwarg_v) |
| | except ValueError: |
| | try: |
| | converted_v = float(kwarg_v) |
| | except ValueError: |
| | converted_v = kwarg_v |
| | getattr(namespace, self.dest)[kwarg_k] = converted_v |
| |
|
| | parser = argparse.ArgumentParser('OpenELM Generate Module') |
| | parser.add_argument( |
| | '--model', |
| | dest='model', |
| | help='Path to the hf converted model.', |
| | required=True, |
| | type=str, |
| | ) |
| | parser.add_argument( |
| | '--hf_access_token', |
| | dest='hf_access_token', |
| | help='Hugging face access token, starting with "hf_".', |
| | type=str, |
| | ) |
| | parser.add_argument( |
| | '--prompt', |
| | dest='prompt', |
| | help='Prompt for LLM call.', |
| | default='', |
| | type=str, |
| | ) |
| | parser.add_argument( |
| | '--device', |
| | dest='device', |
| | help='Device used for inference.', |
| | type=str, |
| | ) |
| | parser.add_argument( |
| | '--max_length', |
| | dest='max_length', |
| | help='Maximum length of tokens.', |
| | default=256, |
| | type=int, |
| | ) |
| | parser.add_argument( |
| | '--assistant_model', |
| | dest='assistant_model', |
| | help=( |
| | ( |
| | 'If set, this is used as a draft model ' |
| | 'for assisted speculative generation.' |
| | ) |
| | ), |
| | type=str, |
| | ) |
| | parser.add_argument( |
| | '--generate_kwargs', |
| | dest='generate_kwargs', |
| | help='Additional kwargs passed to the HF generate function.', |
| | type=str, |
| | nargs='*', |
| | action=KwargsParser, |
| | ) |
| | return parser.parse_args() |
| |
|
| |
|
| | if __name__ == '__main__': |
| | args = openelm_generate_parser() |
| | prompt = args.prompt |
| |
|
| | output_text, genertaion_time = generate( |
| | prompt=prompt, |
| | model=args.model, |
| | device=args.device, |
| | max_length=args.max_length, |
| | assistant_model=args.assistant_model, |
| | generate_kwargs=args.generate_kwargs, |
| | hf_access_token=args.hf_access_token, |
| | ) |
| |
|
| | print_txt = ( |
| | f'\r\n{"=" * os.get_terminal_size().columns}\r\n' |
| | '\033[1m Prompt + Generated Output\033[0m\r\n' |
| | f'{"-" * os.get_terminal_size().columns}\r\n' |
| | f'{output_text}\r\n' |
| | f'{"-" * os.get_terminal_size().columns}\r\n' |
| | '\r\nGeneration took' |
| | f'\033[1m\033[92m {round(genertaion_time, 2)} \033[0m' |
| | 'seconds.\r\n' |
| | ) |
| | print(print_txt) |
| |
|