File size: 6,670 Bytes
d2659b8 3d5eb81 75fa31b d2659b8 e2d63b2 d2659b8 671b968 d2659b8 1deae37 d2659b8 dd7e51e d2659b8 412fa1c d2659b8 6a9b6f1 d2659b8 dd45a31 d2659b8 1deae37 d2659b8 d336446 d2659b8 1fff0da d2659b8 71ee6de d2659b8 d336446 d2659b8 1fff0da d2659b8 d38a401 d2659b8 75fa31b dd45a31 75fa31b dd45a31 75fa31b 1fff0da dd45a31 221fe9f a929d20 221fe9f dd45a31 428f41a dd45a31 221fe9f dd45a31 221fe9f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 |
import logging
import re
from typing import Dict, Any
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from peft import PeftConfig, PeftModel
import torch.cuda
LOGGER = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
device = "cuda" if torch.cuda.is_available() else "cpu"
MAX_INPUT_TOKEN_LENGTH = 16000
class EndpointHandler():
def __init__(self, path=""):
config = PeftConfig.from_pretrained(path)
model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path, load_in_8bit=True, trust_remote_code=True, device_map='auto')
self.tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)
self.tokenizer.chat_template = "{% for message in messages %}{% if message['role'] == 'user' %}{{ bos_token + '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'system' %}{{ bos_token + '<<SYS>>\\n' + message['content'] + '\\n<</SYS>>\\n\\n' }}{% elif message['role'] == 'assistant' %}{{ ' ' + message['content'] + ' ' + eos_token }}{% endif %}{% endfor %}"
# Load the Lora model
self.model = PeftModel.from_pretrained(model, path)
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
# Get inputs
# Extract required parameters from data
message = data.get("message")
chat_history = data.get("chat_history", [])
system_prompt = data.get("system_prompt", "")
# Extract optional parameters for the generate function logic
instruction = data.get("instruction")
conclusions = data.get("conclusions")
context = data.get("context")
# Optional parameters with default values
max_new_tokens = data.get("max_new_tokens", 1024)
temperature = data.get("temperature", 0.6)
top_p = data.get("top_p", 0.9)
top_k = data.get("top_k", 50)
repetition_penalty = data.get("repetition_penalty", 1.2)
if message is None or system_prompt is None:
raise ValueError("Missing required parameters.")
# Call the generate function
output = generate(
tokenizer=self.tokenizer,
model=self.model,
message=message,
chat_history=chat_history,
system_prompt=system_prompt,
instruction=instruction,
conclusions=conclusions,
context=context,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
top_k=top_k,
repetition_penalty=repetition_penalty
)
# Postprocess
prediction = output
LOGGER.info(f"Generated text: {prediction}")
return prediction
def generate(
tokenizer,
model,
message: str,
chat_history: list[tuple[str, str]],
system_prompt: str = "",
instruction: str = None,
conclusions: list[tuple[str, str]] = None,
context: list[str] = None,
max_new_tokens: int = 1024,
temperature: float = 0.6,
top_p: float = 0.9,
top_k: int = 50,
repetition_penalty: float = 1.2,
end_sequences: list[str] = ["[INST]", "[/INST]", "\n"]
) -> dict:
LOGGER.info(f"instruction: {instruction}")
LOGGER.info(f"conclusions: {conclusions}")
LOGGER.info(f"context: {context}")
# Check if the system_prompt is provided, else construct it from instruction, conclusions, and context
if not system_prompt and instruction is not None and conclusions is not None and context is not None:
system_prompt = "Instruction: {}\nConclusions:\n".format(instruction)
for idx, (conclusion, conclusion_key) in enumerate(conclusions):
system_prompt += "{}: {}\n".format(conclusion, conclusion_key)
system_prompt += "\nContext:\n"
for idx, ctx in enumerate(context):
system_prompt += "{}: [{}]\n".format(ctx, idx + 1)
# Construct conversation history
conversation = [{"role": "system", "content": system_prompt}]
for user, assistant in chat_history:
if user is not None:
conversation.extend([{"role": "user", "content": user}])
conversation.extend([{"role": "assistant", "content": assistant}])
conversation.append({"role": "user", "content": message})
# Tokenize and process the conversation
input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
input_ids = input_ids.to(model.device)
# Create a TextIteratorStreamer instance
streamer = TextIteratorStreamer(
tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=False
)
# Generate the response using TextIteratorStreamer
generate_kwargs = dict(
{"input_ids": input_ids},
streamer=streamer,
max_new_tokens=max_new_tokens,
do_sample=True,
top_p=top_p,
top_k=top_k,
temperature=temperature,
num_beams=1,
repetition_penalty=repetition_penalty
)
model.generate(**generate_kwargs)
outputs = []
generated_text = ""
conclusion_found = None
context_numbers = []
for text in streamer:
outputs.append(text)
generated_text = "".join(outputs)
for end_sequence in end_sequences:
if end_sequence in generated_text:
generated_text = generated_text.replace(end_sequence, "")
return parse(generated_text, conclusions, end_sequences)
return parse(generated_text, conclusions, end_sequences)
def parse(generated_text: str, conclusions: list[tuple[str, str]], end_sequences: list[str]) -> dict:
# Initialize variables
conclusion_found = None
context_numbers = []
# Remove end sequences and clean the text
for end_sequence in end_sequences:
generated_text = generated_text.replace(end_sequence, "")
generated_text = generated_text.strip()
# Check for conclusion keys in the generated text
if conclusions:
for conclusion_key, _ in conclusions:
if conclusion_key in generated_text:
conclusion_found = conclusion_key
generated_text = generated_text.replace(conclusion_key, "")
# Extract context numbers from the generated text
context_pattern = r"\[\d+\]"
context_matches = re.findall(context_pattern, generated_text)
context_numbers = [int(match.strip("[]")) for match in context_matches]
return {
"generated_text": generated_text,
"conclusion": conclusion_found,
"context": context_numbers
} |