File size: 6,670 Bytes
d2659b8
3d5eb81
75fa31b
 
d2659b8
 
 
 
 
 
 
 
e2d63b2
d2659b8
 
 
 
671b968
d2659b8
1deae37
d2659b8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dd7e51e
 
d2659b8
 
 
 
 
412fa1c
d2659b8
 
 
 
 
 
 
 
6a9b6f1
d2659b8
dd45a31
d2659b8
 
1deae37
 
d2659b8
 
d336446
d2659b8
 
 
 
 
 
 
 
1fff0da
 
d2659b8
71ee6de
 
 
d2659b8
d336446
d2659b8
 
1fff0da
d2659b8
 
 
 
 
 
 
 
 
 
 
 
 
d38a401
d2659b8
 
 
 
75fa31b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dd45a31
75fa31b
 
 
 
dd45a31
 
 
75fa31b
 
1fff0da
dd45a31
 
 
221fe9f
a929d20
221fe9f
 
 
 
 
 
 
 
 
 
dd45a31
 
 
 
 
 
428f41a
dd45a31
 
 
 
 
 
 
221fe9f
dd45a31
 
221fe9f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
import logging
import re
from typing import Dict, Any
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from peft import PeftConfig, PeftModel
import torch.cuda


LOGGER = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
device = "cuda" if torch.cuda.is_available() else "cpu"

MAX_INPUT_TOKEN_LENGTH = 16000

class EndpointHandler():
    def __init__(self, path=""):
        config = PeftConfig.from_pretrained(path)
        model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path, load_in_8bit=True, trust_remote_code=True, device_map='auto')
        self.tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)
        self.tokenizer.chat_template = "{% for message in messages %}{% if message['role'] == 'user' %}{{ bos_token + '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'system' %}{{ bos_token + '<<SYS>>\\n' + message['content'] + '\\n<</SYS>>\\n\\n' }}{% elif message['role'] == 'assistant' %}{{ ' '  + message['content'] + ' ' + eos_token }}{% endif %}{% endfor %}"
        # Load the Lora model
        self.model = PeftModel.from_pretrained(model, path)

    def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
        # Get inputs
        # Extract required parameters from data
        message = data.get("message")
        chat_history = data.get("chat_history", [])
        system_prompt = data.get("system_prompt", "")

        # Extract optional parameters for the generate function logic
        instruction = data.get("instruction")
        conclusions = data.get("conclusions")
        context = data.get("context")

        # Optional parameters with default values
        max_new_tokens = data.get("max_new_tokens", 1024)
        temperature = data.get("temperature", 0.6)
        top_p = data.get("top_p", 0.9)
        top_k = data.get("top_k", 50)
        repetition_penalty = data.get("repetition_penalty", 1.2)

        if message is None or system_prompt is None:
            raise ValueError("Missing required parameters.")

        # Call the generate function
        output = generate(
            tokenizer=self.tokenizer,
            model=self.model,
            message=message,
            chat_history=chat_history,
            system_prompt=system_prompt,
            instruction=instruction,
            conclusions=conclusions,
            context=context,
            max_new_tokens=max_new_tokens,
            temperature=temperature,
            top_p=top_p,
            top_k=top_k,
            repetition_penalty=repetition_penalty
        )

        # Postprocess
        prediction = output
        LOGGER.info(f"Generated text: {prediction}")
        return prediction

def generate(
    tokenizer,
    model,
    message: str,
    chat_history: list[tuple[str, str]],
    system_prompt: str = "",
    instruction: str = None,
    conclusions: list[tuple[str, str]] = None,
    context: list[str] = None,
    max_new_tokens: int = 1024,
    temperature: float = 0.6,
    top_p: float = 0.9,
    top_k: int = 50,
    repetition_penalty: float = 1.2,
    end_sequences: list[str] = ["[INST]", "[/INST]", "\n"]
) -> dict:
    
    LOGGER.info(f"instruction: {instruction}")
    LOGGER.info(f"conclusions: {conclusions}")
    LOGGER.info(f"context: {context}")
    # Check if the system_prompt is provided, else construct it from instruction, conclusions, and context
    if not system_prompt and instruction is not None and conclusions is not None and context is not None:
        system_prompt = "Instruction: {}\nConclusions:\n".format(instruction)
        for idx, (conclusion, conclusion_key) in enumerate(conclusions):
            system_prompt += "{}: {}\n".format(conclusion, conclusion_key)
        system_prompt += "\nContext:\n"
        for idx, ctx in enumerate(context):
            system_prompt += "{}: [{}]\n".format(ctx, idx + 1)

    # Construct conversation history
    conversation = [{"role": "system", "content": system_prompt}]
    for user, assistant in chat_history:
        if user is not None:
            conversation.extend([{"role": "user", "content": user}])
        conversation.extend([{"role": "assistant", "content": assistant}])
    conversation.append({"role": "user", "content": message})

    # Tokenize and process the conversation
    input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
    if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
        input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
    input_ids = input_ids.to(model.device)

    # Create a TextIteratorStreamer instance
    streamer = TextIteratorStreamer(
        tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=False
    )

    # Generate the response using TextIteratorStreamer
    generate_kwargs = dict(
        {"input_ids": input_ids},
        streamer=streamer,
        max_new_tokens=max_new_tokens,
        do_sample=True,
        top_p=top_p,
        top_k=top_k,
        temperature=temperature,
        num_beams=1,
        repetition_penalty=repetition_penalty
    )
    model.generate(**generate_kwargs)

    outputs = []
    generated_text = ""
    conclusion_found = None
    context_numbers = []
    for text in streamer:
        outputs.append(text)
        generated_text = "".join(outputs)
        for end_sequence in end_sequences:
            if end_sequence in generated_text:
                generated_text = generated_text.replace(end_sequence, "")
                return parse(generated_text, conclusions, end_sequences)
    return parse(generated_text, conclusions, end_sequences)

def parse(generated_text: str, conclusions: list[tuple[str, str]], end_sequences: list[str]) -> dict:
    # Initialize variables
    conclusion_found = None
    context_numbers = []

    # Remove end sequences and clean the text
    for end_sequence in end_sequences:
        generated_text = generated_text.replace(end_sequence, "")
    generated_text = generated_text.strip()

    # Check for conclusion keys in the generated text
    if conclusions:
        for conclusion_key, _ in conclusions:
            if conclusion_key in generated_text:
                conclusion_found = conclusion_key
                generated_text = generated_text.replace(conclusion_key, "")

    # Extract context numbers from the generated text
    context_pattern = r"\[\d+\]"
    context_matches = re.findall(context_pattern, generated_text)
    context_numbers = [int(match.strip("[]")) for match in context_matches]

    return {
        "generated_text": generated_text,
        "conclusion": conclusion_found,
        "context": context_numbers
    }