import re import copy from transformers import pipeline class ImageCaption: def __init__(self, model_id, quantization_config): self.pipe = pipeline("image-to-text", model=model_id, model_kwargs={"quantization_config": quantization_config}) def infer(self, image, prompt, temperature, length_penalty, repetition_penalty, max_length, min_length, top_p): outputs = self.pipe(images=image, prompt=prompt, generate_kwargs={ "temperature": temperature, "length_penalty": length_penalty, "repetition_penalty": repetition_penalty, "max_length": max_length, "min_length": min_length, "top_p": top_p}) return outputs[0]["generated_text"] def extract_response_pairs(self, text): turns = re.split(r'(USER:|ASSISTANT:)', text)[1:] turns = [turn.strip() for turn in turns if turn.strip()] conv_list = [] for i in range(0, len(turns[1::2]), 2): if i + 1 < len(turns[1::2]): conv_list.append([turns[1::2][i].lstrip(":"), turns[1::2][i + 1].lstrip(":")]) return conv_list def add_text(self, history, text): history.append([text, None]) return history, text def generate(self, history_chat, text_input, image, temperature, length_penalty, repetition_penalty, max_length, min_length, top_p): chat_history = " ".join(history_chat) chat_history += f"USER: \n{text_input}\nASSISTANT:" inference_result = self.infer(image, chat_history, temperature, length_penalty, repetition_penalty, max_length, min_length, top_p) chat_val = self.extract_response_pairs(inference_result) chat_state_list = copy.deepcopy(chat_val) return chat_state_list