File size: 2,844 Bytes
68fed24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78ce12c
68fed24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import hashlib
import os
import random
import string
import open_clip
import requests
import torch
import shutil
from PIL import Image

model, _, preprocess = open_clip.create_model_and_transforms('ViT-B-32', pretrained='laion2b_s34b_b79k')
tokenizer = open_clip.get_tokenizer('ViT-B-32')


def generate_random_string_and_hash(length=8):
    # 生成随机字符串
    letters = string.ascii_letters
    random_string = ''.join(random.choice(letters) for i in range(length))

    # 生成哈希值
    hash_value = hashlib.sha256(random_string.encode()).hexdigest()

    return hash_value


def process_img(image_input, text_inputs, classes):
    with torch.no_grad():
        image_features = model.encode_image(image_input)
        text_features = model.encode_text(text_inputs)

    # Pick the top 5 most similar labels for the image
    image_features /= image_features.norm(dim=-1, keepdim=True)
    text_features /= text_features.norm(dim=-1, keepdim=True)
    similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1)
    value, index = similarity[0].topk(1)
    class_name = classes[index]
    return class_name


def get_result(question, data, example=None):
    global model, preprocess, tokenizer
    model, _, preprocess = open_clip.create_model_and_transforms('ViT-B-32', pretrained='laion2b_s34b_b79k')
    tokenizer = open_clip.get_tokenizer('ViT-B-32')
    sess = requests.session()
    result = []
    dir_path = generate_random_string_and_hash()
    os.makedirs(f"temp/{dir_path}", exist_ok=True)
    raw_answer = sess.get("https://yundisk.de/d/OneDrive_5G/Pic/data.json").json()
    if question in raw_answer:
        raw_answer=raw_answer[question]
        classes = raw_answer["classes"]
        text_inputs = torch.cat([tokenizer(f"a photo of {c}") for c in classes])
        if raw_answer["need_example"]:
            if example:
                example_file_path = f"{generate_random_string_and_hash()}.png"
                with open(f"temp/{dir_path}/{example_file_path}", "wb+") as f:
                    f.write(sess.get(example).content)
                example = preprocess(Image.open(f"temp/{dir_path}/{example_file_path}")).unsqueeze(0)
                answer = [process_img(example, text_inputs, classes)]
            else:
                print(question)
                return None
        else:
            answer = raw_answer["answer"]
        for img in data:
            img_path = f"{generate_random_string_and_hash()}.png"
            with open(f"temp/{dir_path}/{img_path}", "wb+") as f:
                f.write(sess.get(img).content)
            img = preprocess(Image.open(f"temp/{dir_path}/{img_path}")).unsqueeze(0)
            class_name = process_img(img, text_inputs, classes)
            result.append(class_name in answer)
    shutil.rmtree(f"temp/{dir_path}")
    return result