Spaces:
Runtime error
Runtime error
| """ Functions for loading email generator.""" | |
| import torch | |
| import transformers | |
| DEVICE = 0 if torch.cuda.is_available() else -1 | |
| AVAILABLE_GENERATORS = [ | |
| 'pszemraj/opt-350m-email-generation', | |
| 'pszemraj/opt-350m-email-generation', | |
| 'postbot/gpt2-medium-emailgen', | |
| 'sagorsarker/emailgenerator' | |
| ] | |
| DEFAULT_GENERATOR = 'pszemraj/opt-350m-email-generation' | |
| class EmailGenerator: | |
| """ Class that loads and wraps a HuggingFace email generation pipeline.""" | |
| def __init__(self, model_tag: str) -> None: | |
| """ Initialize HuggingFace email generation pipeline. | |
| Args: | |
| model_tag (str): Model name. | |
| """ | |
| self.tag = model_tag | |
| self.generator = transformers.pipeline( | |
| 'text-generation', model_tag, | |
| use_fast=True, do_sample=False, | |
| device=DEVICE | |
| ) | |
| def generate(self, prompt: str, max_tokens: int) -> str: | |
| """ Generate a sample from a given prompt. | |
| Args: | |
| prompt (str): Prompting for email generator. | |
| max_tokens (int): Maximum number of tokens to return. | |
| Returns: | |
| str: Generated text. | |
| """ | |
| output = self.generator(prompt, max_length=max_tokens) | |
| return output[0]['generated_text'] | |
| def __str__(self): | |
| return f'EmailGenerator({self.tag})' | |
| def set_global_generator(model_tag: str=DEFAULT_GENERATOR): | |
| """ Set global parameter 'generator' as specified EmailGenerator.""" | |
| global generator | |
| generator = EmailGenerator(model_tag=model_tag) | |
| def generator_exists(): | |
| """ Check if global variable 'generator' has been defined.""" | |
| return 'generator' in globals() | |
| def generate_email(model_tag: str, prompt: str, max_tokens: int): | |
| """ Check for generator and create prompt. | |
| Initialize correct generator if incorrect generator or no generator is found. | |
| """ | |
| if not generator_exists() or generator.tag != model_tag: | |
| set_global_generator(model_tag=model_tag) | |
| return generator.generate(prompt, max_tokens=max_tokens) | |