| import re | |
| import copy | |
| from transformers import pipeline | |
| class ImageCaption: | |
| def __init__(self, model_id, quantization_config): | |
| self.pipe = pipeline("image-to-text", model=model_id, model_kwargs={"quantization_config": quantization_config}) | |
| def infer(self, image, prompt, temperature, length_penalty, repetition_penalty, max_length, min_length, top_p): | |
| outputs = self.pipe(images=image, prompt=prompt, | |
| generate_kwargs={ | |
| "temperature": temperature, | |
| "length_penalty": length_penalty, | |
| "repetition_penalty": repetition_penalty, | |
| "max_length": max_length, | |
| "min_length": min_length, | |
| "top_p": top_p}) | |
| return outputs[0]["generated_text"] | |
| def extract_response_pairs(self, text): | |
| turns = re.split(r'(USER:|ASSISTANT:)', text)[1:] | |
| turns = [turn.strip() for turn in turns if turn.strip()] | |
| conv_list = [] | |
| for i in range(0, len(turns[1::2]), 2): | |
| if i + 1 < len(turns[1::2]): | |
| conv_list.append([turns[1::2][i].lstrip(":"), turns[1::2][i + 1].lstrip(":")]) | |
| return conv_list | |
| def add_text(self, history, text): | |
| history.append([text, None]) | |
| return history, text | |
| def generate(self, history_chat, text_input, image, temperature, length_penalty, repetition_penalty, max_length, | |
| min_length, top_p): | |
| chat_history = " ".join(history_chat) | |
| chat_history += f"USER: <image>\n{text_input}\nASSISTANT:" | |
| inference_result = self.infer(image, chat_history, temperature, length_penalty, repetition_penalty, max_length, | |
| min_length, top_p) | |
| chat_val = self.extract_response_pairs(inference_result) | |
| chat_state_list = copy.deepcopy(chat_val) | |
| return chat_state_list | |