File size: 2,794 Bytes
4591223
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from typing import Dict, List, Any

class EndpointHandler:
    def __init__(self, path=""):
        """
        Initializes the model and tokenizer. 
        `path` is automatically provided by Hugging Face (it points to your repo files).
        """
        print("🚀 Initializing PropagationShield Handler...")
        
        self.tokenizer = AutoTokenizer.from_pretrained(path)
        
        # 1. Configure 4-bit quantization to prevent OOM and System RAM limits
        bnb_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_use_double_quant=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_compute_dtype=torch.float16
        )
        
        # 2. Load the model safely
        self.model = AutoModelForCausalLM.from_pretrained(
            path,
            quantization_config=bnb_config,
            device_map="auto",
            torch_dtype=torch.float16,
            low_cpu_mem_usage=True, # Crucial to prevent the 30GB RAM crash during boot
        )
        print("✅ PropagationShield Loaded Successfully!")

    def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
        """
        Runs inference on the incoming request.
        """
        # Parse incoming data
        inputs = data.pop("inputs", data)
        parameters = data.pop("parameters", {})
        
        max_new_tokens = parameters.get("max_new_tokens", 512)
        temperature = parameters.get("temperature", 0.1)
        
        # 3. Format the prompt
        # If the user sends a list of messages [{"role": "system", "content": "..."}, ...]
        if isinstance(inputs, list):
            prompt = self.tokenizer.apply_chat_template(
                inputs, tokenize=False, add_generation_prompt=True
            )
        # If the user sends a raw formatted string
        else:
            prompt = str(inputs)
            
        # 4. Tokenize
        input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(self.model.device)
        
        # 5. Generate
        with torch.no_grad():
            output_ids = self.model.generate(
                input_ids,
                max_new_tokens=max_new_tokens,
                temperature=temperature,
                do_sample=True if temperature > 0.0 else False,
                pad_token_id=self.tokenizer.eos_token_id
            )
            
        # 6. Isolate and decode only the newly generated tokens
        generated_ids = output_ids[0][input_ids.shape[-1]:]
        generated_text = self.tokenizer.decode(generated_ids, skip_special_tokens=True)
        
        # Return in standard HF API format
        return [{"generated_text": generated_text.strip()}]