ks415 commited on
Commit
19dd951
·
verified ·
1 Parent(s): 7376086

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +42 -0
  2. requirements.txt +68 -0
app.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import clip
4
+ from PIL import Image
5
+
6
+ device = "cuda" if torch.cuda.is_available() else "cpu"
7
+
8
+ def calculate_similarity(image, text, model_name):
9
+
10
+ model, preprocess = load_model(model_name)
11
+ # 画像の前処理
12
+ image = preprocess(image).unsqueeze(0).to(device)
13
+
14
+ # テキストの前処理
15
+ text = clip.tokenize([text]).to(device)
16
+
17
+ # 類似度の計算
18
+ with torch.no_grad():
19
+ image_features = model.encode_image(image)
20
+ text_features = model.encode_text(text)
21
+
22
+ similarity = torch.cosine_similarity(image_features, text_features).cpu().numpy()[0]
23
+
24
+ return similarity
25
+
26
+ def load_model(model_name):
27
+ model, preprocess = clip.load(model_name, device=device)
28
+ return model, preprocess
29
+
30
+ iface = gr.Interface(
31
+ fn=calculate_similarity,
32
+ inputs=[
33
+ gr.Image(type="pil"),
34
+ gr.Textbox(lines=2, placeholder="A photo of a ..."),
35
+ gr.Radio(["ViT-B/32", "ViT-B/16", "ViT-L/14", "ViT-L/14@336px"], label="モデル選択")
36
+ ],
37
+ outputs="number",
38
+ title="CLIPによる画像とテキストの類似度計算",
39
+ description="類似度を計算したい画像とテキストを入力し,使用するCLIPモデルを選択してください."
40
+ )
41
+
42
+ iface.launch()
requirements.txt ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiofiles==23.2.1
2
+ annotated-types==0.7.0
3
+ anyio==4.8.0
4
+ certifi==2025.1.31
5
+ charset-normalizer==3.4.1
6
+ click==8.1.8
7
+ clip @ git+https://github.com/openai/CLIP.git@dcba3cb2e2827b402d2701e7e1c7d9fed8a20ef1
8
+ contourpy==1.3.0
9
+ cycler==0.12.1
10
+ exceptiongroup==1.2.2
11
+ fastapi==0.115.8
12
+ ffmpy==0.5.0
13
+ filelock==3.17.0
14
+ fonttools==4.56.0
15
+ fsspec==2025.2.0
16
+ ftfy==6.3.1
17
+ gradio==4.44.1
18
+ gradio_client==1.3.0
19
+ h11==0.14.0
20
+ httpcore==1.0.7
21
+ httpx==0.28.1
22
+ huggingface-hub==0.28.1
23
+ idna==3.10
24
+ importlib_resources==6.5.2
25
+ Jinja2==3.1.5
26
+ kiwisolver==1.4.7
27
+ markdown-it-py==3.0.0
28
+ MarkupSafe==2.1.5
29
+ matplotlib==3.9.4
30
+ mdurl==0.1.2
31
+ mpmath==1.3.0
32
+ networkx==3.2.1
33
+ numpy==2.0.2
34
+ orjson==3.10.15
35
+ packaging==24.2
36
+ pandas==2.2.3
37
+ pillow==10.4.0
38
+ pydantic==2.10.6
39
+ pydantic_core==2.27.2
40
+ pydub==0.25.1
41
+ Pygments==2.19.1
42
+ pyparsing==3.2.1
43
+ python-dateutil==2.9.0.post0
44
+ python-multipart==0.0.20
45
+ pytz==2025.1
46
+ PyYAML==6.0.2
47
+ regex==2024.11.6
48
+ requests==2.32.3
49
+ rich==13.9.4
50
+ ruff==0.9.5
51
+ semantic-version==2.10.0
52
+ shellingham==1.5.4
53
+ six==1.17.0
54
+ sniffio==1.3.1
55
+ starlette==0.45.3
56
+ sympy==1.13.1
57
+ tomlkit==0.12.0
58
+ torch==2.6.0
59
+ torchvision==0.21.0
60
+ tqdm==4.67.1
61
+ typer==0.15.1
62
+ typing_extensions==4.12.2
63
+ tzdata==2025.1
64
+ urllib3~=2.0
65
+ uvicorn==0.34.0
66
+ wcwidth==0.2.13
67
+ websockets==12.0
68
+ zipp==3.21.0