mobilevit_model / README.md
gihakkk's picture
Update README.md
735544b verified
---
license: apache-2.0
---
์‚ฌ์šฉ์˜ˆ์‹œ
```python
import onnxruntime as ort
import numpy as np
from transformers import AutoFeatureExtractor
from PIL import Image
# ONNX ๋ชจ๋ธ ๊ฒฝ๋กœ
onnx_model_path = r'C:\mobilevit_model.onnx'
# ONNX ๋Ÿฐํƒ€์ž„ ์„ธ์…˜ ์ดˆ๊ธฐํ™”
ort_session = ort.InferenceSession(onnx_model_path)
# ์ƒˆ๋กœ์šด ์ด๋ฏธ์ง€ ์˜ˆ์ธก ํ•จ์ˆ˜ ์ •์˜
def predict_image(image_path):
# MobileViT ๋ชจ๋ธ์— ๋งž๋Š” ํŠน์ง• ์ถ”์ถœ๊ธฐ ๋กœ๋“œ
feature_extractor = AutoFeatureExtractor.from_pretrained("apple/mobilevit-small")
# ์ด๋ฏธ์ง€๋ฅผ ๋กœ๋“œํ•˜๊ณ  RGB๋กœ ๋ณ€ํ™˜
image = Image.open(image_path).convert("RGB")
# ์ด๋ฏธ์ง€๋ฅผ ํŠน์ง• ์ถ”์ถœ๊ธฐ๋กœ ์ „์ฒ˜๋ฆฌ
inputs = feature_extractor(images=image, return_tensors="np")
input_array = inputs['pixel_values'] # ONNX๋Š” Numpy ํ˜•์‹์„ ์‚ฌ์šฉ
# ONNX ๋ชจ๋ธ์— ์ž…๋ ฅ ์ „๋‹ฌ ๋ฐ ์ถ”๋ก 
ort_inputs = {ort_session.get_inputs()[0].name: input_array}
ort_outputs = ort_session.run(None, ort_inputs)
# ๊ฒฐ๊ณผ ํ•ด์„
logits = ort_outputs[0]
predicted_class = np.argmax(logits, axis=-1).item()
return "๊ทธ๋ƒฅ ์‚ฌ์ง„" if predicted_class == 1 else "๋กœ๋งจ์Šค ์Šค์บ  ์‚ฌ์ง„"
# ์˜ˆ์ธก ์˜ˆ์‹œ
image_path = r'C:\1234567.jpg'
result = predict_image(image_path)
print(result)
```