algorythmtechnologies's picture
Create handler.py
199308a verified
from typing import Dict, List, Any, Optional
import torch
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
TextIteratorStreamer,
)
class EndpointHandler:
"""
Custom Inference Endpoints handler for algorythmtechnologies/Warren-8B-Uncensored-2000.
Expected JSON payload:
{
"inputs": "user prompt or message",
"max_new_tokens": 256, # optional
"temperature": 0.7, # optional
"top_p": 0.9, # optional
"top_k": 50, # optional
"repetition_penalty": 1.1, # optional
"stop_sequences": ["</s>"] # optional
}
Returns:
[
{
"generated_text": "...",
"finish_reason": "length|stop|error"
}
]
"""
def __init__(self, path: str = ""):
# Choose device
self.device = "cuda" if torch.cuda.is_available() else "cpu"
# Load tokenizer and model from the repository path
self.tokenizer = AutoTokenizer.from_pretrained(path or ".")
# Make sure there is a pad_token for generation
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
self.model = AutoModelForCausalLM.from_pretrained(
path or ".",
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
device_map="auto" if self.device == "cuda" else None,
)
# Set model to eval mode
self.model.eval()
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
data args:
inputs (str): user text prompt
max_new_tokens (int, optional)
temperature (float, optional)
top_p (float, optional)
top_k (int, optional)
repetition_penalty (float, optional)
stop_sequences (List[str], optional)
Return:
A list with one dict:
[
{
"generated_text": str,
"finish_reason": str
}
]
"""
# Extract inputs
prompt: Optional[str] = data.get("inputs")
if prompt is None:
return [{"error": "Missing 'inputs' field in payload."}]
max_new_tokens: int = int(data.get("max_new_tokens", 256))
temperature: float = float(data.get("temperature", 0.7))
top_p: float = float(data.get("top_p", 0.9))
top_k: int = int(data.get("top_k", 50))
repetition_penalty: float = float(data.get("repetition_penalty", 1.05))
stop_sequences = data.get("stop_sequences", None)
# Tokenize
inputs = self.tokenizer(
prompt,
return_tensors="pt",
padding=False,
truncation=True,
).to(self.device)
# Configure basic generation kwargs
gen_kwargs = dict(
max_new_tokens=max_new_tokens,
do_sample=True,
temperature=temperature,
top_p=top_p,
top_k=top_k,
repetition_penalty=repetition_penalty,
pad_token_id=self.tokenizer.pad_token_id,
eos_token_id=self.tokenizer.eos_token_id,
)
# Run generation
with torch.no_grad():
output_ids = self.model.generate(
**inputs,
**gen_kwargs,
)
# Decode full text and strip the original prompt
full_text = self.tokenizer.decode(
output_ids[0],
skip_special_tokens=True,
)
# Try to remove the prompt from the beginning for cleaner output
if full_text.startswith(prompt):
generated_text = full_text[len(prompt) :].lstrip()
else:
generated_text = full_text
# Apply stop sequences post-hoc if provided
finish_reason = "length"
if stop_sequences:
for stop in stop_sequences:
idx = generated_text.find(stop)
if idx != -1:
generated_text = generated_text[:idx]
finish_reason = "stop"
break
return [
{
"generated_text": generated_text,
"finish_reason": finish_reason,
}
]