from typing import Dict, Any, List from transformers import AutoModelForSeq2SeqLM, AutoTokenizer import torch from torch.cuda.amp import autocast class EndpointHandler: def __init__(self, path="chentong00/propositionizer-wiki-flan-t5-large"): """ Initialize the handler by loading the model, tokenizer, and setting the device. """ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.tokenizer = AutoTokenizer.from_pretrained(path) self.model = AutoModelForSeq2SeqLM.from_pretrained(path).to(self.device).half() def process_chunks( self, chunks: List[str], titles: List[str], dates: List[str] ) -> List[str]: """ Process multiple text chunks with the model. Args: chunks (list): List of text content to process. titles (list): List of document titles corresponding to the chunks. dates (list): List of document dates corresponding to the chunks. Returns: list: List of generated output texts. """ input_texts = [ f"Title: {t}. Date: {d}. Content: {c}" for c, t, d in zip(chunks, titles, dates) ] input_ids = self.tokenizer( input_texts, return_tensors="pt", padding=True, truncation=True, max_length=1024, ).input_ids.to(self.device) try: with torch.no_grad(): # Use autocast for mixed precision on CUDA devices if self.device.type == "cuda": with autocast(): outputs = self.model.generate( input_ids, max_new_tokens=512, no_repeat_ngram_size=5, length_penalty=1.2, num_beams=5, ) else: outputs = self.model.generate( input_ids, max_new_tokens=512, no_repeat_ngram_size=5, length_penalty=1.2, num_beams=5, ) predictions = self.tokenizer.batch_decode(outputs, skip_special_tokens=True) finally: # Explicit memory cleanup del input_ids, outputs torch.cuda.empty_cache() if self.device.type == "cuda": torch.cuda.synchronize() return predictions def __call__(self, data: Dict[str, Any]) -> Dict[str, List[Dict[str, Any]]]: """ Handle the inference request. Args: data (dict): The payload with text inputs. Returns: dict: The processed outputs containing the generated text for each input along with their IDs. """ inputs = data.get("inputs", []) # Ensure inputs is a list of dictionaries if not isinstance(inputs, list) or not all(isinstance(i, dict) for i in inputs): raise ValueError("The inputs must be a list of dictionaries.") chunks, titles, dates, ids = [], [], [], [] for item in inputs: for key in ["id", "chunk", "title", "date"]: if key not in item: raise ValueError(f"Each input must contain the key: {key}.") ids.append(item["id"]) chunks.append(item["chunk"]) titles.append(item["title"]) dates.append(item["date"]) predictions = self.process_chunks(chunks, titles, dates) result = [ {"id": id_, "generated_text": prediction} for id_, prediction in zip(ids, predictions) ] return {"results": result}