File size: 3,643 Bytes
c0bc97c
 
 
 
4bcb484
 
 
c0bc97c
 
 
 
 
 
538a0aa
 
 
50e9683
c0bc97c
538a0aa
c0bc97c
538a0aa
 
9715f6f
538a0aa
c0bc97c
c4c2bd8
 
 
 
 
9715f6f
c4c2bd8
538a0aa
c4c2bd8
 
 
 
c0bc97c
 
 
 
 
 
 
 
 
 
 
 
 
 
50e9683
 
 
 
c0bc97c
50e9683
c0bc97c
 
 
 
 
 
 
 
 
 
50e9683
c0bc97c
 
 
 
 
 
 
 
c164643
9715f6f
c0bc97c
 
c164643
 
c0bc97c
 
9715f6f
c0bc97c
50e9683
c0bc97c
50e9683
 
 
9715f6f
50e9683
 
 
 
 
 
 
9715f6f
50e9683
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
import torch
import json
import os

os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"

class EndpointHandler:
    def __init__(self, path=""):
        """
        Initialize the handler with the model from the given path
        """
        model_name = "meta-llama/Llama-3.3-70B-Instruct"

        self.tokenizer = AutoTokenizer.from_pretrained(model_name)

        base_model = AutoModelForCausalLM.from_pretrained(
            model_name,
            torch_dtype=torch.float16,
            device_map="auto",
            trust_remote_code=True,
            load_in_8bit=True,  
            low_cpu_mem_usage=True
        )

        try:
            self.model = PeftModel.from_pretrained(
                base_model,
                path,
                is_trainable=False 
            )
            print("Successfully loaded adapter with base model")
        except Exception as e:
            print(f"Error loading adapter: {e}")
            print("Falling back to base model without adapter")
            self.model = base_model
        
        try:
            with open(f"{path}/chat_template.jinja", "r") as f:
                self.chat_template = f.read()
        except:
            self.chat_template = None

    def __call__(self, data):
        """
        Process the input data and return the model's response
        """
        inputs = data.get("inputs", "")
        parameters = data.get("parameters", {})
        
        default_prompt = "Break this text into WhatsApp messages like a real person would send them. Split where you'd naturally pause: after greetings, before/after questions, between different thoughts, when changing topics. Preserve exact wording - just divide where someone would actually hit 'send' and start a new message. Output JSON array."

        custom_prompt = parameters.get("prompt", default_prompt)

        messages = [
            {"role": "system", "content": custom_prompt},
            {"role": "user", "content": inputs}
        ]
        
        if self.chat_template:
            text = self.tokenizer.apply_chat_template(
                messages, 
                tokenize=False, 
                add_generation_prompt=True
            )
        else:
            text = f"{custom_prompt}\nUser: {inputs}\nAssistant:"
        
        # Tokenize
        model_inputs = self.tokenizer(text, return_tensors="pt").to(self.model.device)
        
        # Generate response
        with torch.no_grad():
            outputs = self.model.generate(
                **model_inputs,
                max_new_tokens=parameters.get("max_new_tokens", 100),
                temperature=parameters.get("temperature", 0.3),  
                top_p=parameters.get("top_p", 0.9),
                do_sample=True,
                pad_token_id=self.tokenizer.eos_token_id,
                repetition_penalty=1.1
            )
        
    
        response = self.tokenizer.decode(
            outputs[0][model_inputs.input_ids.shape[-1]:],
            skip_special_tokens=True
        ).strip()

        try:
          
            if response.startswith('[') and response.endswith(']'):
                parsed = json.loads(response)
                if isinstance(parsed, list):
                    formatted_response = response
                else:
                    formatted_response = json.dumps([response])
            else:

                formatted_response = json.dumps([response])
        except:
            formatted_response = json.dumps([inputs])

        return [{"content": formatted_response}]