import torch from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig from typing import Dict, List, Any class EndpointHandler: def __init__(self, path=""): """ Initializes the model and tokenizer. `path` is automatically provided by Hugging Face (it points to your repo files). """ print("🚀 Initializing PropagationShield Handler...") self.tokenizer = AutoTokenizer.from_pretrained(path) # 1. Configure 4-bit quantization to prevent OOM and System RAM limits bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.float16 ) # 2. Load the model safely self.model = AutoModelForCausalLM.from_pretrained( path, quantization_config=bnb_config, device_map="auto", torch_dtype=torch.float16, low_cpu_mem_usage=True, # Crucial to prevent the 30GB RAM crash during boot ) print("✅ PropagationShield Loaded Successfully!") def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: """ Runs inference on the incoming request. """ # Parse incoming data inputs = data.pop("inputs", data) parameters = data.pop("parameters", {}) max_new_tokens = parameters.get("max_new_tokens", 512) temperature = parameters.get("temperature", 0.1) # 3. Format the prompt # If the user sends a list of messages [{"role": "system", "content": "..."}, ...] if isinstance(inputs, list): prompt = self.tokenizer.apply_chat_template( inputs, tokenize=False, add_generation_prompt=True ) # If the user sends a raw formatted string else: prompt = str(inputs) # 4. Tokenize input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(self.model.device) # 5. Generate with torch.no_grad(): output_ids = self.model.generate( input_ids, max_new_tokens=max_new_tokens, temperature=temperature, do_sample=True if temperature > 0.0 else False, pad_token_id=self.tokenizer.eos_token_id ) # 6. Isolate and decode only the newly generated tokens generated_ids = output_ids[0][input_ids.shape[-1]:] generated_text = self.tokenizer.decode(generated_ids, skip_special_tokens=True) # Return in standard HF API format return [{"generated_text": generated_text.strip()}]