|
|
--- |
|
|
library_name: transformers |
|
|
tags: |
|
|
- image-classification |
|
|
- multi-head-classification |
|
|
- room-classification |
|
|
- dinov2 |
|
|
- computer-vision |
|
|
- scene-classification |
|
|
license: apache-2.0 |
|
|
language: |
|
|
- en |
|
|
pipeline_tag: image-classification |
|
|
base_model: |
|
|
- facebook/dinov2-large |
|
|
--- |
|
|
|
|
|
# Room Scene Classifier |
|
|
|
|
|
DINOv2 ๊ธฐ๋ฐ ๋ฉํฐํค๋ ํธํ
์ด๋ฏธ์ง ์ฅ๋ฉด ๋ถ๋ฅ ๋ชจ๋ธ์
๋๋ค. |
|
|
|
|
|
## ๋ชจ๋ธ ๊ฐ์ |
|
|
|
|
|
์ด ๋ชจ๋ธ์ ํธํ
์ด๋ฏธ์ง๋ฅผ **Scene(์ฅ๋ฉด)**, **Concept(๊ฐ๋
)**, **Object(๊ฐ์ฒด)** 3๊ฐ์ง ๊ด์ ์ผ๋ก ๋์์ ๋ถ๋ฅํ๋ ๋ฉํฐํค๋ ๋ฅ๋ฌ๋ ๋ชจ๋ธ์
๋๋ค. DINOv2 ๋ฐฑ๋ณธ์ ์ฌ์ฉํ์ฌ ๊ฐ๋ ฅํ ๋น์ ํน์ง์ ์ถ์ถํ๊ณ , ๊ฐ ํค๋์์ ํนํ๋ ๋ถ๋ฅ๋ฅผ ์ํํฉ๋๋ค. |
|
|
|
|
|
## ๋ชจ๋ธ ์ ๋ณด |
|
|
|
|
|
- **๋ชจ๋ธ๋ช
**: `image_classifier_model_0.2` |
|
|
- **๊ธฐ๋ฐ ๋ชจ๋ธ**: `facebook/dinov2-large` |
|
|
- **์ด๋ฏธ์ง ํฌ๊ธฐ**: 224x224 |
|
|
- **์ฑ๋**: RGB (3์ฑ๋) |
|
|
- **์ด ํ๋ผ๋ฏธํฐ**: 303,252,502๊ฐ (๋ฐฑ๋ณธ ๊ณ ์ ) |
|
|
- **ํ๋ จ ๊ฐ๋ฅ ํ๋ผ๋ฏธํฐ**: 24,598๊ฐ |
|
|
|
|
|
## ๋ถ๋ฅ ํค๋ |
|
|
|
|
|
### Scene ํค๋ (6๊ฐ ํด๋์ค) |
|
|
- ๊ฐ์ค, ์์ค, ์์์ฅ, ๋ก๋น, ๋ ์คํ ๋, ๊ธฐํ |
|
|
|
|
|
### Concept ํค๋ (3๊ฐ ํด๋์ค) |
|
|
- ์ค๋ด, ์ผ์ธ, ํด๋ก์ฆ์
|
|
|
|
|
|
### Object ํค๋ (13๊ฐ ํด๋์ค) |
|
|
- ์นจ๋, ์ํ, ์ค์๊ธฐ, ์์กฐ, ์์, ํ
์ด๋ธ, TV, ๋์ฅ๊ณ , ์ฑํฌ๋, ํ์ฅ๋, ๊ฑฐ์ธ, ๊ธฐํ, ๋ฏธ๋ถ๋ฅ |
|
|
|
|
|
## ์ฌ์ฉ๋ฒ |
|
|
|
|
|
### Python์ผ๋ก ๋ชจ๋ธ ์ฌ์ฉ |
|
|
|
|
|
```python |
|
|
import torch |
|
|
import onnxruntime as ort |
|
|
import numpy as np |
|
|
from PIL import Image |
|
|
from torchvision import transforms |
|
|
import json |
|
|
|
|
|
# ๋ชจ๋ธ ์ ๋ณด ๋ก๋ |
|
|
with open('image_classifier_model_0.2_model_info.json', 'r') as f: |
|
|
model_info = json.load(f) |
|
|
|
|
|
# PyTorch ๋ชจ๋ธ ๋ก๋ |
|
|
model = torch.load('image_classifier_model_0.2.pth', map_location='cpu') |
|
|
model.eval() |
|
|
|
|
|
# ONNX ๋ชจ๋ธ ์ฌ์ฉ (๋ ๋น ๋ฅธ ์ถ๋ก ) |
|
|
onnx_session = ort.InferenceSession('image_classifier_model_0.2.onnx') |
|
|
|
|
|
# ์ด๋ฏธ์ง ์ ์ฒ๋ฆฌ |
|
|
transform = transforms.Compose([ |
|
|
transforms.Resize((224, 224)), |
|
|
transforms.CenterCrop(224), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) |
|
|
]) |
|
|
|
|
|
def classify_image_pytorch(image_path): |
|
|
"""PyTorch ๋ชจ๋ธ์ ์ฌ์ฉํ ์ด๋ฏธ์ง ๋ถ๋ฅ""" |
|
|
image = transform(Image.open(image_path)).unsqueeze(0) |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model(image) |
|
|
predictions = {} |
|
|
|
|
|
for head_name, logits in outputs.items(): |
|
|
probabilities = torch.softmax(logits, dim=1) |
|
|
predicted_class = torch.argmax(probabilities, dim=1).item() |
|
|
confidence = probabilities[0, predicted_class].item() |
|
|
|
|
|
predictions[head_name] = { |
|
|
'class_id': predicted_class, |
|
|
'confidence': confidence, |
|
|
'probabilities': probabilities[0].tolist() |
|
|
} |
|
|
|
|
|
return predictions |
|
|
|
|
|
def classify_image_onnx(image_path): |
|
|
"""ONNX ๋ชจ๋ธ์ ์ฌ์ฉํ ์ด๋ฏธ์ง ๋ถ๋ฅ (๊ถ์ฅ)""" |
|
|
image = transform(Image.open(image_path)).numpy() |
|
|
|
|
|
# ONNX ๋ชจ๋ธ ์ถ๋ก |
|
|
input_feed = {'input': image.astype(np.float32)} |
|
|
outputs = onnx_session.run(None, input_feed) |
|
|
|
|
|
predictions = {} |
|
|
head_names = ['scene', 'concept', 'object'] |
|
|
|
|
|
for i, head_name in enumerate(head_names): |
|
|
logits = outputs[i] |
|
|
probabilities = torch.softmax(torch.tensor(logits), dim=1) |
|
|
predicted_class = torch.argmax(probabilities, dim=1).item() |
|
|
confidence = probabilities[0, predicted_class].item() |
|
|
|
|
|
predictions[head_name] = { |
|
|
'class_id': predicted_class, |
|
|
'confidence': confidence, |
|
|
'probabilities': probabilities[0].tolist() |
|
|
} |
|
|
|
|
|
return predictions |
|
|
|
|
|
# ์์ ์ฌ์ฉ |
|
|
predictions = classify_image_onnx("hotel_room.jpg") |
|
|
print("๋ถ๋ฅ ๊ฒฐ๊ณผ:") |
|
|
for head, result in predictions.items(): |
|
|
print(f"{head}: ํด๋์ค {result['class_id']}, ์ ๋ขฐ๋ {result['confidence']:.4f}") |
|
|
``` |
|
|
|
|
|
### ํด๋์ค ID๋ฅผ ์ค์ ํด๋์ค๋ช
์ผ๋ก ๋ณํ |
|
|
|
|
|
```python |
|
|
def get_class_names(predictions, model_info): |
|
|
"""ํด๋์ค ID๋ฅผ ์ค์ ํด๋์ค๋ช
์ผ๋ก ๋ณํ""" |
|
|
class_mappings = model_info['class_mappings'] |
|
|
|
|
|
results = {} |
|
|
for head, result in predictions.items(): |
|
|
class_id = result['class_id'] |
|
|
if head in class_mappings: |
|
|
actual_class_id = class_mappings[head][str(class_id)] |
|
|
results[head] = { |
|
|
'class_id': actual_class_id, |
|
|
'confidence': result['confidence'] |
|
|
} |
|
|
|
|
|
return results |
|
|
|
|
|
# ํด๋์ค๋ช
๋ณํ ์์ |
|
|
class_names = get_class_names(predictions, model_info) |
|
|
print("์ค์ ํด๋์ค ID:") |
|
|
for head, result in class_names.items(): |
|
|
print(f"{head}: {result['class_id']}") |
|
|
``` |
|
|
|
|
|
### ๋ฐฐ์น ์ฒ๋ฆฌ |
|
|
|
|
|
```python |
|
|
def classify_batch_images(image_paths): |
|
|
"""์ฌ๋ฌ ์ด๋ฏธ์ง๋ฅผ ํ ๋ฒ์ ๋ถ๋ฅ""" |
|
|
results = [] |
|
|
|
|
|
for image_path in image_paths: |
|
|
predictions = classify_image_onnx(image_path) |
|
|
results.append({ |
|
|
'image_path': image_path, |
|
|
'predictions': predictions |
|
|
}) |
|
|
|
|
|
return results |
|
|
|
|
|
# ์์ |
|
|
image_paths = ["room1.jpg", "bathroom1.jpg", "lobby1.jpg"] |
|
|
batch_results = classify_batch_images(image_paths) |
|
|
|
|
|
for result in batch_results: |
|
|
print(f"\n์ด๋ฏธ์ง: {result['image_path']}") |
|
|
for head, pred in result['predictions'].items(): |
|
|
print(f" {head}: ํด๋์ค {pred['class_id']}, ์ ๋ขฐ๋ {pred['confidence']:.4f}") |
|
|
``` |
|
|
|
|
|
## ๋ชจ๋ธ ํ์ผ |
|
|
|
|
|
- `image_classifier_model_0.2.pth`: PyTorch ๋ชจ๋ธ ํ์ผ |
|
|
- `image_classifier_model_0.2.onnx`: ONNX ๋ชจ๋ธ ํ์ผ (์ถ๋ก ์ต์ ํ) |
|
|
- `image_classifier_model_0.2_model_info.json`: ๋ชจ๋ธ ๋ฉํ๋ฐ์ดํฐ |
|
|
- `image_classifier_model_0.2_inference_example.py`: ์ถ๋ก ์์ ์ฝ๋ |
|
|
|
|
|
|
|
|
## ๋ชจ๋ธ ์ํคํ
์ฒ |
|
|
|
|
|
### ๋ฉํฐํค๋ ๋ถ๋ฅ ์์คํ
|
|
|
``` |
|
|
์
๋ ฅ ์ด๋ฏธ์ง (224ร224) |
|
|
โ |
|
|
DINOv2 ๋ฐฑ๋ณธ (Frozen) |
|
|
โ |
|
|
๊ณตํต ํน์ง (1024์ฐจ์) |
|
|
โโโโ Scene ํค๋ โ 6๊ฐ ํด๋์ค |
|
|
โโโโ Concept ํค๋ โ 3๊ฐ ํด๋์ค |
|
|
โโโโ Object ํค๋ โ 13๊ฐ ํด๋์ค |
|
|
``` |
|
|
|
|
|
### ์ฃผ์ ํน์ง |
|
|
- **DINOv2 ๋ฐฑ๋ณธ**: ๊ฐ๋ ฅํ ๋น์ ํธ๋์คํฌ๋จธ ๊ธฐ๋ฐ ํน์ง ์ถ์ถ |
|
|
- **๋ฐฑ๋ณธ ๊ณ ์ **: ์ฌ์ ํ๋ จ๋ ํน์ง์ ํ์ฉํ์ฌ ๊ณผ์ ํฉ ๋ฐฉ์ง |
|
|
- **๋ฉํฐํค๋**: 3๊ฐ ํค๋๋ก ๋ค๊ฐ๋ ๋ถ์ |
|
|
- **ํด๋์ค ๊ฐ์ค์น**: ๋ถ๊ท ํ ๋ฐ์ดํฐ ์๋ ๋ณด์ |
|
|
|
|
|
## ์ ์ฒ๋ฆฌ ์๊ตฌ์ฌํญ |
|
|
|
|
|
1. **์ด๋ฏธ์ง ํฌ๊ธฐ**: 224x224 ํฝ์
|
|
|
2. **์์ ๊ณต๊ฐ**: RGB |
|
|
3. **์ ๊ทํ**: ImageNet ํ์ค๊ฐ ์ฌ์ฉ (mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
|
|
4. **ํฌ๋กญ**: ์ค์ ํฌ๋กญ (center crop) |
|
|
5. **์ง์ ํ์**: JPG, PNG, JPEG |
|
|
|
|
|
## ์ฌ์ฉ ์ฌ๋ก |
|
|
|
|
|
### ์ง์ ์ฌ์ฉ |
|
|
|
|
|
- **ํธํ
์ด๋ฏธ์ง ์๋ ๋ถ๋ฅ**: ๊ฐ์ค, ์์ค, ๋ก๋น ๋ฑ ์ฅ๋ฉด๋ณ ์๋ ๋ถ๋ฅ |
|
|
- **์ด๋ฏธ์ง ๋ฉํ๋ฐ์ดํฐ ์์ฑ**: ์ด๋ฏธ์ง์ ์ฅ๋ฉด, ๊ฐ๋
, ๊ฐ์ฒด ์ ๋ณด ์๋ ์ถ์ถ |
|
|
- **์ด๋ฏธ์ง ๋ฐ์ดํฐ๋ฒ ์ด์ค ๊ด๋ฆฌ**: ๋๋์ ํธํ
์ด๋ฏธ์ง ์๋ ํ๊น
|
|
|
- **ํ์ง ๊ด๋ฆฌ**: ์ด๋ฏธ์ง ๋ถ๋ฅ ์ผ๊ด์ฑ ๊ฒ์ฆ |
|
|
|
|
|
### ๋ค์ด์คํธ๋ฆผ ์ฌ์ฉ |
|
|
|
|
|
- **ํธํ
๊ด๋ฆฌ ์์คํ
**: ๊ฐ์ค ์ด๋ฏธ์ง ์๋ ๋ถ๋ฅ ๋ฐ ๊ด๋ฆฌ |
|
|
- **์ฌํ ํ๋ซํผ**: ๊ฐ์ค ํ์
๋ณ ์ด๋ฏธ์ง ํํฐ๋ง |
|
|
- **๋ถ๋์ฐ ํ๋ซํผ**: ์์ ์์ค ์ ๋ณด ์๋ ์ถ์ถ |
|
|
- **์ด๋ฏธ์ง ๊ฒ์ ์์ง**: ๋ค์ค ์์ฑ ๊ธฐ๋ฐ ์ด๋ฏธ์ง ๊ฒ์ |
|
|
|
|
|
## ์ ํ์ฌํญ |
|
|
|
|
|
1. **๋๋ฉ์ธ ํนํ**: ํธํ
/์์ ์ด๋ฏธ์ง์ ํนํ๋์ด ์์ด ๋ค๋ฅธ ๋๋ฉ์ธ์์๋ ์ฑ๋ฅ์ด ์ ํ์ ์
๋๋ค. |
|
|
2. **์ด๋ฏธ์ง ํ์ง**: ์ ํ์ง์ด๋ ๋
ธ์ด์ฆ๊ฐ ๋ง์ ์ด๋ฏธ์ง์์๋ ์ฑ๋ฅ์ด ์ ํ๋ ์ ์์ต๋๋ค. |
|
|
3. **๊ฐ๋ ์์กด์ฑ**: ํน์ ๊ฐ๋์์ ์ดฌ์๋ ์ด๋ฏธ์ง์ ๋ํด ์ฑ๋ฅ์ด ๋ค๋ฅผ ์ ์์ต๋๋ค. |
|
|
4. **ํด๋์ค ๋ถ๊ท ํ**: ์ผ๋ถ ํด๋์ค๋ ๋ค๋ฅธ ํด๋์ค๋ณด๋ค ์ฑ๋ฅ์ด ๋ฎ์ ์ ์์ต๋๋ค. |
|
|
|
|
|
## ๋ผ์ด์ ์ค |
|
|
|
|
|
Apache 2.0 License |
|
|
|
|
|
## ์ฐธ๊ณ |
|
|
|
|
|
์ด ๋ชจ๋ธ์ Room Clusterer ํ๋ก์ ํธ์ ์ผ๋ถ๋ก ๊ฐ๋ฐ๋์์ต๋๋ค. ๋ ์์ธํ ์ ๋ณด๋ [ํ๋ก์ ํธ ์ ์ฅ์](https://github.com/tportio/content-ml-trainer)๋ฅผ ์ฐธ์กฐํ์ธ์. |