| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
| import torch | |
| import os | |
| class EndpointHandler: | |
| def __init__(self, path=""): | |
| # Explicitly prevent sentence-transformers auto-detection | |
| os.environ["TRANSFORMERS_OFFLINE"] = "1" | |
| print(f"Loading T5Gemma model from: {path}") | |
| self.tokenizer = AutoTokenizer.from_pretrained( | |
| path, | |
| trust_remote_code=True | |
| ) | |
| self.model = AutoModelForSeq2SeqLM.from_pretrained( | |
| path, | |
| torch_dtype=torch.bfloat16, | |
| trust_remote_code=True, | |
| device_map="auto" | |
| ) | |
| print("T5Gemma model loaded successfully") | |
| def __call__(self, data): | |
| inputs = data.pop("inputs", data) | |
| messages = [{"role": "user", "content": inputs}] | |
| input_ids = self.tokenizer.apply_chat_template( | |
| messages, | |
| add_generation_prompt=True, | |
| return_tensors="pt" | |
| ) | |
| outputs = self.model.generate( | |
| input_ids, | |
| max_new_tokens=1024, | |
| temperature=0.1, | |
| do_sample=True | |
| ) | |
| return { | |
| "generated_text": self.tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| } | |