import gradio as gr from transformers import AutoImageProcessor, AutoModelForImageClassification from PIL import Image import torch import torch.nn.functional as F model_path = "./" processor = AutoImageProcessor.from_pretrained(model_path) model = AutoModelForImageClassification.from_pretrained(model_path) def classify_image(image): if image is None: return None inputs = processor(images=image, return_tensors="pt") with torch.no_grad(): outputs = model(**inputs) probabilities = F.softmax(outputs.logits, dim=1)[0] results = {} for i, prob in enumerate(probabilities): label_name = model.config.id2label[i] results[label_name] = float(prob) return results guide_text = """ ### F1-Score 0.92 성능 확인하기 (필독!) 이 모델은 산업용 엣지 디바이스(단순 배경) 환경을 가정하여 만든 경량화 모델입니다. 검증된 성능을 확인하시려면, 제가 학습에 실제 사용한 이미지를 넣어보세요. 테스트용 샘플 이미지 다운로드 (클릭) """ dev_summary = """ --- ### 개발 로그: f1: (0.68 → 0.96) 단순한 튜닝이 아닌, **데이터 품질 개선**을 통해 문제를 해결한 10단계의 실험 과정입니다. | 단계 | 주요 시도 (Experiment) | F1-Score | 분석 및 결과 (Key Insight) | | :---: | :--- | :---: | :--- | | 1 | Baseline (MobileViT) | 0.68 | 낮은 성능, 클래스 불균형 확인 | | 2~3 | 증강(Augmentation) 재검증 | 0.67 | 학습률/증강 조절했으나 성능 정체 (효과 미미) | | 4~5 | Class Weight 적용 | 0.65 | 노이즈 데이터에 과적합되어 성능 오히려 하락 | | 6 | 파라미터 재조정 | 0.73 | 전처리 변경 없이는 한계임을 확인 | | 7 | 데이터 2차 전처리 (Cleaning) | 0.82 | 불량 데이터 50 삭제 → 성능 비약적 상승 | | 8 | 모델 변경 (EfficientFormer) | 0.92 | 정제된 데이터에 최신 경량 모델 도입 | | 9~10 | 해상도/정규화 추가 실험 | 0.92 | 성능 수렴 (추가 개선폭 미미) | | 11 | 3차 전처리, 모델변경: google/vit-base-patch16-224 | 0.96 | 손실 0.6->0.03으로 감소 | > 결과적으로 트랜스포머 고성능 모델보다 데이터 품질이 중요함. 허나 오염데이터 완전히 제거하진 않았음 자체적인 증강효과를 위해서 ### Classification Report | Class | Precision | Recall | F1-score | Support | |----------|----------:|-------:|---------:|--------:| | PET | 0.96 | 0.94 | 0.95 | 218 | | Can | 0.99 | 0.97 | 0.98 | 283 | | Glass | 0.96 | 0.97 | 0.97 | 221 | | Paper | 0.98 | 0.98 | 0.98 | 315 | | Plastic | 0.95 | 0.95 | 0.95 | 308 | | Vinyl | 0.95 | 0.97 | 0.96 | 282 | | **Accuracy** | | | **0.96** | | | **Macro Avg** | 0.96 | 0.96 | 0.96 | | | **Weighted Avg** | 0.97 | 0.96 | 0.96 | | 사용 데이터 : jms0923/tod: Trash_Object_Detection_Dataset_v1.0(zenodo) """ interface = gr.Interface( fn=classify_image, inputs=gr.Image(type="pil", label="여기에 이미지를 드래그하세요"), outputs=gr.Label(num_top_classes=3, label="분류 결과"), title=" 경량화 재활용품 분류기", description=guide_text, article=dev_summary ) interface.launch()