slapmack's picture
memory cleanup
3cc9791
from typing import Dict, Any, List
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
import torch
from torch.cuda.amp import autocast
class EndpointHandler:
def __init__(self, path="chentong00/propositionizer-wiki-flan-t5-large"):
"""
Initialize the handler by loading the model, tokenizer, and setting the device.
"""
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.tokenizer = AutoTokenizer.from_pretrained(path)
self.model = AutoModelForSeq2SeqLM.from_pretrained(path).to(self.device).half()
def process_chunks(
self, chunks: List[str], titles: List[str], dates: List[str]
) -> List[str]:
"""
Process multiple text chunks with the model.
Args:
chunks (list): List of text content to process.
titles (list): List of document titles corresponding to the chunks.
dates (list): List of document dates corresponding to the chunks.
Returns:
list: List of generated output texts.
"""
input_texts = [
f"Title: {t}. Date: {d}. Content: {c}"
for c, t, d in zip(chunks, titles, dates)
]
input_ids = self.tokenizer(
input_texts,
return_tensors="pt",
padding=True,
truncation=True,
max_length=1024,
).input_ids.to(self.device)
try:
with torch.no_grad():
# Use autocast for mixed precision on CUDA devices
if self.device.type == "cuda":
with autocast():
outputs = self.model.generate(
input_ids,
max_new_tokens=512,
no_repeat_ngram_size=5,
length_penalty=1.2,
num_beams=5,
)
else:
outputs = self.model.generate(
input_ids,
max_new_tokens=512,
no_repeat_ngram_size=5,
length_penalty=1.2,
num_beams=5,
)
predictions = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
finally:
# Explicit memory cleanup
del input_ids, outputs
torch.cuda.empty_cache()
if self.device.type == "cuda":
torch.cuda.synchronize()
return predictions
def __call__(self, data: Dict[str, Any]) -> Dict[str, List[Dict[str, Any]]]:
"""
Handle the inference request.
Args:
data (dict): The payload with text inputs.
Returns:
dict: The processed outputs containing the generated text for each input along with their IDs.
"""
inputs = data.get("inputs", [])
# Ensure inputs is a list of dictionaries
if not isinstance(inputs, list) or not all(isinstance(i, dict) for i in inputs):
raise ValueError("The inputs must be a list of dictionaries.")
chunks, titles, dates, ids = [], [], [], []
for item in inputs:
for key in ["id", "chunk", "title", "date"]:
if key not in item:
raise ValueError(f"Each input must contain the key: {key}.")
ids.append(item["id"])
chunks.append(item["chunk"])
titles.append(item["title"])
dates.append(item["date"])
predictions = self.process_chunks(chunks, titles, dates)
result = [
{"id": id_, "generated_text": prediction}
for id_, prediction in zip(ids, predictions)
]
return {"results": result}