minjune121 commited on
Commit
46301e8
ยท
verified ยท
1 Parent(s): 276019d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +299 -57
app.py CHANGED
@@ -1,20 +1,19 @@
1
  """
2
  Boolook - ์Œ์„ฑ ๊ธฐ๋ฐ˜ ๊ฐ์ • ๋ถ„์„ ์ฑ… ์ถ”์ฒœ (HuggingFace Spaces)
3
  ์ˆ˜์ •์‚ฌํ•ญ:
4
- - ์ž„๋ฒ ๋”ฉ ๋กœ๋”ฉ์„ ๋ฐฑ๊ทธ๋ผ์šด๋“œ ์Šค๋ ˆ๋“œ๋กœ ๋ถ„๋ฆฌ (ํƒ€์ž„์•„์›ƒ ๋ฐฉ์ง€)
5
- - ๋ฐฐ์น˜ ํฌ๊ธฐ 128๋กœ ์ฆ๊ฐ€ (์†๋„ ํ–ฅ์ƒ)
6
- - ์„œ๋ฒ„๊ฐ€ ๋จผ์ € ์—ด๋ฆฐ ๋’ค ๋ฐ์ดํ„ฐ ๋กœ๋”ฉ ์ง„ํ–‰
7
- - ์ถ”์ฒœ ๊ฒฐ๊ณผ ์ถœ๋ ฅ์„ JSON ํ˜•์‹์œผ๋กœ ๋‹จ์ˆœํ™”
8
- - emotion_score: ์ฃผ๊ฐ์ • ๋‹จ์ผ ์ˆ˜์น˜
9
- - user_input / recommendation_books ํ‚ค ์‚ฌ์šฉ
10
- - ์˜ค๋””์˜ค type="filepath" + soundfile ๋ถ„๊ธฐ ์ฒ˜๋ฆฌ
11
- - ํ”ผ๋“œ๋ฐฑ UI ์ œ๊ฑฐ โ†’ /api/feedback ์—”๋“œํฌ์ธํŠธ๋กœ ๋Œ€์ฒด
12
  """
13
 
14
  import gradio as gr
15
  import pandas as pd
16
  import numpy as np
17
  import torch
 
 
18
  import pickle
19
  import csv
20
  import json
@@ -35,13 +34,21 @@ logger = logging.getLogger(__name__)
35
  # ============================================================
36
  # ์„ค์ •
37
  # ============================================================
38
- BOOK_DB_PATH = Path("book_db_final.csv")
39
- FEEDBACK_PATH = Path("user_feedback.csv")
40
- SBERT_CACHE_PATH = Path("book_embeddings.pkl")
41
- SAMPLE_RATE = 16000
 
42
  MAX_EMBEDDING_BATCH = 128
43
 
 
 
 
 
 
 
44
  device = 0 if torch.cuda.is_available() else -1
 
45
  logger.info(f"๋””๋ฐ”์ด์Šค: {'GPU' if device == 0 else 'CPU'}")
46
 
47
  # ============================================================
@@ -53,7 +60,245 @@ _data_ready = False
53
  _data_lock = threading.Lock()
54
 
55
  # ============================================================
56
- # ๋ชจ๋ธ ๋กœ๋”ฉ
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  # ============================================================
58
  logger.info("๋ชจ๋ธ ๋กœ๋”ฉ ์ค‘...")
59
 
@@ -79,17 +324,6 @@ try:
79
  except Exception as e:
80
  logger.error(f"SBERT ๋กœ๋“œ ์‹คํŒจ: {e}")
81
 
82
- audio_emotion_pipeline = None
83
- try:
84
- audio_emotion_pipeline = hf_pipeline(
85
- "audio-classification",
86
- model="superb/wav2vec2-base-superb-er",
87
- device=device,
88
- )
89
- logger.info("์Œ์„ฑ ๊ฐ์ • ๋ชจ๋ธ ๋กœ๋“œ ์™„๋ฃŒ")
90
- except Exception as e:
91
- logger.warning(f"์Œ์„ฑ ๊ฐ์ • ๋ชจ๋ธ ์Šคํ‚ต (ํ…์ŠคํŠธ๋งŒ ์‚ฌ์šฉ): {e}")
92
-
93
  logger.info("๋ชจ๋ธ ๋กœ๋”ฉ ์™„๋ฃŒ!")
94
 
95
  # ============================================================
@@ -118,7 +352,13 @@ if sbert_model:
118
  except Exception as e:
119
  logger.error(f"๊ฐ์ • ๋ ˆ์ด๋ธ” ์ž„๋ฒ ๋”ฉ ์‹คํŒจ: {e}")
120
 
