File size: 4,282 Bytes
c6bf520
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
# ==============================
# 1. ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ import
# ==============================

import gradio as gr  # ์›น UI ์ƒ์„ฑ์„ ์œ„ํ•œ Gradio ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ
import torch  # PyTorch (๋”ฅ๋Ÿฌ๋‹ ๋ชจ๋ธ ์‹คํ–‰ ๋ฐ ํ…์„œ ์—ฐ์‚ฐ)
from PIL import Image  # ์ด๋ฏธ์ง€ ์ฒ˜๋ฆฌ (numpy โ†” PIL ๋ณ€ํ™˜)

# ViT ๋ชจ๋ธ (์ด๋ฏธ์ง€ ๋ถ„๋ฅ˜)
from transformers import ViTImageProcessor, ViTForImageClassification

# BLIP ๋ชจ๋ธ (์ด๋ฏธ์ง€ ์„ค๋ช… ์ƒ์„ฑ)
from transformers import BlipProcessor, BlipForConditionalGeneration


# ==============================
# 2. ViT ๋ชจ๋ธ ๋กœ๋“œ (์ด๋ฏธ์ง€ ๋ถ„๋ฅ˜)
# ==============================

model_name = "google/vit-base-patch16-224"  
# Vision Transformer ๋ชจ๋ธ ์ด๋ฆ„

processor = ViTImageProcessor.from_pretrained(model_name)  
# ์ด๋ฏธ์ง€ ์ „์ฒ˜๋ฆฌ๊ธฐ ๋กœ๋“œ (๋ฆฌ์‚ฌ์ด์ฆˆ, ์ •๊ทœํ™” ์ž๋™ ์ˆ˜ํ–‰)

model = ViTForImageClassification.from_pretrained(model_name)  
# ์ด๋ฏธ์ง€ ๋ถ„๋ฅ˜ ๋ชจ๋ธ ๋กœ๋“œ (์‚ฌ์ „ ํ•™์Šต๋œ ๊ฐ€์ค‘์น˜ ํฌํ•จ)


# ==============================
# 3. BLIP ๋ชจ๋ธ ๋กœ๋“œ (์ด๋ฏธ์ง€ ์„ค๋ช…)
# ==============================

caption_processor = BlipProcessor.from_pretrained(
    "Salesforce/blip-image-captioning-base"
)  
# ์ด๋ฏธ์ง€ โ†’ ํ…์ŠคํŠธ ๋ณ€ํ™˜์„ ์œ„ํ•œ ์ „์ฒ˜๋ฆฌ๊ธฐ

caption_model = BlipForConditionalGeneration.from_pretrained(
    "Salesforce/blip-image-captioning-base"
)  
# ์ด๋ฏธ์ง€ ์„ค๋ช… ์ƒ์„ฑ ๋ชจ๋ธ


# ==============================
# 4. ์ด๋ฏธ์ง€ ์„ค๋ช… ํ•จ์ˆ˜ (์—๋Ÿฌ ์ˆ˜์ • ํ•ต์‹ฌ)
# ==============================

def generate_caption(img):

    # ์ด๋ฏธ PIL Image์ธ์ง€ ํ™•์ธ (์ค‘๋ณต ๋ณ€ํ™˜ ๋ฐฉ์ง€)
    if not isinstance(img, Image.Image):
        img = Image.fromarray(img)

    # BLIP ์ž…๋ ฅ ์ „์ฒ˜๋ฆฌ(์ด๋ฏธ์ง€๋ฅผ ๋ชจ๋ธ ์ž…๋ ฅ์šฉ ํ…์„œ(pt=PyTorch)๋กœ ๋ณ€ํ™˜)
    inputs = caption_processor(images=img, return_tensors="pt")

    # ๋ชจ๋ธ ์ถ”๋ก  (gradient ๋ฏธ๋ถ„ ๊ณ„์‚ฐ ๋น„ํ™œ์„ฑํ™”) => ๊ฒฝ์‚ฌ ํ•˜๊ฐ•๋ฒ•(๊ธฐ์šธ๊ธฐ ๊ณ„์‚ฐX) ์†๋„ ํ–ฅ์ƒ
    with torch.no_grad():
        # ๋ชจ๋ธ์„ ํ†ตํ•ด ์ด๋ฏธ์ง€์— ๋Œ€ํ•œ ํ…์ŠคํŠธ ํ† ํฐ(์ˆซ์ž ๋ฐฐ์—ด) ์ƒ์„ฑ
        out = caption_model.generate(**inputs)

    # ์ƒ์„ฑ๋œ ํ† ํฐ ๋ฒˆํ˜ธ๋“ค์„ ์‚ฌ๋žŒ์ด ์ฝ์„ ์ˆ˜ ์žˆ๋Š” ๋‹จ์–ด๋กœ ๋ณ€ํ™˜(ํŠน์ˆ˜ ํ† ํฐ ์ œ์™ธ)
    caption = caption_processor.decode(out[0], skip_special_tokens=True)

    return caption  # ์ตœ์ข… ์ด๋ฏธ์ง€ ์„ค๋ช… ๋ฐ˜ํ™˜


