|
|
from typing import Dict, Any |
|
|
import torch |
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
|
|
|
class EndpointHandler: |
|
|
def __init__(self, path: str = "/repository"): |
|
|
self.device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
print(f"Loading tokenizer from {path}...") |
|
|
self.tokenizer = AutoTokenizer.from_pretrained(path) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.tokenizer.pad_token is None: |
|
|
self.tokenizer.pad_token = self.tokenizer.eos_token |
|
|
|
|
|
print(f"Loading model from {path} on device: {self.device}...") |
|
|
self.model = AutoModelForCausalLM.from_pretrained( |
|
|
path, |
|
|
torch_dtype=torch.float16, |
|
|
trust_remote_code=True, |
|
|
device_map="auto", |
|
|
low_cpu_mem_usage=True |
|
|
|
|
|
) |
|
|
self.model.eval() |
|
|
print("✅ Model loaded successfully!") |
|
|
|
|
|
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: |
|
|
inputs = data.get("inputs", "") |
|
|
parameters = data.get("parameters", {}) or {} |
|
|
|
|
|
if not isinstance(inputs, str) or not inputs.strip(): |
|
|
return {"generated_text": ""} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
gen_kwargs = { |
|
|
"max_new_tokens": min(parameters.get("max_new_tokens", 256), 512), |
|
|
"temperature": parameters.get("temperature", 0.3), |
|
|
"top_p": parameters.get("top_p", 0.95), |
|
|
"top_k": parameters.get("top_k", 50), |
|
|
"do_sample": parameters.get("do_sample", True), |
|
|
"repetition_penalty": parameters.get("repetition_penalty", 1.1), |
|
|
"eos_token_id": self.tokenizer.eos_token_id, |
|
|
"pad_token_id": self.tokenizer.pad_token_id, |
|
|
} |
|
|
|
|
|
print(f"Generating with parameters: {gen_kwargs}") |
|
|
|
|
|
|
|
|
|
|
|
inputs = inputs.strip() |
|
|
tokenized = self.tokenizer( |
|
|
|
|
|
inputs, |
|
|
return_tensors="pt", |
|
|
truncation=True, |
|
|
max_length=2048, |
|
|
padding=True |
|
|
).to(self.device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
|
|
|
outputs = self.model.generate( |
|
|
input_ids=tokenized.input_ids, |
|
|
attention_mask=tokenized.attention_mask, |
|
|
**gen_kwargs, |
|
|
use_cache=True |
|
|
) |
|
|
|
|
|
|
|
|
new_tokens = outputs[0][len(tokenized.input_ids[0]):] |
|
|
generated = self.tokenizer.decode( |
|
|
new_tokens, |
|
|
skip_special_tokens=True, |
|
|
clean_up_tokenization_spaces=True |
|
|
) |
|
|
|
|
|
|
|
|
return {"generated_text": generated.strip()} |
|
|
|