mosshoon commited on
Commit
06dc977
·
1 Parent(s): b778b87

feat: ResNet50 변경

Browse files
app.py CHANGED
@@ -28,8 +28,9 @@ except ImportError:
28
  return 100
29
  return 20 * np.log10(255.0 / np.sqrt(mse))
30
  import io
 
31
  import torch
32
- import lpips
33
  import torchvision.transforms as transforms
34
  import ssl
35
 
@@ -45,6 +46,7 @@ class ImageSimilarityLeaderboard:
45
 
46
  # 메모리 최적화를 위한 캐시 및 락 (먼저 초기화)
47
  self._ref_image_cache = None
 
48
  self._cache_loaded = False
49
  self._file_lock = threading.Lock() # 파일 I/O 동시성 제어
50
  self._processing_lock = threading.Lock() # 처리 동시성 제어
@@ -53,15 +55,25 @@ class ImageSimilarityLeaderboard:
53
  self.leaderboard_data = self.load_leaderboard()
54
  self.last_modified = self.get_file_modified_time()
55
 
56
- # LPIPS 모델 초기화 (한 번만 로드)
57
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
58
  try:
59
- self.lpips_model = lpips.LPIPS(net='vgg').to(self.device)
60
- self.lpips_model.eval() # 평가 모드
61
- print(f"✅ LPIPS 모델 로드 완료 (Device: {self.device})")
 
 
 
 
 
 
 
 
 
 
62
  except Exception as e:
63
- print(f"⚠️ LPIPS 모델 로드 실패: {e}")
64
- self.lpips_model = None
65
 
66
  # macOS 호환성을 위한 경고 억제
67
  import warnings
@@ -116,6 +128,17 @@ class ImageSimilarityLeaderboard:
116
  ref_image = cv2.resize(ref_image, (new_width, new_height))
117
 
118
  self._ref_image_cache = cv2.cvtColor(ref_image, cv2.COLOR_BGR2RGB)
 
 
 
 
 
 
 
 
 
 
 
119
  self._cache_loaded = True
120
  # 메모리 정리
121
  del ref_image
@@ -141,33 +164,33 @@ class ImageSimilarityLeaderboard:
141
 
142
  def calculate_similarity(self, image1, image2):
143
  try:
144
- # 1) LPIPS 계산 (Perceptual Similarity) - 가장 중요
145
- lpips_score = 0.0
146
- if self.lpips_model is not None:
147
  try:
148
- # LPIPS를 위한 전처리 (-1 ~ 1 사이 값으로 정규화, RGB)
149
- # 이미지 크기는 224x224 이상 권장, 여기서는 256x256으로 리사이즈
150
- lpips_size = (256, 256)
151
-
152
- # numpy -> tensor
153
- img1_t = cv2.resize(image1, lpips_size).astype(np.float32) / 127.5 - 1.0
154
- img2_t = cv2.resize(image2, lpips_size).astype(np.float32) / 127.5 - 1.0
155
-
156
- img1_t = torch.from_numpy(img1_t).permute(2, 0, 1).unsqueeze(0).to(self.device)
157
- img2_t = torch.from_numpy(img2_t).permute(2, 0, 1).unsqueeze(0).to(self.device)
158
 
159
  with torch.no_grad():
160
- d = self.lpips_model(img1_t, img2_t)
161
- dist = d.item()
162
 
163
- # LPIPS 거리를 점수로 변환 (0이 동일, 보통 0.5 이상이면 꽤 다름)
164
- # 거리가 0이면 100점, 0.5면 50점, 1.0이면 0점
165
- # 조금 더 관대하게: max(0, (1 - dist) * 100)
166
- lpips_score = max(0, (1 - dist) * 100)
 
 
 
 
 
 
 
 
167
 
168
  except Exception as e:
169
- print(f"LPIPS 계산 오류: {e}")
170
- lpips_score = 0.0
171
 
172
  # 2) 그레이스케일 변환 (기존 로직 유지)
173
  if image1.ndim == 3:
@@ -209,14 +232,13 @@ class ImageSimilarityLeaderboard:
209
  hist_corr = cv2.compareHist(hist1, hist2, cv2.HISTCMP_CORREL)
210
  hist_score = (hist_corr + 1) / 2 # -1~1 → 0~1
211
 
212
- # 7) 최종 점수 계산 (LPIPS 비중 대폭 강화)
213
- # LPIPS 모델이 있으면 LPIPS 80%, SSIM 10%, Hist 10%
214
- # 없으면 기존 방식 (SSIM 70%, Hist 30%)
215
 
