from typing import Dict, List, Any, Optional import torch from transformers import ( AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer, ) class EndpointHandler: """ Custom Inference Endpoints handler for algorythmtechnologies/Warren-8B-Uncensored-2000. Expected JSON payload: { "inputs": "user prompt or message", "max_new_tokens": 256, # optional "temperature": 0.7, # optional "top_p": 0.9, # optional "top_k": 50, # optional "repetition_penalty": 1.1, # optional "stop_sequences": [""] # optional } Returns: [ { "generated_text": "...", "finish_reason": "length|stop|error" } ] """ def __init__(self, path: str = ""): # Choose device self.device = "cuda" if torch.cuda.is_available() else "cpu" # Load tokenizer and model from the repository path self.tokenizer = AutoTokenizer.from_pretrained(path or ".") # Make sure there is a pad_token for generation if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token self.model = AutoModelForCausalLM.from_pretrained( path or ".", torch_dtype=torch.float16 if self.device == "cuda" else torch.float32, device_map="auto" if self.device == "cuda" else None, ) # Set model to eval mode self.model.eval() def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: """ data args: inputs (str): user text prompt max_new_tokens (int, optional) temperature (float, optional) top_p (float, optional) top_k (int, optional) repetition_penalty (float, optional) stop_sequences (List[str], optional) Return: A list with one dict: [ { "generated_text": str, "finish_reason": str } ] """ # Extract inputs prompt: Optional[str] = data.get("inputs") if prompt is None: return [{"error": "Missing 'inputs' field in payload."}] max_new_tokens: int = int(data.get("max_new_tokens", 256)) temperature: float = float(data.get("temperature", 0.7)) top_p: float = float(data.get("top_p", 0.9)) top_k: int = int(data.get("top_k", 50)) repetition_penalty: float = float(data.get("repetition_penalty", 1.05)) stop_sequences = data.get("stop_sequences", None) # Tokenize inputs = self.tokenizer( prompt, return_tensors="pt", padding=False, truncation=True, ).to(self.device) # Configure basic generation kwargs gen_kwargs = dict( max_new_tokens=max_new_tokens, do_sample=True, temperature=temperature, top_p=top_p, top_k=top_k, repetition_penalty=repetition_penalty, pad_token_id=self.tokenizer.pad_token_id, eos_token_id=self.tokenizer.eos_token_id, ) # Run generation with torch.no_grad(): output_ids = self.model.generate( **inputs, **gen_kwargs, ) # Decode full text and strip the original prompt full_text = self.tokenizer.decode( output_ids[0], skip_special_tokens=True, ) # Try to remove the prompt from the beginning for cleaner output if full_text.startswith(prompt): generated_text = full_text[len(prompt) :].lstrip() else: generated_text = full_text # Apply stop sequences post-hoc if provided finish_reason = "length" if stop_sequences: for stop in stop_sequences: idx = generated_text.find(stop) if idx != -1: generated_text = generated_text[:idx] finish_reason = "stop" break return [ { "generated_text": generated_text, "finish_reason": finish_reason, } ]