| | from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline |
| | import torch |
| |
|
| | class EndpointHandler: |
| | def __init__(self, path="krisoei/timgpt"): |
| | if not path: |
| | raise ValueError("A valid model path or name must be provided.") |
| | |
| | |
| | self.tokenizer = AutoTokenizer.from_pretrained(path) |
| | self.model = AutoModelForCausalLM.from_pretrained( |
| | path, |
| | torch_dtype=torch.float16, |
| | device_map="auto" |
| | ) |
| | |
| | |
| | self.pipe = pipeline( |
| | "text-generation", |
| | model=self.model, |
| | tokenizer=self.tokenizer, |
| | max_new_tokens=512, |
| | do_sample=True, |
| | temperature=0.7, |
| | top_p=0.95, |
| | ) |
| |
|
| | def __call__(self, data): |
| | |
| | if not isinstance(data, dict): |
| | return {"error": "Input must be a JSON object."} |
| | |
| | prompt = data.get("inputs", "") |
| | if not prompt: |
| | return {"error": "No input provided."} |
| | |
| | try: |
| | |
| | outputs = self.pipe(prompt) |
| | if outputs: |
| | response = outputs[0]['generated_text'] |
| | |
| | response = response[len(prompt):].strip() |
| | return {"generated_text": response} |
| | else: |
| | return {"error": "No output generated."} |
| | except Exception as e: |
| | return {"error": str(e)} |
| |
|