File size: 4,377 Bytes
199308a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 |
from typing import Dict, List, Any, Optional
import torch
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
TextIteratorStreamer,
)
class EndpointHandler:
"""
Custom Inference Endpoints handler for algorythmtechnologies/Warren-8B-Uncensored-2000.
Expected JSON payload:
{
"inputs": "user prompt or message",
"max_new_tokens": 256, # optional
"temperature": 0.7, # optional
"top_p": 0.9, # optional
"top_k": 50, # optional
"repetition_penalty": 1.1, # optional
"stop_sequences": ["</s>"] # optional
}
Returns:
[
{
"generated_text": "...",
"finish_reason": "length|stop|error"
}
]
"""
def __init__(self, path: str = ""):
# Choose device
self.device = "cuda" if torch.cuda.is_available() else "cpu"
# Load tokenizer and model from the repository path
self.tokenizer = AutoTokenizer.from_pretrained(path or ".")
# Make sure there is a pad_token for generation
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
self.model = AutoModelForCausalLM.from_pretrained(
path or ".",
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
device_map="auto" if self.device == "cuda" else None,
)
# Set model to eval mode
self.model.eval()
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
data args:
inputs (str): user text prompt
max_new_tokens (int, optional)
temperature (float, optional)
top_p (float, optional)
top_k (int, optional)
repetition_penalty (float, optional)
stop_sequences (List[str], optional)
Return:
A list with one dict:
[
{
"generated_text": str,
"finish_reason": str
}
]
"""
# Extract inputs
prompt: Optional[str] = data.get("inputs")
if prompt is None:
return [{"error": "Missing 'inputs' field in payload."}]
max_new_tokens: int = int(data.get("max_new_tokens", 256))
temperature: float = float(data.get("temperature", 0.7))
top_p: float = float(data.get("top_p", 0.9))
top_k: int = int(data.get("top_k", 50))
repetition_penalty: float = float(data.get("repetition_penalty", 1.05))
stop_sequences = data.get("stop_sequences", None)
# Tokenize
inputs = self.tokenizer(
prompt,
return_tensors="pt",
padding=False,
truncation=True,
).to(self.device)
# Configure basic generation kwargs
gen_kwargs = dict(
max_new_tokens=max_new_tokens,
do_sample=True,
temperature=temperature,
top_p=top_p,
top_k=top_k,
repetition_penalty=repetition_penalty,
pad_token_id=self.tokenizer.pad_token_id,
eos_token_id=self.tokenizer.eos_token_id,
)
# Run generation
with torch.no_grad():
output_ids = self.model.generate(
**inputs,
**gen_kwargs,
)
# Decode full text and strip the original prompt
full_text = self.tokenizer.decode(
output_ids[0],
skip_special_tokens=True,
)
# Try to remove the prompt from the beginning for cleaner output
if full_text.startswith(prompt):
generated_text = full_text[len(prompt) :].lstrip()
else:
generated_text = full_text
# Apply stop sequences post-hoc if provided
finish_reason = "length"
if stop_sequences:
for stop in stop_sequences:
idx = generated_text.find(stop)
if idx != -1:
generated_text = generated_text[:idx]
finish_reason = "stop"
break
return [
{
"generated_text": generated_text,
"finish_reason": finish_reason,
}
]
|