sentiment-app / app.py
mdeokk
deploy
c6bf520
# ==============================
# 1. ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ import
# ==============================
import gradio as gr # ์›น UI ์ƒ์„ฑ์„ ์œ„ํ•œ Gradio ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ
import torch # PyTorch (๋”ฅ๋Ÿฌ๋‹ ๋ชจ๋ธ ์‹คํ–‰ ๋ฐ ํ…์„œ ์—ฐ์‚ฐ)
from PIL import Image # ์ด๋ฏธ์ง€ ์ฒ˜๋ฆฌ (numpy โ†” PIL ๋ณ€ํ™˜)
# ViT ๋ชจ๋ธ (์ด๋ฏธ์ง€ ๋ถ„๋ฅ˜)
from transformers import ViTImageProcessor, ViTForImageClassification
# BLIP ๋ชจ๋ธ (์ด๋ฏธ์ง€ ์„ค๋ช… ์ƒ์„ฑ)
from transformers import BlipProcessor, BlipForConditionalGeneration
# ==============================
# 2. ViT ๋ชจ๋ธ ๋กœ๋“œ (์ด๋ฏธ์ง€ ๋ถ„๋ฅ˜)
# ==============================
model_name = "google/vit-base-patch16-224"
# Vision Transformer ๋ชจ๋ธ ์ด๋ฆ„
processor = ViTImageProcessor.from_pretrained(model_name)
# ์ด๋ฏธ์ง€ ์ „์ฒ˜๋ฆฌ๊ธฐ ๋กœ๋“œ (๋ฆฌ์‚ฌ์ด์ฆˆ, ์ •๊ทœํ™” ์ž๋™ ์ˆ˜ํ–‰)
model = ViTForImageClassification.from_pretrained(model_name)
# ์ด๋ฏธ์ง€ ๋ถ„๋ฅ˜ ๋ชจ๋ธ ๋กœ๋“œ (์‚ฌ์ „ ํ•™์Šต๋œ ๊ฐ€์ค‘์น˜ ํฌํ•จ)
# ==============================
# 3. BLIP ๋ชจ๋ธ ๋กœ๋“œ (์ด๋ฏธ์ง€ ์„ค๋ช…)
# ==============================
caption_processor = BlipProcessor.from_pretrained(
"Salesforce/blip-image-captioning-base"
)
# ์ด๋ฏธ์ง€ โ†’ ํ…์ŠคํŠธ ๋ณ€ํ™˜์„ ์œ„ํ•œ ์ „์ฒ˜๋ฆฌ๊ธฐ
caption_model = BlipForConditionalGeneration.from_pretrained(
"Salesforce/blip-image-captioning-base"
)
# ์ด๋ฏธ์ง€ ์„ค๋ช… ์ƒ์„ฑ ๋ชจ๋ธ
# ==============================
# 4. ์ด๋ฏธ์ง€ ์„ค๋ช… ํ•จ์ˆ˜ (์—๋Ÿฌ ์ˆ˜์ • ํ•ต์‹ฌ)
# ==============================
def generate_caption(img):
# ์ด๋ฏธ PIL Image์ธ์ง€ ํ™•์ธ (์ค‘๋ณต ๋ณ€ํ™˜ ๋ฐฉ์ง€)
if not isinstance(img, Image.Image):
img = Image.fromarray(img)
# BLIP ์ž…๋ ฅ ์ „์ฒ˜๋ฆฌ(์ด๋ฏธ์ง€๋ฅผ ๋ชจ๋ธ ์ž…๋ ฅ์šฉ ํ…์„œ(pt=PyTorch)๋กœ ๋ณ€ํ™˜)
inputs = caption_processor(images=img, return_tensors="pt")
# ๋ชจ๋ธ ์ถ”๋ก  (gradient ๋ฏธ๋ถ„ ๊ณ„์‚ฐ ๋น„ํ™œ์„ฑํ™”) => ๊ฒฝ์‚ฌ ํ•˜๊ฐ•๋ฒ•(๊ธฐ์šธ๊ธฐ ๊ณ„์‚ฐX) ์†๋„ ํ–ฅ์ƒ
with torch.no_grad():
# ๋ชจ๋ธ์„ ํ†ตํ•ด ์ด๋ฏธ์ง€์— ๋Œ€ํ•œ ํ…์ŠคํŠธ ํ† ํฐ(์ˆซ์ž ๋ฐฐ์—ด) ์ƒ์„ฑ
out = caption_model.generate(**inputs)
# ์ƒ์„ฑ๋œ ํ† ํฐ ๋ฒˆํ˜ธ๋“ค์„ ์‚ฌ๋žŒ์ด ์ฝ์„ ์ˆ˜ ์žˆ๋Š” ๋‹จ์–ด๋กœ ๋ณ€ํ™˜(ํŠน์ˆ˜ ํ† ํฐ ์ œ์™ธ)
caption = caption_processor.decode(out[0], skip_special_tokens=True)
return caption # ์ตœ์ข… ์ด๋ฏธ์ง€ ์„ค๋ช… ๋ฐ˜ํ™˜
# ==============================
# 5. ์ด๋ฏธ์ง€ ๋ถ„๋ฅ˜ + ์„ค๋ช… ํ•จ์ˆ˜
# ==============================
def classify_image(img):
# ์ด๋ฏธ PIL Image์ธ์ง€ ํ™•์ธ (์ค‘๋ณต ๋ณ€ํ™˜ ๋ฐฉ์ง€)
if not isinstance(img, Image.Image):
img = Image.fromarray(img)
# ViT ์ „์ฒ˜๋ฆฌ
inputs = processor(images=img, return_tensors="pt")
# ๋ชจ๋ธ ์˜ˆ์ธก
with torch.no_grad():
outputs = model(**inputs) # ๋ชจ๋ธ ์‹คํ–‰
logits = outputs.logits # ์›์‹œ ์ถœ๋ ฅ๊ฐ’
# Softmax โ†’ ํ™•๋ฅ  ๋ณ€ํ™˜
probs = torch.nn.functional.softmax(logits, dim=-1)[0]
# ์ƒ์œ„ 3๊ฐœ ๊ฒฐ๊ณผ ์ถ”์ถœ
top3_prob, top3_indices = torch.topk(probs, 3)
results = {} # ๊ฒฐ๊ณผ ์ €์žฅ์šฉ ๋”•์…”๋„ˆ๋ฆฌ
# Top 3 ํด๋ž˜์Šค ๋ฐ˜๋ณต ์ฒ˜๋ฆฌ
for i in range(3):
label = model.config.id2label[top3_indices[i].item()] # ๋ผ๋ฒจ ๋ณ€ํ™˜
results[label] = float(top3_prob[i]) # ํ™•๋ฅ  ์ €์žฅ
# ์ด๋ฏธ์ง€ ์„ค๋ช… ์ƒ์„ฑ (PIL ๊ทธ๋Œ€๋กœ ์ „๋‹ฌ)
caption = generate_caption(img)
# ๋ถ„๋ฅ˜ ๊ฒฐ๊ณผ + ์„ค๋ช… ๋ฐ˜ํ™˜
return results, caption
# ==============================
# 6. Gradio UI ๊ตฌ์„ฑ
# ==============================
demo = gr.Interface(
fn=classify_image, # ์‹คํ–‰ ํ•จ์ˆ˜
inputs=gr.Image(
type="numpy", # numpy ํ˜•ํƒœ๋กœ ์ด๋ฏธ์ง€ ์ž…๋ ฅ
sources=["upload"] # ์—…๋กœ๋“œ ๋ฐฉ์‹
),
outputs=[
gr.Label(num_top_classes=3), # ์ด๋ฏธ์ง€ ๋ถ„๋ฅ˜ ๊ฒฐ๊ณผ
gr.Textbox(label="์ด๋ฏธ์ง€ ์„ค๋ช…") # ์ด๋ฏธ์ง€ ์„ค๋ช… ์ถœ๋ ฅ
],
title="ViT ์ด๋ฏธ์ง€ ๋ถ„๋ฅ˜ + BLIP ์ด๋ฏธ์ง€ ์„ค๋ช…",
# ์›น ํŽ˜์ด์ง€ ์ œ๋ชฉ
description="์ด๋ฏธ์ง€๋ฅผ ์—…๋กœ๋“œํ•˜๋ฉด ๋ถ„๋ฅ˜ ๊ฒฐ๊ณผ์™€ ์„ค๋ช…์„ ํ•จ๊ป˜ ์ œ๊ณตํ•ฉ๋‹ˆ๋‹ค."
# ์„œ๋น„์Šค ์„ค๋ช…
)
# ==============================
# 7. ์„œ๋ฒ„ ์‹คํ–‰
# ==============================
if __name__ == "__main__": # ์ง์ ‘ ์‹คํ–‰ ์‹œ
demo.launch(inbrowser=True)
# Gradio ์„œ๋ฒ„ ์‹คํ–‰ + ๋ธŒ๋ผ์šฐ์ € ์ž๋™ ์‹คํ–‰