Spaces:
Running
Running
| # ============================== | |
| # 1. ๋ผ์ด๋ธ๋ฌ๋ฆฌ import | |
| # ============================== | |
| import gradio as gr # ์น UI ์์ฑ์ ์ํ Gradio ๋ผ์ด๋ธ๋ฌ๋ฆฌ | |
| import torch # PyTorch (๋ฅ๋ฌ๋ ๋ชจ๋ธ ์คํ ๋ฐ ํ ์ ์ฐ์ฐ) | |
| from PIL import Image # ์ด๋ฏธ์ง ์ฒ๋ฆฌ (numpy โ PIL ๋ณํ) | |
| # ViT ๋ชจ๋ธ (์ด๋ฏธ์ง ๋ถ๋ฅ) | |
| from transformers import ViTImageProcessor, ViTForImageClassification | |
| # BLIP ๋ชจ๋ธ (์ด๋ฏธ์ง ์ค๋ช ์์ฑ) | |
| from transformers import BlipProcessor, BlipForConditionalGeneration | |
| # ============================== | |
| # 2. ViT ๋ชจ๋ธ ๋ก๋ (์ด๋ฏธ์ง ๋ถ๋ฅ) | |
| # ============================== | |
| model_name = "google/vit-base-patch16-224" | |
| # Vision Transformer ๋ชจ๋ธ ์ด๋ฆ | |
| processor = ViTImageProcessor.from_pretrained(model_name) | |
| # ์ด๋ฏธ์ง ์ ์ฒ๋ฆฌ๊ธฐ ๋ก๋ (๋ฆฌ์ฌ์ด์ฆ, ์ ๊ทํ ์๋ ์ํ) | |
| model = ViTForImageClassification.from_pretrained(model_name) | |
| # ์ด๋ฏธ์ง ๋ถ๋ฅ ๋ชจ๋ธ ๋ก๋ (์ฌ์ ํ์ต๋ ๊ฐ์ค์น ํฌํจ) | |
| # ============================== | |
| # 3. BLIP ๋ชจ๋ธ ๋ก๋ (์ด๋ฏธ์ง ์ค๋ช ) | |
| # ============================== | |
| caption_processor = BlipProcessor.from_pretrained( | |
| "Salesforce/blip-image-captioning-base" | |
| ) | |
| # ์ด๋ฏธ์ง โ ํ ์คํธ ๋ณํ์ ์ํ ์ ์ฒ๋ฆฌ๊ธฐ | |
| caption_model = BlipForConditionalGeneration.from_pretrained( | |
| "Salesforce/blip-image-captioning-base" | |
| ) | |
| # ์ด๋ฏธ์ง ์ค๋ช ์์ฑ ๋ชจ๋ธ | |
| # ============================== | |
| # 4. ์ด๋ฏธ์ง ์ค๋ช ํจ์ (์๋ฌ ์์ ํต์ฌ) | |
| # ============================== | |
| def generate_caption(img): | |
| # ์ด๋ฏธ PIL Image์ธ์ง ํ์ธ (์ค๋ณต ๋ณํ ๋ฐฉ์ง) | |
| if not isinstance(img, Image.Image): | |
| img = Image.fromarray(img) | |
| # BLIP ์ ๋ ฅ ์ ์ฒ๋ฆฌ(์ด๋ฏธ์ง๋ฅผ ๋ชจ๋ธ ์ ๋ ฅ์ฉ ํ ์(pt=PyTorch)๋ก ๋ณํ) | |
| inputs = caption_processor(images=img, return_tensors="pt") | |
| # ๋ชจ๋ธ ์ถ๋ก (gradient ๋ฏธ๋ถ ๊ณ์ฐ ๋นํ์ฑํ) => ๊ฒฝ์ฌ ํ๊ฐ๋ฒ(๊ธฐ์ธ๊ธฐ ๊ณ์ฐX) ์๋ ํฅ์ | |
| with torch.no_grad(): | |
| # ๋ชจ๋ธ์ ํตํด ์ด๋ฏธ์ง์ ๋ํ ํ ์คํธ ํ ํฐ(์ซ์ ๋ฐฐ์ด) ์์ฑ | |
| out = caption_model.generate(**inputs) | |
| # ์์ฑ๋ ํ ํฐ ๋ฒํธ๋ค์ ์ฌ๋์ด ์ฝ์ ์ ์๋ ๋จ์ด๋ก ๋ณํ(ํน์ ํ ํฐ ์ ์ธ) | |
| caption = caption_processor.decode(out[0], skip_special_tokens=True) | |
| return caption # ์ต์ข ์ด๋ฏธ์ง ์ค๋ช ๋ฐํ | |
| # ============================== | |
| # 5. ์ด๋ฏธ์ง ๋ถ๋ฅ + ์ค๋ช ํจ์ | |
| # ============================== | |
| def classify_image(img): | |
| # ์ด๋ฏธ PIL Image์ธ์ง ํ์ธ (์ค๋ณต ๋ณํ ๋ฐฉ์ง) | |
| if not isinstance(img, Image.Image): | |
| img = Image.fromarray(img) | |
| # ViT ์ ์ฒ๋ฆฌ | |
| inputs = processor(images=img, return_tensors="pt") | |
| # ๋ชจ๋ธ ์์ธก | |
| with torch.no_grad(): | |
| outputs = model(**inputs) # ๋ชจ๋ธ ์คํ | |
| logits = outputs.logits # ์์ ์ถ๋ ฅ๊ฐ | |
| # Softmax โ ํ๋ฅ ๋ณํ | |
| probs = torch.nn.functional.softmax(logits, dim=-1)[0] | |
| # ์์ 3๊ฐ ๊ฒฐ๊ณผ ์ถ์ถ | |
| top3_prob, top3_indices = torch.topk(probs, 3) | |
| results = {} # ๊ฒฐ๊ณผ ์ ์ฅ์ฉ ๋์ ๋๋ฆฌ | |
| # Top 3 ํด๋์ค ๋ฐ๋ณต ์ฒ๋ฆฌ | |
| for i in range(3): | |
| label = model.config.id2label[top3_indices[i].item()] # ๋ผ๋ฒจ ๋ณํ | |
| results[label] = float(top3_prob[i]) # ํ๋ฅ ์ ์ฅ | |
| # ์ด๋ฏธ์ง ์ค๋ช ์์ฑ (PIL ๊ทธ๋๋ก ์ ๋ฌ) | |
| caption = generate_caption(img) | |
| # ๋ถ๋ฅ ๊ฒฐ๊ณผ + ์ค๋ช ๋ฐํ | |
| return results, caption | |
| # ============================== | |
| # 6. Gradio UI ๊ตฌ์ฑ | |
| # ============================== | |
| demo = gr.Interface( | |
| fn=classify_image, # ์คํ ํจ์ | |
| inputs=gr.Image( | |
| type="numpy", # numpy ํํ๋ก ์ด๋ฏธ์ง ์ ๋ ฅ | |
| sources=["upload"] # ์ ๋ก๋ ๋ฐฉ์ | |
| ), | |
| outputs=[ | |
| gr.Label(num_top_classes=3), # ์ด๋ฏธ์ง ๋ถ๋ฅ ๊ฒฐ๊ณผ | |
| gr.Textbox(label="์ด๋ฏธ์ง ์ค๋ช ") # ์ด๋ฏธ์ง ์ค๋ช ์ถ๋ ฅ | |
| ], | |
| title="ViT ์ด๋ฏธ์ง ๋ถ๋ฅ + BLIP ์ด๋ฏธ์ง ์ค๋ช ", | |
| # ์น ํ์ด์ง ์ ๋ชฉ | |
| description="์ด๋ฏธ์ง๋ฅผ ์ ๋ก๋ํ๋ฉด ๋ถ๋ฅ ๊ฒฐ๊ณผ์ ์ค๋ช ์ ํจ๊ป ์ ๊ณตํฉ๋๋ค." | |
| # ์๋น์ค ์ค๋ช | |
| ) | |
| # ============================== | |
| # 7. ์๋ฒ ์คํ | |
| # ============================== | |
| if __name__ == "__main__": # ์ง์ ์คํ ์ | |
| demo.launch(inbrowser=True) | |
| # Gradio ์๋ฒ ์คํ + ๋ธ๋ผ์ฐ์ ์๋ ์คํ | |