from transformers import AutoModelForCausalLM,AutoTokenizer,BitsAndBytesConfig import torch import os class EndpointHandler(): def __init__(self, model_id=""): self.device = "cuda:0" self.bnb_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16,) self.tokenizer = AutoTokenizer.from_pretrained(model_id) self.model = AutoModelForCausalLM.from_pretrained(model_id, device_map={"":0}, quantization_config=self.bnb_config,) def __call__(self, input:str) -> str: inputs = self.tokenizer(input, return_tensors="pt").to(self.device) outputs = self.model.generate(**inputs, max_new_tokens=20) result = self.tokenizer.decode(outputs[0], skip_special_tokens=True) return result