File size: 2,339 Bytes
c6a12ae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import nltk
from config import tokenizer
import torch
from utils.image_processing import get_clip_embeddings

nltk.download('punkt')
nltk.download('punkt_tab')

def remove_punctuation(text):
    newtext = ''.join([char for char in text if char.isalnum() or char.isspace()])
    newtext = ' '.join(newtext.split())
    return newtext

def preprocess_text(text):
    text_no_punct = remove_punctuation(text)
    return text_no_punct

def getStringAfter(output, start_str):
    if start_str in output:
        answer = output.split(start_str)[1]
    else:
        answer = output

    answer = preprocess_text(answer)
    return answer

def getAnswerPart(output):
    input_words = nltk.word_tokenize("<|system|> \n You are an assistant good at understanding the context. <|end|> \n <|user|> \n") + nltk.word_tokenize("\n Describe the objects and their relationship in the given context.<|end|> \n <|assistant|> \n")
    output_words = nltk.word_tokenize(output)
    filtered_words = [word for word in output_words if word.lower() not in [w.lower() for w in input_words]]
    return ' '.join(filtered_words)

def getInputs(image_path, question, answer=""):
    image_features = None
    num_image_tokens = 0

    if image_path is not None:
        image_features = get_clip_embeddings(image_path)
        num_image_tokens = image_features.shape[1]

    start_text = f"<|system|>\nYou are an assistant good at understanding the context.<|end|>\n<|user|>\n "
    end_text = f" .\n  Describe the objects and their relationship from the context. <|end|>\n<|assistant|>\n {answer}"

    start_tokens = tokenizer(start_text, padding=True, truncation=True, max_length=512, return_tensors="pt")
    end_tokens = tokenizer(end_text, padding=True, truncation=True, max_length=512, return_tensors="pt")
    
    start_input_ids = start_tokens['input_ids']
    start_attention_mask = start_tokens['attention_mask']
    end_input_ids = end_tokens['input_ids']
    end_attention_mask = end_tokens['attention_mask']

    if image_path is not None:
        attention_mask = torch.cat([start_attention_mask, torch.ones((1, num_image_tokens), dtype=torch.long), end_attention_mask], dim=1)
    else:
        attention_mask = torch.cat([start_attention_mask, end_attention_mask], dim=1)

    return start_input_ids, end_input_ids, image_features, attention_mask