|
|
from typing import Dict, List, Any, Optional |
|
|
|
|
|
import torch |
|
|
from transformers import ( |
|
|
AutoTokenizer, |
|
|
AutoModelForCausalLM, |
|
|
TextIteratorStreamer, |
|
|
) |
|
|
|
|
|
|
|
|
class EndpointHandler: |
|
|
""" |
|
|
Custom Inference Endpoints handler for algorythmtechnologies/Warren-8B-Uncensored-2000. |
|
|
|
|
|
Expected JSON payload: |
|
|
{ |
|
|
"inputs": "user prompt or message", |
|
|
"max_new_tokens": 256, # optional |
|
|
"temperature": 0.7, # optional |
|
|
"top_p": 0.9, # optional |
|
|
"top_k": 50, # optional |
|
|
"repetition_penalty": 1.1, # optional |
|
|
"stop_sequences": ["</s>"] # optional |
|
|
} |
|
|
|
|
|
Returns: |
|
|
[ |
|
|
{ |
|
|
"generated_text": "...", |
|
|
"finish_reason": "length|stop|error" |
|
|
} |
|
|
] |
|
|
""" |
|
|
|
|
|
def __init__(self, path: str = ""): |
|
|
|
|
|
self.device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained(path or ".") |
|
|
|
|
|
if self.tokenizer.pad_token is None: |
|
|
self.tokenizer.pad_token = self.tokenizer.eos_token |
|
|
|
|
|
self.model = AutoModelForCausalLM.from_pretrained( |
|
|
path or ".", |
|
|
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32, |
|
|
device_map="auto" if self.device == "cuda" else None, |
|
|
) |
|
|
|
|
|
|
|
|
self.model.eval() |
|
|
|
|
|
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: |
|
|
""" |
|
|
data args: |
|
|
inputs (str): user text prompt |
|
|
max_new_tokens (int, optional) |
|
|
temperature (float, optional) |
|
|
top_p (float, optional) |
|
|
top_k (int, optional) |
|
|
repetition_penalty (float, optional) |
|
|
stop_sequences (List[str], optional) |
|
|
|
|
|
Return: |
|
|
A list with one dict: |
|
|
[ |
|
|
{ |
|
|
"generated_text": str, |
|
|
"finish_reason": str |
|
|
} |
|
|
] |
|
|
""" |
|
|
|
|
|
prompt: Optional[str] = data.get("inputs") |
|
|
if prompt is None: |
|
|
return [{"error": "Missing 'inputs' field in payload."}] |
|
|
|
|
|
max_new_tokens: int = int(data.get("max_new_tokens", 256)) |
|
|
temperature: float = float(data.get("temperature", 0.7)) |
|
|
top_p: float = float(data.get("top_p", 0.9)) |
|
|
top_k: int = int(data.get("top_k", 50)) |
|
|
repetition_penalty: float = float(data.get("repetition_penalty", 1.05)) |
|
|
stop_sequences = data.get("stop_sequences", None) |
|
|
|
|
|
|
|
|
inputs = self.tokenizer( |
|
|
prompt, |
|
|
return_tensors="pt", |
|
|
padding=False, |
|
|
truncation=True, |
|
|
).to(self.device) |
|
|
|
|
|
|
|
|
gen_kwargs = dict( |
|
|
max_new_tokens=max_new_tokens, |
|
|
do_sample=True, |
|
|
temperature=temperature, |
|
|
top_p=top_p, |
|
|
top_k=top_k, |
|
|
repetition_penalty=repetition_penalty, |
|
|
pad_token_id=self.tokenizer.pad_token_id, |
|
|
eos_token_id=self.tokenizer.eos_token_id, |
|
|
) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
output_ids = self.model.generate( |
|
|
**inputs, |
|
|
**gen_kwargs, |
|
|
) |
|
|
|
|
|
|
|
|
full_text = self.tokenizer.decode( |
|
|
output_ids[0], |
|
|
skip_special_tokens=True, |
|
|
) |
|
|
|
|
|
|
|
|
if full_text.startswith(prompt): |
|
|
generated_text = full_text[len(prompt) :].lstrip() |
|
|
else: |
|
|
generated_text = full_text |
|
|
|
|
|
|
|
|
finish_reason = "length" |
|
|
if stop_sequences: |
|
|
for stop in stop_sequences: |
|
|
idx = generated_text.find(stop) |
|
|
if idx != -1: |
|
|
generated_text = generated_text[:idx] |
|
|
finish_reason = "stop" |
|
|
break |
|
|
|
|
|
return [ |
|
|
{ |
|
|
"generated_text": generated_text, |
|
|
"finish_reason": finish_reason, |
|
|
} |
|
|
] |
|
|
|