File size: 1,855 Bytes
6f3e563
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32cb5ae
 
6f3e563
 
 
32cb5ae
 
6f3e563
 
 
32cb5ae
 
6f3e563
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import requests
from PIL import Image

import evaluate


metric = evaluate.load("./clip_score.py")


def download_image(image_path):
    if image_path.startswith("http"):
        image = Image.open(requests.get(image_path, stream=True).raw)
    else:
        image = Image.open(image_path)
    return image


def compute_clip_score(image, text):
    if not isinstance(image, list):
        references = [image]
    else:
        references = image
    if not isinstance(text, list):
        predictions = [text]
    else:
        predictions = text
    results = metric.compute(predictions=predictions, references=references)
    return results["clip_score"]


predictions = ["A cat sitting on a couch", "A scenic view of mountains during sunset"]
references = [
    "https://images.unsplash.com/photo-1720539222585-346e73f01536",
    "https://images.unsplash.com/photo-1694253987647-4eebcf679974",
]
references = [download_image(url) for url in references]

test_cases = [
    {
        "predictions": predictions,
        "references": references,
        "result": {"clip_score": 0.307},
    },
    {
        "predictions": predictions[0],
        "references": references[0],
        "result": {"clip_score": 0.304},
    },
    {
        "predictions": predictions[1],
        "references": references[1],
        "result": {"clip_score": 0.310},
    },
    {
        "predictions": predictions[0],
        "references": references[1],
        "result": {"clip_score": 0.106},
    },
    {
        "predictions": predictions[1],
        "references": references[0],
        "result": {"clip_score": 0.134},
    },
]

for i, test_case in enumerate(test_cases):
    result = compute_clip_score(test_case["references"], test_case["predictions"])
    error = abs(result - test_case["result"]["clip_score"])
    assert error < 0.1, f"Test case {i} failed"