from transformers import AutoTokenizer, AutoModelForSeq2SeqLM import torch import os class EndpointHandler: def __init__(self, path=""): # Explicitly prevent sentence-transformers auto-detection os.environ["TRANSFORMERS_OFFLINE"] = "1" print(f"Loading T5Gemma model from: {path}") self.tokenizer = AutoTokenizer.from_pretrained( path, trust_remote_code=True ) self.model = AutoModelForSeq2SeqLM.from_pretrained( path, torch_dtype=torch.bfloat16, trust_remote_code=True, device_map="auto" ) print("T5Gemma model loaded successfully") def __call__(self, data): inputs = data.pop("inputs", data) messages = [{"role": "user", "content": inputs}] input_ids = self.tokenizer.apply_chat_template( messages, add_generation_prompt=True, return_tensors="pt" ) outputs = self.model.generate( input_ids, max_new_tokens=1024, temperature=0.1, do_sample=True ) return { "generated_text": self.tokenizer.decode(outputs[0], skip_special_tokens=True) }