File size: 1,982 Bytes
32a0eda
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
import re
import copy
from transformers import pipeline


class ImageCaption:
    def __init__(self, model_id, quantization_config):
        self.pipe = pipeline("image-to-text", model=model_id, model_kwargs={"quantization_config": quantization_config})

    def infer(self, image, prompt, temperature, length_penalty, repetition_penalty, max_length, min_length, top_p):
        outputs = self.pipe(images=image, prompt=prompt,
                            generate_kwargs={
                                "temperature": temperature,
                                "length_penalty": length_penalty,
                                "repetition_penalty": repetition_penalty,
                                "max_length": max_length,
                                "min_length": min_length,
                                "top_p": top_p})
        return outputs[0]["generated_text"]

    def extract_response_pairs(self, text):
        turns = re.split(r'(USER:|ASSISTANT:)', text)[1:]
        turns = [turn.strip() for turn in turns if turn.strip()]
        conv_list = []
        for i in range(0, len(turns[1::2]), 2):
            if i + 1 < len(turns[1::2]):
                conv_list.append([turns[1::2][i].lstrip(":"), turns[1::2][i + 1].lstrip(":")])
        return conv_list

    def add_text(self, history, text):
        history.append([text, None])
        return history, text

    def generate(self, history_chat, text_input, image, temperature, length_penalty, repetition_penalty, max_length,
                 min_length, top_p):
        chat_history = " ".join(history_chat)
        chat_history += f"USER: <image>\n{text_input}\nASSISTANT:"

        inference_result = self.infer(image, chat_history, temperature, length_penalty, repetition_penalty, max_length,
                                      min_length, top_p)
        chat_val = self.extract_response_pairs(inference_result)

        chat_state_list = copy.deepcopy(chat_val)
        return chat_state_list