216
- if self.lpips_model is not None:
217
- final_score = (lpips_score * 0.8) + (ssim_score * 100 * 0.1) + (hist_score * 100 * 0.1)
218
  else:
219
- print(f"LPIPS 모델이 없어서 SSIM 70%, Hist 30%로 계산")
220
  final_score = (ssim_score * 0.7 + hist_score * 0.3) * 100
221
 
222
  # 8) PSNR이 높으면 약간의 보너스 (최대 5점)
@@ -228,12 +250,12 @@ class ImageSimilarityLeaderboard:
228
  'ssim': float(ssim_score),
229
  'psnr': float(psnr_score * 100),
230
  'histogram': float(hist_score),
231
- 'lpips': float(lpips_score), # 결과에 포함
232
  'final_score': float(final_score)
233
  }
234
  except Exception as e:
235
  print(f"유사도 계산 오류: {e}")
236
- return {'ssim':0.0,'psnr':0.0,'histogram':0.0,'lpips':0.0,'final_score':0.0}
237
 
238
  def process_image(self, uploaded_image, username):
239
  """업로드된 이미지를 처리하고 점수를 계산합니다."""
@@ -298,7 +320,7 @@ class ImageSimilarityLeaderboard:
298
  'ssim': float(round(similarity_scores['ssim'], 4)),
299
  'psnr': float(round(similarity_scores['psnr'], 2)),
300
  'histogram': float(round(similarity_scores['histogram'], 4)),
301
- 'lpips': float(round(similarity_scores.get('lpips', 0.0), 2))
302
  }
303
 
304
  # 같은 이름의 기존 기록이 있는지 확인하고, 더 높은 점수만 유지
@@ -489,7 +511,7 @@ def create_interface():
489
  """Gradio 인터페이스 생성"""
490
  with gr.Blocks(title="비슷한 이미지를 만들어주세요!", theme=gr.themes.Soft()) as demo:
491
  gr.Markdown("""
492
- # 🏆 참조 이미지와 얼마나 유사한지 측정하여 리더보드에 등록해보세요!
493
  """)
494
 
495
 
@@ -616,83 +638,83 @@ def create_interface():
616
  outputs=[result_output_challenge, leaderboard_output_challenge]
617
  )
618
 
619
- # # Originality반 탭
620
- # with gr.Tab("🎨 Originality반"):
621
- # with gr.Row():
622
- # with gr.Column(scale=1):
623
- # gr.Markdown("### 📤 이미지 업로드 (Originality반)")
624
- # image_input_originality = gr.Image(
625
- # label="비교할 이미지를 업로드하세요",
626
- # type="pil",
627
- # height=300
628
- # )
629
- # username_input_originality = gr.Textbox(
630
- # label="사용자 이름",
631
- # placeholder="이름을 입력하세요",
632
- # max_lines=1
633
- # )
634
- # submit_btn_originality = gr.Button("🚀 점수 계산 및 등록", variant="primary", size="lg")
635
-
636
- # with gr.Column(scale=1):
637
- # gr.Markdown("### 📊 결과 (Originality반)")
638
- # result_output_originality = gr.Textbox(
639
- # label="계산 결과",
640
- # lines=10,
641
- # interactive=False
642
- # )
643
-
644
- # gr.Markdown("### 🏅 Originality반 리더보드")
645
- # leaderboard_output_originality = gr.Dataframe(
646
- # headers=["순위", "사용자명", "점수", "날짜"],
647
- # datatype=["number", "str", "number", "str"],
648
- # interactive=False
649
- # )
650
-
651
- # # Originality반 이벤트 핸들러
652
- # submit_btn_originality.click(
653
- # fn=process_user_image_originality,
654
- # inputs=[image_input_originality, username_input_originality],
655
- # outputs=[result_output_originality, leaderboard_output_originality]
656
- # )
657
-
658
- # # BCE반 탭
659
- # with gr.Tab("🔥 BCE반"):
660
- # with gr.Row():
661
- # with gr.Column(scale=1):
662
- # gr.Markdown("### 📤 이미지 업로드 (BCE반)")
663
- # image_input_bce = gr.Image(
664
- # label="비교할 이미지를 업로드하세요",
665
- # type="pil",
666
- # height=300
667
- # )
668
- # username_input_bce = gr.Textbox(
669
- # label="사용자 이름",
670
- # placeholder="이름을 입력하세요",
671
- # max_lines=1
672
- # )
673
- # submit_btn_bce = gr.Button("🚀 점수 계산 및 등록", variant="primary", size="lg")
674
-
675
- # with gr.Column(scale=1):
676
- # gr.Markdown("### 📊 결과 (BCE반)")
677
- # result_output_bce = gr.Textbox(
678
- # label="계산 결과",
679
- # lines=10,
680
- # interactive=False
681
- # )
682
-
683
- # gr.Markdown("### 🏅 BCE반 리더보드")
684
- # leaderboard_output_bce = gr.Dataframe(
685
- # headers=["순위", "사용자명", "점수", "날짜"],
686
- # datatype=["number", "str", "number", "str"],
687
- # interactive=False
688
- # )
689
-
690
- # # BCE반 이벤트 핸들러
691
- # submit_btn_bce.click(
692
- # fn=process_user_image_bce,
693
- # inputs=[image_input_bce, username_input_bce],
694
- # outputs=[result_output_bce, leaderboard_output_bce]
695
- # )
696
 