# ==============================
# 5. ์ด๋ฏธ์ง€ ๋ถ„๋ฅ˜ + ์„ค๋ช… ํ•จ์ˆ˜
# ==============================

def classify_image(img):

    #  ์ด๋ฏธ PIL Image์ธ์ง€ ํ™•์ธ (์ค‘๋ณต ๋ณ€ํ™˜ ๋ฐฉ์ง€)
    if not isinstance(img, Image.Image):
        img = Image.fromarray(img)

    # ViT ์ „์ฒ˜๋ฆฌ
    inputs = processor(images=img, return_tensors="pt")

    # ๋ชจ๋ธ ์˜ˆ์ธก
    with torch.no_grad():
        outputs = model(**inputs)  # ๋ชจ๋ธ ์‹คํ–‰
        logits = outputs.logits  # ์›์‹œ ์ถœ๋ ฅ๊ฐ’

    # Softmax โ†’ ํ™•๋ฅ  ๋ณ€ํ™˜
    probs = torch.nn.functional.softmax(logits, dim=-1)[0]

    # ์ƒ์œ„ 3๊ฐœ ๊ฒฐ๊ณผ ์ถ”์ถœ
    top3_prob, top3_indices = torch.topk(probs, 3)

    results = {}  # ๊ฒฐ๊ณผ ์ €์žฅ์šฉ ๋”•์…”๋„ˆ๋ฆฌ

    # Top 3 ํด๋ž˜์Šค ๋ฐ˜๋ณต ์ฒ˜๋ฆฌ
    for i in range(3):
        label = model.config.id2label[top3_indices[i].item()]  # ๋ผ๋ฒจ ๋ณ€ํ™˜
        results[label] = float(top3_prob[i])  # ํ™•๋ฅ  ์ €์žฅ

    # ์ด๋ฏธ์ง€ ์„ค๋ช… ์ƒ์„ฑ (PIL ๊ทธ๋Œ€๋กœ ์ „๋‹ฌ)
    caption = generate_caption(img)

    # ๋ถ„๋ฅ˜ ๊ฒฐ๊ณผ + ์„ค๋ช… ๋ฐ˜ํ™˜
    return results, caption


# ==============================
# 6. Gradio UI ๊ตฌ์„ฑ
# ==============================

demo = gr.Interface(

    fn=classify_image,  # ์‹คํ–‰ ํ•จ์ˆ˜

    inputs=gr.Image(
        type="numpy",  # numpy ํ˜•ํƒœ๋กœ ์ด๋ฏธ์ง€ ์ž…๋ ฅ
        sources=["upload"]  # ์—…๋กœ๋“œ ๋ฐฉ์‹
    ),

    outputs=[
        gr.Label(num_top_classes=3),  # ์ด๋ฏธ์ง€ ๋ถ„๋ฅ˜ ๊ฒฐ๊ณผ
        gr.Textbox(label="์ด๋ฏธ์ง€ ์„ค๋ช…")  # ์ด๋ฏธ์ง€ ์„ค๋ช… ์ถœ๋ ฅ
    ],

    title="ViT ์ด๋ฏธ์ง€ ๋ถ„๋ฅ˜ + BLIP ์ด๋ฏธ์ง€ ์„ค๋ช…",  
    # ์›น ํŽ˜์ด์ง€ ์ œ๋ชฉ

    description="์ด๋ฏธ์ง€๋ฅผ ์—…๋กœ๋“œํ•˜๋ฉด ๋ถ„๋ฅ˜ ๊ฒฐ๊ณผ์™€ ์„ค๋ช…์„ ํ•จ๊ป˜ ์ œ๊ณตํ•ฉ๋‹ˆ๋‹ค."
    # ์„œ๋น„์Šค ์„ค๋ช…
)


# ==============================
# 7. ์„œ๋ฒ„ ์‹คํ–‰
# ==============================

if __name__ == "__main__":  # ์ง์ ‘ ์‹คํ–‰ ์‹œ

    demo.launch(inbrowser=True)  
    # Gradio ์„œ๋ฒ„ ์‹คํ–‰ + ๋ธŒ๋ผ์šฐ์ € ์ž๋™ ์‹คํ–‰