Spaces:
Runtime error
Runtime error
| import abc | |
| import logging | |
| import re | |
| import httpx | |
| from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, GenerationConfig | |
| from base import JobInput | |
| logger = logging.getLogger(__name__) | |
| logger.setLevel(logging.DEBUG) | |
| MODEL_NAME = "google/flan-t5-large" | |
| model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME) | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
| class Summarizer: | |
| def __init__(self) -> None: | |
| self.template = "Summarize the text below in two sentences:\n\n{}" | |
| self.generation_config = GenerationConfig.from_pretrained(MODEL_NAME) | |
| self.generation_config.max_new_tokens = 200 | |
| self.generation_config.min_new_tokens = 100 | |
| self.generation_config.top_k = 5 | |
| self.generation_config.repetition_penalty = 1.5 | |
| def __call__(self, x: str) -> str: | |
| text = self.template.format(x) | |
| inputs = tokenizer(text, return_tensors="pt") | |
| outputs = model.generate(**inputs, generation_config=self.generation_config) | |
| output = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0] | |
| assert isinstance(output, str) | |
| return output | |
| def get_name(self) -> str: | |
| return f"Summarizer({MODEL_NAME})" | |
| class Tagger: | |
| def __init__(self) -> None: | |
| self.template = ( | |
| "Create a list of tags for the text below. The tags should be high level " | |
| "and specific. Prefix each tag with a hashtag.\n\n{}\n\nTags: #general" | |
| ) | |
| self.generation_config = GenerationConfig.from_pretrained(MODEL_NAME) | |
| self.generation_config.max_new_tokens = 50 | |
| self.generation_config.min_new_tokens = 25 | |
| # increase the temperature to make the model more creative | |
| self.generation_config.temperature = 1.5 | |
| def _extract_tags(self, text: str) -> list[str]: | |
| tags = set() | |
| for tag in text.split(): | |
| if tag.startswith("#"): | |
| tags.add(tag.lower()) | |
| return sorted(tags) | |
| def __call__(self, x: str) -> list[str]: | |
| text = self.template.format(x) | |
| inputs = tokenizer(text, return_tensors="pt") | |
| outputs = model.generate(**inputs, generation_config=self.generation_config) | |
| output = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0] | |
| tags = self._extract_tags(output) | |
| return tags | |
| def get_name(self) -> str: | |
| return f"Tagger({MODEL_NAME})" | |
| class Processor(abc.ABC): | |
| def __call__(self, job: JobInput) -> str: | |
| _id = job.id | |
| logger.info(f"Processing {input} with {self.__class__.__name__} (id={_id[:8]})") | |
| result = self.process(job) | |
| logger.info(f"Finished processing input (id={_id[:8]})") | |
| return result | |
| def process(self, input: JobInput) -> str: | |
| raise NotImplementedError | |
| def match(self, input: JobInput) -> bool: | |
| raise NotImplementedError | |
| def get_name(self) -> str: | |
| raise NotImplementedError | |
| class RawProcessor(Processor): | |
| def match(self, input: JobInput) -> bool: | |
| return True | |
| def process(self, input: JobInput) -> str: | |
| return input.content | |
| def get_name(self) -> str: | |
| return self.__class__.__name__ | |
| class PlainUrlProcessor(Processor): | |
| def __init__(self) -> None: | |
| self.client = httpx.Client() | |
| self.regex = re.compile(r"(https?://[^\s]+)") | |
| self.url = None | |
| self.template = "{url}\n\n{content}" | |
| def match(self, input: JobInput) -> bool: | |
| urls = list(self.regex.findall(input.content)) | |
| if len(urls) == 1: | |
| self.url = urls[0] | |
| return True | |
| return False | |
| def process(self, input: JobInput) -> str: | |
| """Get content of website and return it as string""" | |
| assert isinstance(self.url, str) | |
| text = self.client.get(self.url).text | |
| assert isinstance(text, str) | |
| text = self.template.format(url=self.url, content=text) | |
| return text | |
| def get_name(self) -> str: | |
| return self.__class__.__name__ | |
| class ProcessorRegistry: | |
| def __init__(self) -> None: | |
| self.registry: list[Processor] = [] | |
| self.default_registry: list[Processor] = [] | |
| self.set_default_processors() | |
| def set_default_processors(self) -> None: | |
| self.default_registry.extend([PlainUrlProcessor(), RawProcessor()]) | |
| def register(self, processor: Processor) -> None: | |
| self.registry.append(processor) | |
| def dispatch(self, input: JobInput) -> Processor: | |
| for processor in self.registry + self.default_registry: | |
| if processor.match(input): | |
| return processor | |
| # should never be requires, but eh | |
| return RawProcessor() | |