697
  # 페이지 로드 시 모든 리더보드 표시
698
  demo.load(
 
28
  return 100
29
  return 20 * np.log10(255.0 / np.sqrt(mse))
30
  import io
31
+ from PIL import Image
32
  import torch
33
+ import torchvision.models as models
34
  import torchvision.transforms as transforms
35
  import ssl
36
 
 
46
 
47
  # 메모리 최적화를 위한 캐시 및 락 (먼저 초기화)
48
  self._ref_image_cache = None
49
+ self._ref_embedding = None # ResNet 임베딩 캐시
50
  self._cache_loaded = False
51
  self._file_lock = threading.Lock() # 파일 I/O 동시성 제어
52
  self._processing_lock = threading.Lock() # 처리 동시성 제어
 
55
  self.leaderboard_data = self.load_leaderboard()
56
  self.last_modified = self.get_file_modified_time()
57
 
58
+ # ResNet 모델 초기화 (한 번만 로드)
59
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
60
  try:
61
+ # ResNet50 (ImageNet weights) - 마지막 FC 레이어 제외
62
+ resnet = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)
63
+ self.resnet_model = torch.nn.Sequential(*(list(resnet.children())[:-1])).to(self.device)
64
+ self.resnet_model.eval()
65
+
66
+ # 전처리 파이프라인
67
+ self.preprocess = transforms.Compose([
68
+ transforms.Resize(256),
69
+ transforms.CenterCrop(224),
70
+ transforms.ToTensor(),
71
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
72
+ ])
73
+ print(f"✅ ResNet 모델 로드 완료 (Device: {self.device})")
74
  except Exception as e:
75
+ print(f"⚠️ ResNet 모델 로드 실패: {e}")
76
+ self.resnet_model = None
77
 
78
  # macOS 호환성을 위한 경고 억제
79
  import warnings
 
128
  ref_image = cv2.resize(ref_image, (new_width, new_height))
129
 
130
  self._ref_image_cache = cv2.cvtColor(ref_image, cv2.COLOR_BGR2RGB)
131
+
132
+ # ResNet 임베딩 계산 및 캐시
133
+ if self.resnet_model is not None:
134
+ try:
135
+ pil_img = Image.fromarray(self._ref_image_cache)
136
+ img_t = self.preprocess(pil_img).unsqueeze(0).to(self.device)
137
+ with torch.no_grad():
138
+ self._ref_embedding = self.resnet_model(img_t).flatten()
139
+ except Exception as e:
140
+ print(f"참조 이미지 임베딩 실패: {e}")
141
+
142
  self._cache_loaded = True
143
  # 메모리 정리
144
  del ref_image
 
164
 
165
  def calculate_similarity(self, image1, image2):
166
  try:
167
+ # 1) ResNet Feature Similarity (Semantic Similarity) - 가장 중요
168
+ resnet_score = 0.0
169
+ if self.resnet_model is not None and self._ref_embedding is not None:
170
  try:
171
+ # 사용자 이미지 전처리
172
+ pil_img = Image.fromarray(image2) # image2 is RGB numpy array
173
+ img_t = self.preprocess(pil_img).unsqueeze(0).to(self.device)
 
 
 
 
 
 
 
174
 
175
  with torch.no_grad():
176
+ user_emb = self.resnet_model(img_t).flatten()
 
177
 
