| from transformers import AutoTokenizer, AutoModelForCausalLM, TextGenerationPipeline | |
| class EndpointHandler: | |
| def __init__(self, model_path="djangodevloper/llama3-70b-4bit-medqa"): | |
| self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) | |
| self.model = AutoModelForCausalLM.from_pretrained( | |
| model_path, | |
| device_map="auto", | |
| trust_remote_code=True | |
| ) | |
| self.pipeline = TextGenerationPipeline(model=self.model, tokenizer=self.tokenizer) | |
| def __call__(self, data): | |
| try: | |
| # Validate input format | |
| messages = data.get("inputs") | |
| # Static parameters | |
| params = { | |
| "max_new_tokens": 512, | |
| "temperature": 0.1, | |
| "eos_token_id": [ | |
| self.tokenizer.eos_token_id, | |
| self.tokenizer.convert_tokens_to_ids("<|eot_id|>") | |
| ] | |
| } | |
| # Generate response | |
| outputs = self.pipeline(messages, **params) | |
| generated_text = outputs[0]["generated_text"] | |
| # reply = generated_text[len(outputs):].strip() | |
| return {"generated_text": generated_text} | |
| except Exception as e: | |
| return {"error": f"An error occurred during inference: {e}"} | |