121
- _AUDIO_LABEL_MAP = {"hap": "๊ธฐ์จ", "neu": "์‹ ๋ขฐ", "sad": "์Šฌํ””", "ang": "๋ถ„๋…ธ"}
 
 
 
 
 
 
122
 
123
  _KEYWORD_BOOSTS = {
124
  "์Šฌํ””": ["์Šฌํ”„", "์šฐ์šธ", "๋ˆˆ๋ฌผ", "ํž˜๋“ค", "์™ธ๋กœ"],
@@ -246,25 +486,23 @@ def text_emotion_scores(text: str) -> Dict[str, float]:
246
  return scores
247
 
248
 
249
- def audio_emotion_scores(audio_array: np.ndarray, sr: int) -> Dict[str, float]:
250
- scores = {emo: 0.0 for emo in _EMOTION_LABELS}
251
- if audio_emotion_pipeline is None:
252
- return scores
 
 
253
 
254
- try:
255
- import scipy.io.wavfile as wav_io
256
- import tempfile
257
- with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
258
- wav_io.write(tmp.name, sr, (audio_array * 32767).astype(np.int16))
259
- results = audio_emotion_pipeline(tmp.name)
260
- Path(tmp.name).unlink(missing_ok=True)
261
- for item in results:
262
- mapped = _AUDIO_LABEL_MAP.get(item["label"])
263
- if mapped:
264
- scores[mapped] += item["score"]
265
- except Exception as e:
266
- logger.warning(f"์Œ์„ฑ ๊ฐ์ • ์‹คํŒจ: {e}")
267
- return scores
268
 
269
 
270
  def fused_emotion(t_scores: Dict[str, float], a_scores: Dict[str, float]) -> Tuple[str, Dict[str, float]]:
@@ -280,6 +518,7 @@ def fused_emotion(t_scores: Dict[str, float], a_scores: Dict[str, float]) -> Tup
280
  top_emotion = max(combined, key=combined.get)
281
  return top_emotion, combined
282
 
 
283
  # ============================================================
284
  # ์ถ”์ฒœ
285
  # ============================================================
@@ -324,6 +563,7 @@ def get_recommendations(user_input: str, emotion: str, top_n: int = 3) -> List[D
324
  logger.error(f"์ถ”์ฒœ ์‹คํŒจ: {e}")
325
  return []
326
 
 
327
  # ============================================================
328
  # ์ถ”์ฒœ ๊ฒฐ๊ณผ โ†’ JSON ๋ Œ๋”๋ง
329
  # ============================================================
@@ -349,6 +589,7 @@ def _render_books_json(user_input: str, emotion: str, combined: Dict[str, float]
349
  }
350
  return json.dumps(output, ensure_ascii=False, indent=2)
351
 
 
352
  # ============================================================
353
  # ํ”ผ๋“œ๋ฐฑ
354
  # ============================================================
@@ -423,12 +664,12 @@ def api_feedback(feedback_data) -> str:
423
 
424
  def get_feedback_stats() -> str:
425
  if not FEEDBACK_PATH.exists():
426
- return "๐Ÿ“Š ์•„์ง ํ”ผ๋“œ๋ฐฑ์ด ์—†์Šต๋‹ˆ๋‹ค."
427
  try:
428
  fb_df = pd.read_csv(FEEDBACK_PATH, encoding="utf-8-sig", on_bad_lines="skip")
429
  total = len(fb_df)
430
  if total == 0:
431
- return "๐Ÿ“Š ์•„์ง ํ”ผ๋“œ๋ฐฑ์ด ์—†์Šต๋‹ˆ๋‹ค."
432
  emo_counts = fb_df.groupby("emotion")["accepted"].agg(["count", "sum"])
433
  lines = [f"**์ด ํ”ผ๋“œ๋ฐฑ: {total}๊ฑด**\n"]
434
  for emo, row_s in emo_counts.iterrows():
@@ -440,6 +681,7 @@ def get_feedback_stats() -> str:
440
  except Exception as e:
441
  return f"ํ†ต๊ณ„ ๋กœ๋“œ ์‹คํŒจ: {e}"
442
 
 
443
  # ============================================================
444
  # ๋ฉ”์ธ ์ฒ˜๋ฆฌ
445
  # ============================================================
@@ -477,7 +719,7 @@ def process_voice(audio_input):
477
  return json.dumps({"error": "์Œ์„ฑ์ด ์ธ์‹๋˜์ง€ ์•Š์•˜์Šต๋‹ˆ๋‹ค."}, ensure_ascii=False, indent=2), []
478
 
479
  t_scores = text_emotion_scores(user_input)
480
- a_scores = audio_emotion_scores(y, sr)
481
  top_label, combined = fused_emotion(t_scores, a_scores)
482
  books = get_recommendations(user_input, top_label, top_n=3)
483
  books_json = _render_books_json(user_input, top_label, combined, books)
@@ -493,40 +735,40 @@ def run_analysis(audio):
493
  books_json, books = process_voice(audio)
494
  return books_json, books
495
 
 
496
  # ============================================================
497
  # Gradio UI
498
  # ============================================================
499
- with gr.Blocks(theme=gr.themes.Soft(), title="Boolook ๐Ÿ“š") as demo:
500
  gr.Markdown("""
501
- # ๐Ÿ“š Boolook โ€” ์Œ์„ฑ ๊ธฐ๋ฐ˜ ๊ฐ์ • ๋ถ„์„ ์ฑ… ์ถ”์ฒœ
502
  ๋‹น์‹ ์˜ ๊ฐ์ •์„ ๋ง๋กœ ํ‘œํ˜„ํ•˜๋ฉด, AI๊ฐ€ ๋”ฑ ๋งž๋Š” ์ฑ…์„ ์ถ”์ฒœํ•ด๋“œ๋ฆฝ๋‹ˆ๋‹ค.
503
-
504
- ๐ŸŽค **์‚ฌ์šฉ๋ฒ•:** ๋งˆ์ดํฌ๋กœ ๊ฐ์ • ํ‘œํ˜„ โ†’ ๋ถ„์„ํ•˜๊ธฐ
505
  """)
506
 
507
  state_books = gr.State([])
508
 
509
  with gr.Row():
510
  with gr.Column(scale=1):
511
- gr.Markdown("### ๐ŸŽค ์Œ์„ฑ ์ž…๋ ฅ")
512
  audio_in = gr.Audio(
513
  sources=["microphone", "upload"],
514
  type="filepath",
515
  label="๋งˆ์ดํฌ ๋˜๋Š” ํŒŒ์ผ ์—…๋กœ๋“œ",
516
  )
517
- analyze_btn = gr.Button("๐Ÿ” ๋ถ„์„ํ•˜๊ธฐ", variant="primary", size="lg")
518
- gr.Markdown("๐Ÿ’ก ์˜ˆ: '์˜ค๋Š˜ ๋„ˆ๋ฌด ์Šฌํผ์š”', 'ํ–‰๋ณตํ•œ ๊ธฐ๋ถ„์ด์—์š”'")
519
 
520
  with gr.Column(scale=1):
521
  out_books_json = gr.Code(
522
- label="๐Ÿ“Š ๋ถ„์„ ๊ฒฐ๊ณผ & ๐Ÿ“– ์ถ”์ฒœ ๋„์„œ",
523
  language="json",
524
  interactive=False,
525
  )
526
 
527
- with gr.Accordion("๐Ÿ“ˆ ํ†ต๊ณ„", open=False):
528
  stats_md = gr.Markdown("์ƒˆ๋กœ๊ณ ์นจ์„ ๋ˆŒ๋Ÿฌ์ฃผ์„ธ์š”.")
529
- refresh_btn = gr.Button("๐Ÿ”„ ํ†ต๊ณ„ ์ƒˆ๋กœ๊ณ ์นจ")
530
  refresh_btn.click(fn=get_feedback_stats, outputs=stats_md)
531
 
532
  # ํ”ผ๋“œ๋ฐฑ API ์—”๋“œํฌ์ธํŠธ (ํด๋ผ์ด์–ธํŠธ ์ „์šฉ, UI ๋ฏธ๋…ธ์ถœ)
@@ -548,4 +790,4 @@ with gr.Blocks(theme=gr.themes.Soft(), title="Boolook ๐Ÿ“š") as demo:
548
  )
549
 
550
  if __name__ == "__main__":
551
- demo.launch()
 
1
  """
2
  Boolook - ์Œ์„ฑ ๊ธฐ๋ฐ˜ ๊ฐ์ • ๋ถ„์„ ์ฑ… ์ถ”์ฒœ (HuggingFace Spaces)
3
  ์ˆ˜์ •์‚ฌํ•ญ:
4
+ - final_emotion_model_v3.pth (ResNet-SE + BiLSTM + Attention) ์ปค์Šคํ…€ ๋ชจ๋ธ ํ†ตํ•ฉ
5
+ - superb/wav2vec2-base-superb-er ๋Œ€์‹  ์ปค์Šคํ…€ ๋ชจ๋ธ๋กœ ์Œ์„ฑ ๊ฐ์ • ๋ถ„๋ฅ˜
6
+ - ๋ชจ๋ธ ํด๋ž˜์Šค ์ •์˜ (SEBlock, ResBlock, AttentionPooling, EmotionResNet) ํฌํ•จ
7
+ - Mel-spectrogram ์ „์ฒ˜๋ฆฌ + TTA(n_tta=8) ์ถ”๋ก  + temperature scaling ์ ์šฉ
8
+ - 4ํด๋ž˜์Šค(Angry/Happy/Neutral/Sad) โ†’ ํ•œ๊ตญ์–ด ๊ฐ์ • ๋ ˆ์ด๋ธ” ๋งคํ•‘
 
 
 
9
  """
10
 
11
  import gradio as gr
12
  import pandas as pd
13
  import numpy as np
14
  import torch
15
+ import torch.nn as nn
16
+ import torch.nn.functional as F
17
  import pickle
18
  import csv
19
  import json
 
34
  # ============================================================
35
  # ์„ค์ •
36
  # ============================================================
37
+ BOOK_DB_PATH = Path("book_db_final.csv")
38
+ FEEDBACK_PATH = Path("user_feedback.csv")
39
+ SBERT_CACHE_PATH = Path("book_embeddings.pkl")
40
+ EMOTION_MODEL_PATH = Path("final_emotion_model_v3.pth")
41
+ SAMPLE_RATE = 16000
42
  MAX_EMBEDDING_BATCH = 128
43
 
44
+ # Mel-spectrogram ํŒŒ๋ผ๋ฏธํ„ฐ (ํ•™์Šต ์‹œ ์‚ฌ์šฉํ•œ ๊ฐ’๊ณผ ๋™์ผํ•˜๊ฒŒ ๋งž์ถœ ๊ฒƒ)
45
+ N_MELS = 64
46
+ N_FFT = 1024
47
+ HOP_LEN = 512
48
+ MAX_FRAMES = 128 # ์‹œ๊ฐ„ ์ถ• ๊ณ ์ • ๊ธธ์ด
49
+
50
  device = 0 if torch.cuda.is_available() else -1
51
+ torch_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
52
  logger.info(f"๋””๋ฐ”์ด์Šค: {'GPU' if device == 0 else 'CPU'}")
53
 
54
  # ============================================================
 
60
  _data_lock = threading.Lock()
61
 
62
  # ============================================================
63
+ # โ‘  ์ปค์Šคํ…€ ๊ฐ์ • ๋ชจ๋ธ ์•„ํ‚คํ…์ฒ˜ ์ •์˜
64
+ # ============================================================
65
+
66
+ class SEBlock(nn.Module):
67
+ """Squeeze-and-Excitation Block"""
68
+ def __init__(self, channels: int, reduction: int = 16):
69
+ super().__init__()
70
+ self.excitation = nn.Sequential(
71
+ nn.Linear(channels, channels // reduction, bias=False),
72
+ nn.ReLU(inplace=True),
73
+ nn.Linear(channels // reduction, channels, bias=False),
74
+ nn.Sigmoid(),
75
+ )
76
+
77
+ def forward(self, x):
78
+ # x: (B, C, H, W)
79
+ b, c, _, _ = x.shape
80
+ w = x.mean(dim=[2, 3]) # global avg pool
81
+ w = self.excitation(w).view(b, c, 1, 1)
82
+ return x * w
83
+
84
+
85
+ class ResBlock(nn.Module):
86
+ """ResNet Basic Block with SE"""
87
+ def __init__(self, in_ch: int, out_ch: int, stride: int = 1):
88
+ super().__init__()
89
+ self.conv1 = nn.Conv2d(in_ch, out_ch, 3, stride=stride, padding=1, bias=False)
90
+ self.bn1 = nn.BatchNorm2d(out_ch)
91
+ self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1, bias=False)
92
+ self.bn2 = nn.BatchNorm2d(out_ch)
93
+ self.se = SEBlock(out_ch, reduction=max(1, out_ch // 16))
94
+
95
+ self.shortcut = nn.Sequential()
96
+ if stride != 1 or in_ch != out_ch:
97
+ self.shortcut = nn.Sequential(
98
+ nn.Conv2d(in_ch, out_ch, 1, stride=stride, bias=False),
99
+ nn.BatchNorm2d(out_ch),
100
+ )
101
+
102
+ def forward(self, x):
103
+ out = F.relu(self.bn1(self.conv1(x)), inplace=True)
104
+ out = self.bn2(self.conv2(out))
105
+ out = self.se(out)
106
+ out = F.relu(out + self.shortcut(x), inplace=True)
107
+ return out
108
+
109
+
110
+ class AttentionPooling(nn.Module):
111
+ """Temporal Attention Pooling"""
112
+ def __init__(self, hidden: int):
113
+ super().__init__()
114
+ self.attn = nn.Linear(hidden, 1)
115
+
116
+ def forward(self, x):
117
+ # x: (B, T, H)
118
+ w = torch.softmax(self.attn(x), dim=1) # (B, T, 1)
119
+ return (x * w).sum(dim=1) # (B, H)
120
+
121
+
122
+ class EmotionResNet(nn.Module):
123
+ """
124
+ ResNet-SE + 2-layer BiLSTM + Attention Pooling + Classifier
125
+ ์ž…๋ ฅ: (B, 1, N_MELS, T) Mel-spectrogram
126
+ ์ถœ๋ ฅ: (B, num_classes) logits
127
+ """
128
+ def __init__(self, num_classes: int = 4):
129
+ super().__init__()
130
+ # CNN stem
131
+ self.conv1 = nn.Sequential(
132
+ nn.Conv2d(1, 64, 7, stride=2, padding=3, bias=False),
133
+ nn.BatchNorm2d(64),
134
+ )
135
+ # ResNet layers
136
+ self.layer1 = nn.Sequential(ResBlock(64, 64), ResBlock(64, 64))
137
+ self.layer2 = nn.Sequential(ResBlock(64, 128, stride=2), ResBlock(128, 128))
138
+ self.layer3 = nn.Sequential(ResBlock(128, 256, stride=2), ResBlock(256, 256))
139
+
140
+ # BiLSTM (2 layers)
141
+ self.bilstm = nn.LSTM(
142
+ input_size=256, hidden_size=256,
143
+ num_layers=2, batch_first=True,
144
+ bidirectional=True, dropout=0.3,
145
+ )
146
+
147
+ # Attention
148
+ self.attention = AttentionPooling(hidden=512)
149
+
150
+ # Classifier
151
+ self.classifier = nn.Sequential(
152
+ nn.Linear(512, 256),
153
+ nn.BatchNorm1d(256),
154
+ nn.ReLU(inplace=True),
155
+ nn.Dropout(0.5),
156
+ nn.Linear(256, num_classes),
157
+ )
158
+
159
+ def forward(self, x):
160
+ # CNN
161
+ x = F.relu(self.conv1(x), inplace=True)
162
+ x = self.layer1(x)
163
+ x = self.layer2(x)
164
+ x = self.layer3(x)
165
+
166
+ # (B, C, H, W) โ†’ temporal sequence: global-avg over freq axis
167
+ x = x.mean(dim=2) # (B, C, W)
168
+ x = x.permute(0, 2, 1) # (B, T, C)
169
+
170
+ # BiLSTM
171
+ x, _ = self.bilstm(x) # (B, T, 512)
172
+
173
+ # Attention pooling
174
+ x = self.attention(x) # (B, 512)
175
+
176
+ return self.classifier(x)
177
+
178
+
179
+ # ============================================================
180
+ # โ‘ก ์ปค์Šคํ…€ ๊ฐ์ • ๋ชจ๋ธ ๋กœ๋“œ
181
+ # ============================================================
182
+ _emotion_model = None
183
+ _emotion_classes = ["Angry", "Happy", "Neutral", "Sad"]
184
+ _emotion_label_enc = None
185
+ _emotion_temp = 1.0
186
+ _emotion_n_tta = 1
187
+
188
+ def _load_emotion_model():
189
+ global _emotion_model, _emotion_classes, _emotion_label_enc, _emotion_temp, _emotion_n_tta
190
+ if not EMOTION_MODEL_PATH.exists():
191
+ logger.error(f"{EMOTION_MODEL_PATH} ํŒŒ์ผ์ด ์—†์Šต๋‹ˆ๋‹ค. ์ปค์Šคํ…€ ๊ฐ์ • ๋ชจ๋ธ์„ ์‚ฌ์šฉํ•˜์ง€ ์•Š์Šต๋‹ˆ๋‹ค.")
192
+ return
193
+
194
+ try:
195
+ ckpt = torch.load(EMOTION_MODEL_PATH, map_location="cpu", weights_only=False)
196
+
197
+ _emotion_classes = [str(c) for c in ckpt.get("classes", _emotion_classes)]
198
+ _emotion_label_enc = ckpt.get("label_encoder", None)
199
+ _emotion_temp = float(ckpt.get("temperature", 1.0))
200
+ _emotion_n_tta = int(ckpt.get("n_tta", 1))
201
+
202
+ model = EmotionResNet(num_classes=len(_emotion_classes))
203
+ model.load_state_dict(ckpt["model_state_dict"])
204
+ model.to(torch_device)
205
+ model.eval()
206
+
207
+ _emotion_model = model
208
+ logger.info(
209
+ f"์ปค์Šคํ…€ ๊ฐ์ • ๋ชจ๋ธ ๋กœ๋“œ ์™„๋ฃŒ | "
210
+ f"ํด๋ž˜์Šค: {_emotion_classes} | "
211
+ f"val_acc: {ckpt.get('val_accuracy', 'N/A')} | "
212
+ f"val_f1: {ckpt.get('best_val_f1', 'N/A'):.4f} | "
213
+ f"temp: {_emotion_temp} | TTA: {_emotion_n_tta}"
214
+ )
215
+ except Exception as e:
216
+ logger.error(f"์ปค์Šคํ…€ ๊ฐ์ • ๋ชจ๋ธ ๋กœ๋“œ ์‹คํŒจ: {e}")
217
+
218
+ _load_emotion_model()
219
+
220
+ # ============================================================
221
+ # โ‘ข Mel-spectrogram ์ „์ฒ˜๋ฆฌ
222
+ # ============================================================
223
+ def _compute_melspec(y: np.ndarray, sr: int) -> torch.Tensor:
224
+ """
225
+ ์˜ค๋””์˜ค ๋ฐฐ์—ด โ†’ (1, 1, N_MELS, MAX_FRAMES) ํ…์„œ
226
+ librosa ์—†์ด torch๋งŒ ์‚ฌ์šฉํ•˜๋Š” ๊ฐ„์ด ๊ตฌํ˜„
227
+ """
228
+ try:
229
+ import librosa
230
+ mel = librosa.feature.melspectrogram(
231
+ y=y, sr=sr,
232
+ n_mels=N_MELS, n_fft=N_FFT, hop_length=HOP_LEN,
233
+ )
234
+ mel_db = librosa.power_to_db(mel, ref=np.max)
235
+ except ImportError:
236
+ # librosa ์—†์„ ๋•Œ torch STFT๋กœ ๋Œ€์ฒด
237
+ wav = torch.tensor(y, dtype=torch.float32)
238
+ window = torch.hann_window(N_FFT)
239
+ stft = torch.stft(wav, N_FFT, HOP_LEN, return_complex=True, window=window)
240
+ power = stft.abs() ** 2 # (freq, T)
241
+ # ๊ฐ„์ด mel filterbank (์‚ผ๊ฐํ˜• ๊ทผ์‚ฌ)
242
+ mel_fb = torch.zeros(N_MELS, power.shape[0])
243
+ for m in range(N_MELS):
244
+ mel_fb[m, m * (power.shape[0] // N_MELS):
245
+ (m + 1) * (power.shape[0] // N_MELS)] = 1.0
246
+ mel = mel_fb @ power # (N_MELS, T)
247
+ mel_db = (mel + 1e-6).log().numpy()
248
+
249
+ # ์ •๊ทœํ™”
250
+ mel_db = (mel_db - mel_db.mean()) / (mel_db.std() + 1e-6)
251
+
252
+ # ์‹œ๊ฐ„ ์ถ• ํŒจ๋”ฉ/์ž๋ฅด๊ธฐ
253
+ T = mel_db.shape[1]
254
+ if T < MAX_FRAMES:
255
+ mel_db = np.pad(mel_db, ((0, 0), (0, MAX_FRAMES - T)), mode="constant")
256
+ else:
257
+ mel_db = mel_db[:, :MAX_FRAMES]
258
+
259
+ # (1, 1, N_MELS, MAX_FRAMES)
260
+ tensor = torch.tensor(mel_db, dtype=torch.float32).unsqueeze(0).unsqueeze(0)
261
+ return tensor.to(torch_device)
262
+
263
+
264
+ # ============================================================
265
+ # โ‘ฃ TTA ์ถ”๋ก 
266
+ # ============================================================
267
+ def _tta_augment(spec: torch.Tensor) -> torch.Tensor:
268
+ """๋‹จ์ˆœ ์‹œ๊ฐ„ ์ด๋™ augmentation for TTA"""
269
+ shift = np.random.randint(-MAX_FRAMES // 8, MAX_FRAMES // 8)
270
+ return torch.roll(spec, shift, dims=-1)
271
+
272
+
273
+ def _infer_emotion_model(y: np.ndarray, sr: int) -> Dict[str, float]:
274
+ """์ปค์Šคํ…€ ๋ชจ๋ธ ์ถ”๋ก  โ†’ ํด๋ž˜์Šค๋ณ„ ํ™•๋ฅ  dict (์›๋ณธ ์˜๏ฟฝ๏ฟฝ ๋ ˆ์ด๋ธ”)"""
275
+ if _emotion_model is None:
276
+ return {c: 0.0 for c in _emotion_classes}
277
+
278
+ try:
279
+ spec = _compute_melspec(y, sr) # (1, 1, N_MELS, T)
280
+
281
+ logits_list = []
282
+ with torch.no_grad():
283
+ n = max(1, _emotion_n_tta)
284
+ for i in range(n):
285
+ inp = _tta_augment(spec) if i > 0 else spec
286
+ logits = _emotion_model(inp) # (1, num_classes)
287
+ logits_list.append(logits)
288
+
289
+ avg_logits = torch.stack(logits_list).mean(dim=0) # (1, C)
290
+ probs = torch.softmax(avg_logits / _emotion_temp, dim=-1) # temperature scaling
291
+ probs = probs[0].cpu().numpy()
292
+
293
+ return {cls: float(p) for cls, p in zip(_emotion_classes, probs)}
294
+
295
+ except Exception as e:
296
+ logger.error(f"์ปค์Šคํ…€ ๋ชจ๋ธ ์ถ”๋ก  ์‹คํŒจ: {e}")
297
+ return {c: 0.0 for c in _emotion_classes}
298
+
299
+
300
+ # ============================================================
301
+ # ๋ชจ๋ธ ๋กœ๋”ฉ (STT, SBERT)
302
  # ============================================================
303
  logger.info("๋ชจ๋ธ ๋กœ๋”ฉ ์ค‘...")
304
 
 
324
  except Exception as e:
325
  logger.error(f"SBERT ๋กœ๋“œ ์‹คํŒจ: {e}")
326
 
 
 
 
 
 
 
 
 
 
 
 
327
  logger.info("๋ชจ๋ธ ๋กœ๋”ฉ ์™„๋ฃŒ!")
328
 
329
  # ============================================================
 
352
  except Exception as e:
353
  logger.error(f"๊ฐ์ • ๋ ˆ์ด๋ธ” ์ž„๋ฒ ๋”ฉ ์‹คํŒจ: {e}")
354
 
355
+ # ์ปค์Šคํ…€ ๋ชจ๋ธ ์˜์–ด ๋ ˆ์ด๋ธ” โ†’ ํ•œ๊ตญ์–ด ๋งคํ•‘
356
+ _CUSTOM_LABEL_MAP: Dict[str, str] = {
357
+ "Happy": "๊ธฐ์จ",
358
+ "Sad": "์Šฌํ””",
359
+ "Angry": "๋ถ„๋…ธ",
360
+ "Neutral": "์‹ ๋ขฐ",
361
+ }
362
 
363
  _KEYWORD_BOOSTS = {
364
  "์Šฌํ””": ["์Šฌํ”„", "์šฐ์šธ", "๋ˆˆ๋ฌผ", "ํž˜๋“ค", "์™ธ๋กœ"],
 
486
  return scores
487
 
488
 
489
+ def audio_emotion_scores(y: np.ndarray, sr: int) -> Dict[str, float]:
490
+ """
491
+ ์ปค์Šคํ…€ ๋ชจ๋ธ(final_emotion_model_v3.pth)๋กœ ์Œ์„ฑ ๊ฐ์ • ์ ์ˆ˜ ๋ฐ˜ํ™˜.
492
+ ์˜์–ด 4ํด๋ž˜์Šค ํ™•๋ฅ ์„ ํ•œ๊ตญ์–ด 8ํด๋ž˜์Šค ๊ณต๊ฐ„์œผ๋กœ ๋งคํ•‘.
493
+ """
494
+ base = {emo: 0.0 for emo in _EMOTION_LABELS}
495
 
496
+ raw = _infer_emotion_model(y, sr) # {"Happy": 0.6, "Sad": 0.2, ...}
497
+ if not raw or all(v == 0 for v in raw.values()):
498
+ return base
499
+
500
+ for eng_label, prob in raw.items():
501
+ kor_label = _CUSTOM_LABEL_MAP.get(eng_label)
502
+ if kor_label and kor_label in base:
503
+ base[kor_label] += prob
504
+
505
+ return base
 
 
 
 
506
 
507
 
508
  def fused_emotion(t_scores: Dict[str, float], a_scores: Dict[str, float]) -> Tuple[str, Dict[str, float]]:
 
518
  top_emotion = max(combined, key=combined.get)
519
  return top_emotion, combined
520
 
521
+
522
  # ============================================================
523
  # ์ถ”์ฒœ
524
  # ============================================================
 
563
  logger.error(f"์ถ”์ฒœ ์‹คํŒจ: {e}")
564
  return []
565
 
566
+
567
  # ============================================================
568
  # ์ถ”์ฒœ ๊ฒฐ๊ณผ โ†’ JSON ๋ Œ๋”๋ง
569
  # ============================================================
 
589
  }
590
  return json.dumps(output, ensure_ascii=False, indent=2)
591
 
592
+
593
  # ============================================================
594
  # ํ”ผ๋“œ๋ฐฑ
595
  # ============================================================
 
664
 
665
  def get_feedback_stats() -> str:
666
  if not FEEDBACK_PATH.exists():
667
+ return "์•„์ง ํ”ผ๋“œ๋ฐฑ์ด ์—†์Šต๋‹ˆ๋‹ค."
668
  try:
669
  fb_df = pd.read_csv(FEEDBACK_PATH, encoding="utf-8-sig", on_bad_lines="skip")
670
  total = len(fb_df)
671
  if total == 0:
672
+ return "์•„์ง ํ”ผ๋“œ๋ฐฑ์ด ์—†์Šต๋‹ˆ๋‹ค."
673
  emo_counts = fb_df.groupby("emotion")["accepted"].agg(["count", "sum"])
674
  lines = [f"**์ด ํ”ผ๋“œ๋ฐฑ: {total}๊ฑด**\n"]
675
  for emo, row_s in emo_counts.iterrows():
 
681
  except Exception as e:
682
  return f"ํ†ต๊ณ„ ๋กœ๋“œ ์‹คํŒจ: {e}"
683
 
684
+
685
  # ============================================================
686
  # ๋ฉ”์ธ ์ฒ˜๋ฆฌ
687
  # ============================================================
 
719
  return json.dumps({"error": "์Œ์„ฑ์ด ์ธ์‹๋˜์ง€ ์•Š์•˜์Šต๋‹ˆ๋‹ค."}, ensure_ascii=False, indent=2), []
720
 
721
  t_scores = text_emotion_scores(user_input)
722
+ a_scores = audio_emotion_scores(y, sr) # โ† ์ปค์Šคํ…€ ๋ชจ๋ธ ์‚ฌ์šฉ
723
  top_label, combined = fused_emotion(t_scores, a_scores)
724
  books = get_recommendations(user_input, top_label, top_n=3)
725
  books_json = _render_books_json(user_input, top_label, combined, books)
 
735
  books_json, books = process_voice(audio)
736
  return books_json, books
737
 
738
+
739
  # ============================================================
740
  # Gradio UI
741
  # ============================================================
742
+ with gr.Blocks(theme=gr.themes.Soft(), title="Boolook") as demo:
743
  gr.Markdown("""
744
+ # Boolook โ€” ์Œ์„ฑ ๊ธฐ๋ฐ˜ ๊ฐ์ • ๋ถ„์„ ์ฑ… ์ถ”์ฒœ
745
  ๋‹น์‹ ์˜ ๊ฐ์ •์„ ๋ง๋กœ ํ‘œํ˜„ํ•˜๋ฉด, AI๊ฐ€ ๋”ฑ ๋งž๋Š” ์ฑ…์„ ์ถ”์ฒœํ•ด๋“œ๋ฆฝ๋‹ˆ๋‹ค.
746
+ **์‚ฌ์šฉ๋ฒ•:** ๋งˆ์ดํฌ๋กœ ๊ฐ์ • ํ‘œํ˜„ โ†’ ๋ถ„์„ํ•˜๊ธฐ
 
747
  """)
748
 
749
  state_books = gr.State([])
750
 
751
  with gr.Row():
752
  with gr.Column(scale=1):
753
+ gr.Markdown("### ์Œ์„ฑ ์ž…๋ ฅ")
754
  audio_in = gr.Audio(
755
  sources=["microphone", "upload"],
756
  type="filepath",
757
  label="๋งˆ์ดํฌ ๋˜๋Š” ํŒŒ์ผ ์—…๋กœ๋“œ",
758
  )
759
+ analyze_btn = gr.Button("๋ถ„์„ํ•˜๊ธฐ", variant="primary", size="lg")
760
+ gr.Markdown("์˜ˆ: '์˜ค๋Š˜ ๋„ˆ๋ฌด ์Šฌํผ์š”', 'ํ–‰๋ณตํ•œ ๊ธฐ๋ถ„์ด์—์š”'")
761
 
762
  with gr.Column(scale=1):
763
  out_books_json = gr.Code(
764
+ label="๋ถ„์„ ๊ฒฐ๊ณผ & ์ถ”์ฒœ ๋„์„œ",
765
  language="json",
766
  interactive=False,
767
  )
768
 
769
+ with gr.Accordion("ํ†ต๊ณ„", open=False):
770
  stats_md = gr.Markdown("์ƒˆ๋กœ๊ณ ์นจ์„ ๋ˆŒ๋Ÿฌ์ฃผ์„ธ์š”.")
771
+ refresh_btn = gr.Button("ํ†ต๊ณ„ ์ƒˆ๋กœ๊ณ ์นจ")
772
  refresh_btn.click(fn=get_feedback_stats, outputs=stats_md)
773
 
774
  # ํ”ผ๋“œ๋ฐฑ API ์—”๋“œํฌ์ธํŠธ (ํด๋ผ์ด์–ธํŠธ ์ „์šฉ, UI ๋ฏธ๋…ธ์ถœ)
 
790
  )
791
 
792
  if __name__ == "__main__":
793
+ demo.launch()