178
+ # Cosine Similarity
179
+ cos_sim = torch.nn.functional.cosine_similarity(
180
+ self._ref_embedding.unsqueeze(0),
181
+ user_emb.unsqueeze(0)
182
+ ).item()
183
+
184
+ # Sigmoid Scoring Formula
185
+ # Sim 0.61 (Bad) -> Score 14
186
+ # Sim 0.77 (Good) -> Score 80
187
+ # Sim 0.92 (Perfect) -> Score 99
188
+ # Formula: 100 / (1 + exp(-20 * (sim - 0.7)))
189
+ resnet_score = 100 / (1 + np.exp(-20 * (cos_sim - 0.7)))
190
 
191
  except Exception as e:
192
+ print(f"ResNet 계산 오류: {e}")
193
+ resnet_score = 0.0
194
 
195
  # 2) 그레이스케일 변환 (기존 로직 유지)
196
  if image1.ndim == 3:
 
232
  hist_corr = cv2.compareHist(hist1, hist2, cv2.HISTCMP_CORREL)
233
  hist_score = (hist_corr + 1) / 2 # -1~1 → 0~1
234
 
235
+ # 7) 최종 점수 계산 (ResNet 비중 대폭 강화)
236
+ # ResNet 모델이 있으면 ResNet 80%, SSIM 10%, Hist 10%
 
237
 
238
+ if self.resnet_model is not None:
239
+ final_score = (resnet_score * 0.8) + (ssim_score * 100 * 0.1) + (hist_score * 100 * 0.1)
240
  else:
241
+ print(f"ResNet 모델이 없어서 SSIM 70%, Hist 30%로 계산")
242
  final_score = (ssim_score * 0.7 + hist_score * 0.3) * 100
243
 
244
  # 8) PSNR이 높으면 약간의 보너스 (최대 5점)
 
250
  'ssim': float(ssim_score),
251
  'psnr': float(psnr_score * 100),
252
  'histogram': float(hist_score),
253
+ 'resnet': float(resnet_score), # 결과에 포함
254
  'final_score': float(final_score)
255
  }
256
  except Exception as e:
257
  print(f"유사도 계산 오류: {e}")
258
+ return {'ssim':0.0,'psnr':0.0,'histogram':0.0,'resnet':0.0,'final_score':0.0}
259
 
260
  def process_image(self, uploaded_image, username):
261
  """업로드된 이미지를 처리하고 점수를 계산합니다."""
 
320
  'ssim': float(round(similarity_scores['ssim'], 4)),
321
  'psnr': float(round(similarity_scores['psnr'], 2)),
322
  'histogram': float(round(similarity_scores['histogram'], 4)),
323
+ 'resnet': float(round(similarity_scores.get('resnet', 0.0), 2))
324
  }
325
 
326
  # 같은 이름의 기존 기록이 있는지 확인하고, 더 높은 점수만 유지
 
511
  """Gradio 인터페이스 생성"""
512
  with gr.Blocks(title="비슷한 이미지를 만들어주세요!", theme=gr.themes.Soft()) as demo:
513
  gr.Markdown("""
514
+ # 🏆 롯데 비전 스튜디오 리더보드
515
  """)
516
 
517
 
 
638
  outputs=[result_output_challenge, leaderboard_output_challenge]
639
  )
640
 
