WazapSplitter-LLM / handler.py
joseAndres777's picture
Update handler.py
9715f6f verified
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
import torch
import json
import os
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
class EndpointHandler:
def __init__(self, path=""):
"""
Initialize the handler with the model from the given path
"""
model_name = "meta-llama/Llama-3.3-70B-Instruct"
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
base_model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16,
device_map="auto",
trust_remote_code=True,
load_in_8bit=True,
low_cpu_mem_usage=True
)
try:
self.model = PeftModel.from_pretrained(
base_model,
path,
is_trainable=False
)
print("Successfully loaded adapter with base model")
except Exception as e:
print(f"Error loading adapter: {e}")
print("Falling back to base model without adapter")
self.model = base_model
try:
with open(f"{path}/chat_template.jinja", "r") as f:
self.chat_template = f.read()
except:
self.chat_template = None
def __call__(self, data):
"""
Process the input data and return the model's response
"""
inputs = data.get("inputs", "")
parameters = data.get("parameters", {})
default_prompt = "Break this text into WhatsApp messages like a real person would send them. Split where you'd naturally pause: after greetings, before/after questions, between different thoughts, when changing topics. Preserve exact wording - just divide where someone would actually hit 'send' and start a new message. Output JSON array."
custom_prompt = parameters.get("prompt", default_prompt)
messages = [
{"role": "system", "content": custom_prompt},
{"role": "user", "content": inputs}
]
if self.chat_template:
text = self.tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
else:
text = f"{custom_prompt}\nUser: {inputs}\nAssistant:"
# Tokenize
model_inputs = self.tokenizer(text, return_tensors="pt").to(self.model.device)
# Generate response
with torch.no_grad():
outputs = self.model.generate(
**model_inputs,
max_new_tokens=parameters.get("max_new_tokens", 100),
temperature=parameters.get("temperature", 0.3),
top_p=parameters.get("top_p", 0.9),
do_sample=True,
pad_token_id=self.tokenizer.eos_token_id,
repetition_penalty=1.1
)
response = self.tokenizer.decode(
outputs[0][model_inputs.input_ids.shape[-1]:],
skip_special_tokens=True
).strip()
try:
if response.startswith('[') and response.endswith(']'):
parsed = json.loads(response)
if isinstance(parsed, list):
formatted_response = response
else:
formatted_response = json.dumps([response])
else:
formatted_response = json.dumps([response])
except:
formatted_response = json.dumps([inputs])
return [{"content": formatted_response}]