image-classifier / README.md
justin-onda's picture
Update README.md
921d638 verified
---
library_name: transformers
tags:
- image-classification
- multi-head-classification
- room-classification
- dinov2
- computer-vision
- scene-classification
license: apache-2.0
language:
- en
pipeline_tag: image-classification
base_model:
- facebook/dinov2-large
---
# Room Scene Classifier
DINOv2 ๊ธฐ๋ฐ˜ ๋ฉ€ํ‹ฐํ—ค๋“œ ํ˜ธํ…” ์ด๋ฏธ์ง€ ์žฅ๋ฉด ๋ถ„๋ฅ˜ ๋ชจ๋ธ์ž…๋‹ˆ๋‹ค.
## ๋ชจ๋ธ ๊ฐœ์š”
์ด ๋ชจ๋ธ์€ ํ˜ธํ…” ์ด๋ฏธ์ง€๋ฅผ **Scene(์žฅ๋ฉด)**, **Concept(๊ฐœ๋…)**, **Object(๊ฐ์ฒด)** 3๊ฐ€์ง€ ๊ด€์ ์œผ๋กœ ๋™์‹œ์— ๋ถ„๋ฅ˜ํ•˜๋Š” ๋ฉ€ํ‹ฐํ—ค๋“œ ๋”ฅ๋Ÿฌ๋‹ ๋ชจ๋ธ์ž…๋‹ˆ๋‹ค. DINOv2 ๋ฐฑ๋ณธ์„ ์‚ฌ์šฉํ•˜์—ฌ ๊ฐ•๋ ฅํ•œ ๋น„์ „ ํŠน์ง•์„ ์ถ”์ถœํ•˜๊ณ , ๊ฐ ํ—ค๋“œ์—์„œ ํŠนํ™”๋œ ๋ถ„๋ฅ˜๋ฅผ ์ˆ˜ํ–‰ํ•ฉ๋‹ˆ๋‹ค.
## ๋ชจ๋ธ ์ •๋ณด
- **๋ชจ๋ธ๋ช…**: `image_classifier_model_0.2`
- **๊ธฐ๋ฐ˜ ๋ชจ๋ธ**: `facebook/dinov2-large`
- **์ด๋ฏธ์ง€ ํฌ๊ธฐ**: 224x224
- **์ฑ„๋„**: RGB (3์ฑ„๋„)
- **์ด ํŒŒ๋ผ๋ฏธํ„ฐ**: 303,252,502๊ฐœ (๋ฐฑ๋ณธ ๊ณ ์ •)
- **ํ›ˆ๋ จ ๊ฐ€๋Šฅ ํŒŒ๋ผ๋ฏธํ„ฐ**: 24,598๊ฐœ
## ๋ถ„๋ฅ˜ ํ—ค๋“œ
### Scene ํ—ค๋“œ (6๊ฐœ ํด๋ž˜์Šค)
- ๊ฐ์‹ค, ์š•์‹ค, ์ˆ˜์˜์žฅ, ๋กœ๋น„, ๋ ˆ์Šคํ† ๋ž‘, ๊ธฐํƒ€
### Concept ํ—ค๋“œ (3๊ฐœ ํด๋ž˜์Šค)
- ์‹ค๋‚ด, ์•ผ์™ธ, ํด๋กœ์ฆˆ์—…
### Object ํ—ค๋“œ (13๊ฐœ ํด๋ž˜์Šค)
- ์นจ๋Œ€, ์†ŒํŒŒ, ์ƒค์›Œ๊ธฐ, ์š•์กฐ, ์˜์ž, ํ…Œ์ด๋ธ”, TV, ๋ƒ‰์žฅ๊ณ , ์‹ฑํฌ๋Œ€, ํ™”์žฅ๋Œ€, ๊ฑฐ์šธ, ๊ธฐํƒ€, ๋ฏธ๋ถ„๋ฅ˜
## ์‚ฌ์šฉ๋ฒ•
### Python์œผ๋กœ ๋ชจ๋ธ ์‚ฌ์šฉ
```python
import torch
import onnxruntime as ort
import numpy as np
from PIL import Image
from torchvision import transforms
import json
# ๋ชจ๋ธ ์ •๋ณด ๋กœ๋“œ
with open('image_classifier_model_0.2_model_info.json', 'r') as f:
model_info = json.load(f)
# PyTorch ๋ชจ๋ธ ๋กœ๋“œ
model = torch.load('image_classifier_model_0.2.pth', map_location='cpu')
model.eval()
# ONNX ๋ชจ๋ธ ์‚ฌ์šฉ (๋” ๋น ๋ฅธ ์ถ”๋ก )
onnx_session = ort.InferenceSession('image_classifier_model_0.2.onnx')
# ์ด๋ฏธ์ง€ ์ „์ฒ˜๋ฆฌ
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
def classify_image_pytorch(image_path):
"""PyTorch ๋ชจ๋ธ์„ ์‚ฌ์šฉํ•œ ์ด๋ฏธ์ง€ ๋ถ„๋ฅ˜"""
image = transform(Image.open(image_path)).unsqueeze(0)
with torch.no_grad():
outputs = model(image)
predictions = {}
for head_name, logits in outputs.items():
probabilities = torch.softmax(logits, dim=1)
predicted_class = torch.argmax(probabilities, dim=1).item()
confidence = probabilities[0, predicted_class].item()
predictions[head_name] = {
'class_id': predicted_class,
'confidence': confidence,
'probabilities': probabilities[0].tolist()
}
return predictions
def classify_image_onnx(image_path):
"""ONNX ๋ชจ๋ธ์„ ์‚ฌ์šฉํ•œ ์ด๋ฏธ์ง€ ๋ถ„๋ฅ˜ (๊ถŒ์žฅ)"""
image = transform(Image.open(image_path)).numpy()
# ONNX ๋ชจ๋ธ ์ถ”๋ก 
input_feed = {'input': image.astype(np.float32)}
outputs = onnx_session.run(None, input_feed)
predictions = {}
head_names = ['scene', 'concept', 'object']
for i, head_name in enumerate(head_names):
logits = outputs[i]
probabilities = torch.softmax(torch.tensor(logits), dim=1)
predicted_class = torch.argmax(probabilities, dim=1).item()
confidence = probabilities[0, predicted_class].item()
predictions[head_name] = {
'class_id': predicted_class,
'confidence': confidence,
'probabilities': probabilities[0].tolist()
}
return predictions
# ์˜ˆ์‹œ ์‚ฌ์šฉ
predictions = classify_image_onnx("hotel_room.jpg")
print("๋ถ„๋ฅ˜ ๊ฒฐ๊ณผ:")
for head, result in predictions.items():
print(f"{head}: ํด๋ž˜์Šค {result['class_id']}, ์‹ ๋ขฐ๋„ {result['confidence']:.4f}")
```
### ํด๋ž˜์Šค ID๋ฅผ ์‹ค์ œ ํด๋ž˜์Šค๋ช…์œผ๋กœ ๋ณ€ํ™˜
```python
def get_class_names(predictions, model_info):
"""ํด๋ž˜์Šค ID๋ฅผ ์‹ค์ œ ํด๋ž˜์Šค๋ช…์œผ๋กœ ๋ณ€ํ™˜"""
class_mappings = model_info['class_mappings']
results = {}
for head, result in predictions.items():
class_id = result['class_id']
if head in class_mappings:
actual_class_id = class_mappings[head][str(class_id)]
results[head] = {
'class_id': actual_class_id,
'confidence': result['confidence']
}
return results
# ํด๋ž˜์Šค๋ช… ๋ณ€ํ™˜ ์˜ˆ์‹œ
class_names = get_class_names(predictions, model_info)
print("์‹ค์ œ ํด๋ž˜์Šค ID:")
for head, result in class_names.items():
print(f"{head}: {result['class_id']}")
```
### ๋ฐฐ์น˜ ์ฒ˜๋ฆฌ
```python
def classify_batch_images(image_paths):
"""์—ฌ๋Ÿฌ ์ด๋ฏธ์ง€๋ฅผ ํ•œ ๋ฒˆ์— ๋ถ„๋ฅ˜"""
results = []
for image_path in image_paths:
predictions = classify_image_onnx(image_path)
results.append({
'image_path': image_path,
'predictions': predictions
})
return results
# ์˜ˆ์‹œ
image_paths = ["room1.jpg", "bathroom1.jpg", "lobby1.jpg"]
batch_results = classify_batch_images(image_paths)
for result in batch_results:
print(f"\n์ด๋ฏธ์ง€: {result['image_path']}")
for head, pred in result['predictions'].items():
print(f" {head}: ํด๋ž˜์Šค {pred['class_id']}, ์‹ ๋ขฐ๋„ {pred['confidence']:.4f}")
```
## ๋ชจ๋ธ ํŒŒ์ผ
- `image_classifier_model_0.2.pth`: PyTorch ๋ชจ๋ธ ํŒŒ์ผ
- `image_classifier_model_0.2.onnx`: ONNX ๋ชจ๋ธ ํŒŒ์ผ (์ถ”๋ก  ์ตœ์ ํ™”)
- `image_classifier_model_0.2_model_info.json`: ๋ชจ๋ธ ๋ฉ”ํƒ€๋ฐ์ดํ„ฐ
- `image_classifier_model_0.2_inference_example.py`: ์ถ”๋ก  ์˜ˆ์ œ ์ฝ”๋“œ
## ๋ชจ๋ธ ์•„ํ‚คํ…์ฒ˜
### ๋ฉ€ํ‹ฐํ—ค๋“œ ๋ถ„๋ฅ˜ ์‹œ์Šคํ…œ
```
์ž…๋ ฅ ์ด๋ฏธ์ง€ (224ร—224)
โ†“
DINOv2 ๋ฐฑ๋ณธ (Frozen)
โ†“
๊ณตํ†ต ํŠน์ง• (1024์ฐจ์›)
โ”œโ”€โ”€โ”€ Scene ํ—ค๋“œ โ†’ 6๊ฐœ ํด๋ž˜์Šค
โ”œโ”€โ”€โ”€ Concept ํ—ค๋“œ โ†’ 3๊ฐœ ํด๋ž˜์Šค
โ””โ”€โ”€โ”€ Object ํ—ค๋“œ โ†’ 13๊ฐœ ํด๋ž˜์Šค
```
### ์ฃผ์š” ํŠน์ง•
- **DINOv2 ๋ฐฑ๋ณธ**: ๊ฐ•๋ ฅํ•œ ๋น„์ „ ํŠธ๋žœ์Šคํฌ๋จธ ๊ธฐ๋ฐ˜ ํŠน์ง• ์ถ”์ถœ
- **๋ฐฑ๋ณธ ๊ณ ์ •**: ์‚ฌ์ „ํ›ˆ๋ จ๋œ ํŠน์ง•์„ ํ™œ์šฉํ•˜์—ฌ ๊ณผ์ ํ•ฉ ๋ฐฉ์ง€
- **๋ฉ€ํ‹ฐํ—ค๋“œ**: 3๊ฐœ ํ—ค๋“œ๋กœ ๋‹ค๊ฐ๋„ ๋ถ„์„
- **ํด๋ž˜์Šค ๊ฐ€์ค‘์น˜**: ๋ถˆ๊ท ํ˜• ๋ฐ์ดํ„ฐ ์ž๋™ ๋ณด์ •
## ์ „์ฒ˜๋ฆฌ ์š”๊ตฌ์‚ฌํ•ญ
1. **์ด๋ฏธ์ง€ ํฌ๊ธฐ**: 224x224 ํ”ฝ์…€
2. **์ƒ‰์ƒ ๊ณต๊ฐ„**: RGB
3. **์ •๊ทœํ™”**: ImageNet ํ‘œ์ค€๊ฐ’ ์‚ฌ์šฉ (mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
4. **ํฌ๋กญ**: ์ค‘์•™ ํฌ๋กญ (center crop)
5. **์ง€์› ํ˜•์‹**: JPG, PNG, JPEG
## ์‚ฌ์šฉ ์‚ฌ๋ก€
### ์ง์ ‘ ์‚ฌ์šฉ
- **ํ˜ธํ…” ์ด๋ฏธ์ง€ ์ž๋™ ๋ถ„๋ฅ˜**: ๊ฐ์‹ค, ์š•์‹ค, ๋กœ๋น„ ๋“ฑ ์žฅ๋ฉด๋ณ„ ์ž๋™ ๋ถ„๋ฅ˜
- **์ด๋ฏธ์ง€ ๋ฉ”ํƒ€๋ฐ์ดํ„ฐ ์ƒ์„ฑ**: ์ด๋ฏธ์ง€์˜ ์žฅ๋ฉด, ๊ฐœ๋…, ๊ฐ์ฒด ์ •๋ณด ์ž๋™ ์ถ”์ถœ
- **์ด๋ฏธ์ง€ ๋ฐ์ดํ„ฐ๋ฒ ์ด์Šค ๊ด€๋ฆฌ**: ๋Œ€๋Ÿ‰์˜ ํ˜ธํ…” ์ด๋ฏธ์ง€ ์ž๋™ ํƒœ๊น…
- **ํ’ˆ์งˆ ๊ด€๋ฆฌ**: ์ด๋ฏธ์ง€ ๋ถ„๋ฅ˜ ์ผ๊ด€์„ฑ ๊ฒ€์ฆ
### ๋‹ค์šด์ŠคํŠธ๋ฆผ ์‚ฌ์šฉ
- **ํ˜ธํ…” ๊ด€๋ฆฌ ์‹œ์Šคํ…œ**: ๊ฐ์‹ค ์ด๋ฏธ์ง€ ์ž๋™ ๋ถ„๋ฅ˜ ๋ฐ ๊ด€๋ฆฌ
- **์—ฌํ–‰ ํ”Œ๋žซํผ**: ๊ฐ์‹ค ํƒ€์ž…๋ณ„ ์ด๋ฏธ์ง€ ํ•„ํ„ฐ๋ง
- **๋ถ€๋™์‚ฐ ํ”Œ๋žซํผ**: ์ˆ™์†Œ ์‹œ์„ค ์ •๋ณด ์ž๋™ ์ถ”์ถœ
- **์ด๋ฏธ์ง€ ๊ฒ€์ƒ‰ ์—”์ง„**: ๋‹ค์ค‘ ์†์„ฑ ๊ธฐ๋ฐ˜ ์ด๋ฏธ์ง€ ๊ฒ€์ƒ‰
## ์ œํ•œ์‚ฌํ•ญ
1. **๋„๋ฉ”์ธ ํŠนํ™”**: ํ˜ธํ…”/์ˆ™์†Œ ์ด๋ฏธ์ง€์— ํŠนํ™”๋˜์–ด ์žˆ์–ด ๋‹ค๋ฅธ ๋„๋ฉ”์ธ์—์„œ๋Š” ์„ฑ๋Šฅ์ด ์ œํ•œ์ ์ž…๋‹ˆ๋‹ค.
2. **์ด๋ฏธ์ง€ ํ’ˆ์งˆ**: ์ €ํ™”์งˆ์ด๋‚˜ ๋…ธ์ด์ฆˆ๊ฐ€ ๋งŽ์€ ์ด๋ฏธ์ง€์—์„œ๋Š” ์„ฑ๋Šฅ์ด ์ €ํ•˜๋  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
3. **๊ฐ๋„ ์˜์กด์„ฑ**: ํŠน์ • ๊ฐ๋„์—์„œ ์ดฌ์˜๋œ ์ด๋ฏธ์ง€์— ๋Œ€ํ•ด ์„ฑ๋Šฅ์ด ๋‹ค๋ฅผ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
4. **ํด๋ž˜์Šค ๋ถˆ๊ท ํ˜•**: ์ผ๋ถ€ ํด๋ž˜์Šค๋Š” ๋‹ค๋ฅธ ํด๋ž˜์Šค๋ณด๋‹ค ์„ฑ๋Šฅ์ด ๋‚ฎ์„ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
## ๋ผ์ด์„ ์Šค
Apache 2.0 License
## ์ฐธ๊ณ 
์ด ๋ชจ๋ธ์€ Room Clusterer ํ”„๋กœ์ ํŠธ์˜ ์ผ๋ถ€๋กœ ๊ฐœ๋ฐœ๋˜์—ˆ์Šต๋‹ˆ๋‹ค. ๋” ์ž์„ธํ•œ ์ •๋ณด๋Š” [ํ”„๋กœ์ ํŠธ ์ €์žฅ์†Œ](https://github.com/tportio/content-ml-trainer)๋ฅผ ์ฐธ์กฐํ•˜์„ธ์š”.