641
+ # Originality반 탭
642
+ with gr.Tab("🎨 Originality반"):
643
+ with gr.Row():
644
+ with gr.Column(scale=1):
645
+ gr.Markdown("### 📤 이미지 업로드 (Originality반)")
646
+ image_input_originality = gr.Image(
647
+ label="비교할 이미지를 업로드하세요",
648
+ type="pil",
649
+ height=300
650
+ )
651
+ username_input_originality = gr.Textbox(
652
+ label="사용자 이름",
653
+ placeholder="이름을 입력하세요",
654
+ max_lines=1
655
+ )
656
+ submit_btn_originality = gr.Button("🚀 점수 계산 및 등록", variant="primary", size="lg")
657
+
658
+ with gr.Column(scale=1):
659
+ gr.Markdown("### 📊 결과 (Originality반)")
660
+ result_output_originality = gr.Textbox(
661
+ label="계산 결과",
662
+ lines=10,
663
+ interactive=False
664
+ )
665
+
666
+ gr.Markdown("### 🏅 Originality반 리더보드")
667
+ leaderboard_output_originality = gr.Dataframe(
668
+ headers=["순위", "사용자명", "점수", "날짜"],
669
+ datatype=["number", "str", "number", "str"],
670
+ interactive=False
671
+ )
672
+
673
+ # Originality반 이벤트 핸들러
674
+ submit_btn_originality.click(
675
+ fn=process_user_image_originality,
676
+ inputs=[image_input_originality, username_input_originality],
677
+ outputs=[result_output_originality, leaderboard_output_originality]
678
+ )
679
+
680
+ # BCE반 탭
681
+ with gr.Tab("���� BCE반"):
682
+ with gr.Row():
683
+ with gr.Column(scale=1):
684
+ gr.Markdown("### 📤 이미지 업로드 (BCE반)")
685
+ image_input_bce = gr.Image(
686
+ label="비교할 이미지를 업로드하세요",
687
+ type="pil",
688
+ height=300
689
+ )
690
+ username_input_bce = gr.Textbox(
691
+ label="사용자 이름",
692
+ placeholder="이름을 입력하세요",
693
+ max_lines=1
694
+ )
695
+ submit_btn_bce = gr.Button("🚀 점수 계산 및 등록", variant="primary", size="lg")
696
+
697
+ with gr.Column(scale=1):
698
+ gr.Markdown("### 📊 결과 (BCE반)")
699
+ result_output_bce = gr.Textbox(
700
+ label="계산 결과",
701
+ lines=10,
702
+ interactive=False
703
+ )
704
+
705
+ gr.Markdown("### 🏅 BCE반 리더보드")
706
+ leaderboard_output_bce = gr.Dataframe(
707
+ headers=["순위", "사용자명", "점수", "날짜"],
708
+ datatype=["number", "str", "number", "str"],
709
+ interactive=False
710
+ )
711
+
712
+ # BCE반 이벤트 핸들러
713
+ submit_btn_bce.click(
714
+ fn=process_user_image_bce,
715
+ inputs=[image_input_bce, username_input_bce],
716
+ outputs=[result_output_bce, leaderboard_output_bce]
717
+ )
718
 
719
  # 페이지 로드 시 모든 리더보드 표시
720
  demo.load(
test_images/bad1.png ADDED

Git LFS Details

  • SHA256: fd9a8d320d05e09656955fff1fcff6f5b0a25565bf4bb7dd6f38804057cef52b
  • Pointer size: 129 Bytes
  • Size of remote file: 6.21 kB
test_images/bad2.png ADDED

Git LFS Details

  • SHA256: d65a7743bc140c9df4e7051799873a16b0da6ec8043352712c668818a1a0edef
  • Pointer size: 132 Bytes
  • Size of remote file: 1.05 MB
test_images/bad3.jpeg ADDED
test_images/bad4.png ADDED

Git LFS Details

  • SHA256: ede1c81b3e7f91d9f5370532e46849001cfa6b600ab183c9640ec3645a091e06
  • Pointer size: 131 Bytes
  • Size of remote file: 674 kB
test_images/bad5.png ADDED

Git LFS Details

  • SHA256: 53b5ae45229ce9feb5e6699a026878208ae9d678337e2c6a6f0570a5d87b72e0
  • Pointer size: 131 Bytes
  • Size of remote file: 588 kB
test_images/good1.png ADDED

Git LFS Details

  • SHA256: 86a539d9cd4357e5a7e65bd559d3a5c6ef25ced1e7f0cb2bfb367b6ac0e522fa
  • Pointer size: 131 Bytes
  • Size of remote file: 670 kB
test_images/good2.png ADDED

Git LFS Details

  • SHA256: 5c04fa785e43fcc2d33727ef3786de896333cd7facb83b36d40e313849da6115
  • Pointer size: 131 Bytes
  • Size of remote file: 361 kB
test_images/good3.png ADDED

Git LFS Details

  • SHA256: cf7466599fb2ee95c1c70b1f1b69a34c88b7abd46fce4e86711182e863b883ff
  • Pointer size: 132 Bytes
  • Size of remote file: 1.19 MB
test_images/soso1.png ADDED

Git LFS Details

  • SHA256: 24a1dc571b29737d7b3c45b2dcc351089a0b57a7637a5a160ffee02ca4b8cf5c
  • Pointer size: 132 Bytes
  • Size of remote file: 1.11 MB
test_images/soso2.png ADDED

Git LFS Details

  • SHA256: 087f071f8f2a3a970acf492a23a8b546c217cc9f1389d109a5ec00b7208cac89
  • Pointer size: 131 Bytes
  • Size of remote file: 980 kB