ishworrsubedii's picture
Added new features and improved code formatting:
32a0eda
import re
import copy
from transformers import pipeline
class ImageCaption:
def __init__(self, model_id, quantization_config):
self.pipe = pipeline("image-to-text", model=model_id, model_kwargs={"quantization_config": quantization_config})
def infer(self, image, prompt, temperature, length_penalty, repetition_penalty, max_length, min_length, top_p):
outputs = self.pipe(images=image, prompt=prompt,
generate_kwargs={
"temperature": temperature,
"length_penalty": length_penalty,
"repetition_penalty": repetition_penalty,
"max_length": max_length,
"min_length": min_length,
"top_p": top_p})
return outputs[0]["generated_text"]
def extract_response_pairs(self, text):
turns = re.split(r'(USER:|ASSISTANT:)', text)[1:]
turns = [turn.strip() for turn in turns if turn.strip()]
conv_list = []
for i in range(0, len(turns[1::2]), 2):
if i + 1 < len(turns[1::2]):
conv_list.append([turns[1::2][i].lstrip(":"), turns[1::2][i + 1].lstrip(":")])
return conv_list
def add_text(self, history, text):
history.append([text, None])
return history, text
def generate(self, history_chat, text_input, image, temperature, length_penalty, repetition_penalty, max_length,
min_length, top_p):
chat_history = " ".join(history_chat)
chat_history += f"USER: <image>\n{text_input}\nASSISTANT:"
inference_result = self.infer(image, chat_history, temperature, length_penalty, repetition_penalty, max_length,
min_length, top_p)
chat_val = self.extract_response_pairs(inference_result)
chat_state_list = copy.deepcopy(chat_val)
return chat_state_list