File size: 33,100 Bytes
755eb09
6a13bbe
 
7d80d37
755eb09
 
6a13bbe
 
 
7d80d37
e204bd0
 
6a13bbe
755eb09
 
6a13bbe
e204bd0
 
6a13bbe
1981810
7cc64b6
 
 
 
c4a5cbc
7d80d37
7cc64b6
7d80d37
6a13bbe
1981810
825303f
 
 
755eb09
 
 
 
7d80d37
ebe3b6f
 
 
 
 
 
7d80d37
7cc64b6
e204bd0
6a13bbe
e204bd0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7d80d37
1981810
6a13bbe
7d80d37
 
1981810
7d80d37
1981810
7d80d37
1981810
7d80d37
 
 
 
 
 
 
 
 
1981810
7d80d37
 
 
 
 
 
 
 
 
 
 
 
 
 
1981810
 
7d80d37
6a13bbe
1981810
6a13bbe
 
 
 
31e6ef1
7d80d37
31e6ef1
6a13bbe
 
7d80d37
1981810
c36c080
 
 
 
 
 
 
 
1981810
7d80d37
31e6ef1
 
d5cddb7
02424f0
7d80d37
 
 
 
 
 
 
 
 
 
 
 
 
31e6ef1
 
7d80d37
31e6ef1
 
 
 
7d80d37
 
31e6ef1
7d80d37
 
 
 
 
 
 
 
 
 
 
31e6ef1
7d80d37
31e6ef1
 
7d80d37
 
 
 
 
 
31e6ef1
1981810
 
7d80d37
1981810
 
7d80d37
1981810
31e6ef1
7d80d37
 
1981810
 
 
7d80d37
 
1981810
7d80d37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d5cddb7
6a13bbe
 
 
 
7d80d37
e204bd0
6a13bbe
e204bd0
7d80d37
 
 
 
6a13bbe
e204bd0
6a13bbe
7d80d37
6a13bbe
 
 
 
 
 
 
 
 
 
755eb09
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7d80d37
7cc64b6
 
 
 
7d80d37
7cc64b6
 
 
 
7d80d37
7cc64b6
 
 
 
7d80d37
7cc64b6
 
 
 
 
7d80d37
7cc64b6
 
7d80d37
7cc64b6
 
7d80d37
 
7cc64b6
 
 
 
 
6a13bbe
7cc64b6
 
7d80d37
 
 
 
 
6a13bbe
7d80d37
7cc64b6
e204bd0
1981810
 
 
6a13bbe
 
 
e204bd0
d5cddb7
e204bd0
6a13bbe
 
 
e204bd0
6a13bbe
 
 
 
7d80d37
6a13bbe
 
 
7d80d37
6a13bbe
7d80d37
d5cddb7
 
6a13bbe
 
7d80d37
84e84dd
1803050
7d80d37
 
 
 
84e84dd
acd8897
7d80d37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
acd8897
7d80d37
 
 
 
 
 
 
 
 
 
 
1981810
e204bd0
6a13bbe
1981810
7d80d37
 
 
1981810
6a13bbe
 
 
 
 
 
7cc64b6
 
 
 
 
 
7d80d37
 
f446a3d
7cc64b6
8a8b25f
 
 
7cc64b6
 
6a13bbe
8a8b25f
7cc64b6
7d80d37
c159888
 
7d80d37
 
 
f446a3d
c159888
7d80d37
 
c159888
8a8b25f
6a13bbe
7cc64b6
 
6a13bbe
c75c2da
 
7cc64b6
 
7d80d37
 
7cc64b6
7d80d37
6a13bbe
 
7d80d37
7cc64b6
 
7d80d37
6a13bbe
7d80d37
6a13bbe
 
 
 
 
 
 
 
7d80d37
7cc64b6
671de8b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5144e07
 
 
 
 
 
 
 
 
 
 
 
671de8b
 
5144e07
 
 
 
 
 
671de8b
 
 
5144e07
 
 
671de8b
 
5144e07
671de8b
5144e07
 
 
 
755eb09
671de8b
 
 
 
 
 
 
 
 
 
755eb09
 
 
5144e07
671de8b
6a13bbe
e204bd0
 
5cae9df
e204bd0
 
7cc64b6
 
 
 
 
 
e204bd0
 
c75c2da
 
7cc64b6
c75c2da
7cc64b6
e204bd0
 
 
7cc64b6
e204bd0
7cc64b6
 
755eb09
7cc64b6
e204bd0
7d80d37
e204bd0
 
7cc64b6
7d80d37
7cc64b6
 
 
 
 
c75c2da
7cc64b6
 
 
c75c2da
 
e204bd0
 
7cc64b6
 
 
 
7d80d37
e204bd0
755eb09
e204bd0
7cc64b6
e204bd0
7cc64b6
 
 
 
 
 
 
 
e204bd0
 
7cc64b6
 
e204bd0
7cc64b6
e204bd0
 
c75c2da
6a13bbe
e204bd0
7cc64b6
 
6a13bbe
 
c75c2da
7cc64b6
 
 
 
 
 
 
 
6a13bbe
 
e204bd0
7cc64b6
 
 
755eb09
 
 
 
 
 
7cc64b6
e204bd0
 
 
7cc64b6
5388b96
7d80d37
e204bd0
7cc64b6
 
 
 
755eb09
7cc64b6
755eb09
 
 
 
 
 
c75c2da
7cc64b6
7d80d37
 
e204bd0
 
755eb09
7cc64b6
 
7d80d37
e204bd0
 
755eb09
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
# app_updated.py
"""
Voice→Place Recommender (Streamlit / Hugging Face Spaces)
- 日本語音声感情認識:S3PRL(HuBERT base) + HFの下流(.ckpt)を用いてJTES(4感情)推定
- 音声波形表示機能を追加
- SNS共有ボタンを追加
"""

# ===== 基本インポート =====
import io, base64, os, random
import numpy as np
import soundfile as sf
from pydub import AudioSegment
import urllib.parse
from datetime import datetime

import streamlit as st
from audiorecorder import audiorecorder

# Matplotlib
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from matplotlib import rcParams
import japanize_matplotlib
import matplotlib.font_manager as fm

# Torch / HF Hub / S3PRL
import torch
import torch.nn as nn
from huggingface_hub import HfApi, hf_hub_download
from s3prl.nn import S3PRLUpstream, Featurizer

# Librosa for waveform
import librosa
import librosa.display

# ===== フォント設定(日本語) =====
jp_candidates = ["IPAexGothic", "IPAGothic", "Noto Sans CJK JP", "Noto Sans CJK"]
for name in jp_candidates:
    if any(name in f.name for f in fm.fontManager.ttflist):
        rcParams["font.family"] = name
        break
else:
    rcParams["font.family"] = "DejaVu Sans"
rcParams["axes.unicode_minus"] = False

# ===== 架空の場所データ =====
PLACES = [
    {"place_id":"lib_silent", "name":"無音図書館", "tags":["静けさ","集中","屋内"], "emo_key":"calm", "image":"images/lib_silent.png"},
    {"place_id":"aqua_museum", "name":"深海ガラス館", "tags":["発見","学習","ひんやり","屋内"], "emo_key":"surprise", "image":"images/aqua_museum.png"},
    {"place_id":"roof_garden", "name":"雨上がりの屋上庭園", "tags":["開放","共有","屋外","緑"], "emo_key":"joy", "image":"images/roof_garden.png"},
    {"place_id":"boulder_warehouse", "name":"影のボルダリング倉庫", "tags":["発散","身体活動","屋内"], "emo_key":"release", "image":"images/shade_bol.png"},
    {"place_id":"atelier_mono", "name":"静寂のアトリエ", "tags":["創作","集中","屋内"], "emo_key":"calm", "image":"images/silent_atlier.png"},
    {"place_id":"wind_birch", "name":"風鳴りの白樺道", "tags":["自然","散歩","屋外","緑"], "emo_key":"joy", "image":"images/wind_root.png"},
    {"place_id":"forest_walk", "name":"霧の森プロムナード", "tags":["自然","散歩","静けさ","屋外"], "emo_key":"calm", "image":"images/forest_walk.png"},
    {"place_id":"river_bank", "name":"川辺のデッキテラス", "tags":["水辺","開放","屋外","休憩"], "emo_key":"joy", "image":"images/river_bank.png"},
    {"place_id":"sound_lab", "name":"サウンドラボ実験室", "tags":["体験","学習","没入","屋内"], "emo_key":"surprise", "image":"images/sound_lab.png"},
    {"place_id":"maker_space", "name":"メイカーズガレージ", "tags":["創作","体験","交流","屋内"], "emo_key":"joy", "image":"images/maker_space.png"},
    {"place_id":"bamboo_garden", "name":"竹林の回廊", "tags":["静けさ","緑","内省","屋外"], "emo_key":"calm", "image":"images/bamboo_garden.png"},
    {"place_id":"light_gallery", "name":"光のギャラリー", "tags":["発見","没入","展示","屋内"], "emo_key":"surprise", "image":"images/light_gallery.png"},
    {"place_id":"clay_studio", "name":"陶芸スタジオ", "tags":["創作","集中","屋内"], "emo_key":"calm", "image":"images/clay_studio.png"},
    {"place_id":"urban_track", "name":"アーバントラック", "tags":["身体活動","発散","屋外"], "emo_key":"release", "image":"images/urban_track.png"},
]
REASON_TAGS = ["静けさ","緑","水辺","発散","創作","交流","体験","学習","屋内","屋外","没入","回復"]

# ===== モデル定義 =====
KUSHINADA_REPO = "imprt/kushinada-hubert-base-jtes-er"

# ---- Downstream ヘッド(1層 or 2層MLP) ----
class DownstreamHead(nn.Module):
    """
    in -> (optional proj Linear) -> (optional ReLU) -> final Linear -> logits
    """
    def __init__(self, in_dim, out_dim, W_final, b_final, proj_W=None, proj_b=None):
        super().__init__()
        self.proj = None
        if proj_W is not None and proj_b is not None:
            proj_out, proj_in = proj_W.shape  # [out, in]
            self.proj = nn.Linear(proj_in, proj_out)
            with torch.no_grad():
                self.proj.weight.copy_(proj_W)
                self.proj.bias.copy_(proj_b)
            in_dim = proj_out  # 後段の入力次元
        self.fc = nn.Linear(in_dim, out_dim)
        with torch.no_grad():
            self.fc.weight.copy_(W_final)
            self.fc.bias.copy_(b_final)

    @property
    def expected_in(self):
        # 入力期待次元(Featurizerからのプール後に一致させたい次元)
        if self.proj is not None:
            return self.proj.in_features
        return self.fc.in_features

    def forward(self, x):  # x: [B, expected_in]
        if self.proj is not None:
            x = self.proj(x)
            # 学習時に非線形を挟んでいた可能性はあるが未知なので省略(必要ならnn.ReLU()等)
        return self.fc(x)

# ====== KUSHINADA ローダ(上流 + featurizer + 下流ヘッド構築) ======
@st.cache_resource(show_spinner=False)
def load_kushinada_s3prl():
    token = os.getenv("HF_TOKEN")
    if not token:
        raise RuntimeError("環境変数 HF_TOKEN が見つかりません。SpacesのSettings→Secretsで設定してください。")

    revision = os.getenv("KUSHINADA_REVISION", "main")
    prefer_filename = os.getenv("KUSHINADA_FILENAME", "").strip()

    device = "cuda" if torch.cuda.is_available() else "cpu"

    # 1) 上流 + Featurizer(最終層)
    upstream = S3PRLUpstream("hubert_base").to(device).eval()
    try:
        featurizer = Featurizer(upstream)
    except TypeError:
        try:
            featurizer = Featurizer(upstream, upstream_feature_selection="last_hidden_state")
        except TypeError:
            featurizer = Featurizer(upstream, feature_selection="last_hidden_state")
    featurizer = featurizer.to(device).eval()

    # 2) ckpt選定(下流のみ。upstream/converted系は除外)
    api = HfApi()
    info = api.model_info(KUSHINADA_REPO, token=token, revision=revision)
    all_files = [s.rfilename for s in info.siblings]

    def is_ckpt(path):
        p = path.lower()
        if not (p.endswith(".pt") or p.endswith(".ckpt") or p.endswith(".pth") or p.endswith(".bin")):
            return False
        # 上流や変換済みの類は除外
        bad = ["upstream", "converted", "hubert_base", "s3prl/converted", "wav2vec", "espnet"]
        if any(b in p for b in bad):
            return False
        return True

    candidates = [f for f in all_files if is_ckpt(f)]

    # 優先順位: 明示指定 > downstream/dev-best > best > fold > others
    filename = None
    if prefer_filename:
        # サブパス一致/末尾一致にも対応
        if prefer_filename in all_files:
            filename = prefer_filename
        else:
            matches = [f for f in all_files if f.endswith(prefer_filename)]
            if matches:
                filename = matches[0]
    if filename is None and candidates:
        def rank_score(f):
            f_lower = f.lower()
            score = 0
            if "result/downstream" in f_lower: score += 100
            if "dev-best" in f_lower: score += 50
            if "best" in f_lower: score += 20
            if "fold" in f_lower: score += 10
            if "kushinada" in f_lower: score += 5
            return -score, len(f)  # スコア高→優先、短すぎる名前は避けたいので長さも加味
        candidates_sorted = sorted(candidates, key=rank_score)
        filename = candidates_sorted[0]
    if filename is None:
        raise FileNotFoundError("下流チェックポイントが見つかりません。KUSHINADA_FILENAME を Secrets に設定してください。")

    ckpt_path = hf_hub_download(
        repo_id=KUSHINADA_REPO,
        filename=filename,
        revision=revision,
        token=token,
        repo_type="model",
        local_dir_use_symlinks=False
    )
    ckpt = torch.load(ckpt_path, map_location="cpu")

    # 3) state_dict 取得
    state = None
    if isinstance(ckpt, dict):
        for key in ["state_dict", "Downstream", "model", "downstream", "net", "weights"]:
            if key in ckpt and isinstance(ckpt[key], dict):
                state = ckpt[key]; break
        if state is None:
            state = ckpt
    if not isinstance(state, dict):
        raise RuntimeError("チェックポイント形式を解釈できませんでした。")

    # 4) すべての (weight,bias) の線形層候補を収集([out,in]に整形)
    layers = []
    for k, v in state.items():
        if isinstance(v, torch.Tensor) and v.ndim == 1:  # bias
            b = v.float()
            base = k[:-5] if k.endswith(".bias") else k.rsplit(".", 1)[0]
            w_key = base + ".weight"
            if w_key in state and isinstance(state[w_key], torch.Tensor) and state[w_key].ndim == 2:
                W = state[w_key].float()
                # [out, in] に整形
                if W.shape[0] >= 2 and W.shape[1] >= 2:
                    out, in_ = W.shape
                    layers.append({
                        "name": base,
                        "W": W, "b": b,
                        "out": out, "in": in_
                    })
                else:
                    # 逆向きの可能性 [in,out] を考慮
                    Wt = W.t()
                    out, in_ = Wt.shape
                    layers.append({
                        "name": base,
                        "W": Wt, "b": b,
                        "out": out, "in": in_
                    })

    if not layers:
        raise RuntimeError("線形層の (weight, bias) が見つかりませんでした。")

    # 5) 最終層候補(出力クラスが小さい層を優先)
    finals = [L for L in layers if 2 <= L["out"] <= 16]
    if not finals:
        raise RuntimeError("最終分類層らしき小クラス数の線形層が見つかりませんでした。")

    # 768や256がよく使われるので、それに近いinを優先。名前でclassifier等があればさらに加点
    def final_rank(L):
        score = 0
        if "class" in L["name"].lower() or "out" in L["name"].lower() or "fc" in L["name"].lower():
            score += 3
        score -= abs(L["in"] - 256) / 256.0
        score -= abs(L["in"] - 768) / 768.0
        return -score
    finals_sorted = sorted(finals, key=final_rank)
    final = finals_sorted[0]

    # 6) 前段の射影(final.in に一致する out を持つ層)を探索
    proj = None
    proj_candidates = [L for L in layers if L["out"] == final["in"]]
    if proj_candidates:
        def proj_rank(L):
            score = 0
            if "proj" in L["name"].lower() or "linear" in L["name"].lower() or "fc" in L["name"].lower():
                score += 2
            score -= abs(L["in"] - 768) / 768.0
            return -score
        proj = sorted(proj_candidates, key=proj_rank)[0]

    # 7) DownstreamHead 構築
    if proj is not None:
        head = DownstreamHead(
            in_dim=proj["in"], out_dim=final["out"],
            W_final=final["W"], b_final=final["b"],
            proj_W=proj["W"], proj_b=proj["b"]
        )
    else:
        head = DownstreamHead(
            in_dim=final["in"], out_dim=final["out"],
            W_final=final["W"], b_final=final["b"]
        )
    head = head.to(device).eval()

    # 8) ラベル(JTES想定)
    default_labels = ["angry", "happy", "neutral", "sad"]
    id2label = {i: (default_labels[i] if head.fc.out_features == 4 and i < 4 else f"class_{i}") for i in range(head.fc.out_features)}

    st.info(f"✅ ckpt: `{filename}`(rev: {revision})")
    st.info(f"✅ head.expected_in={head.expected_in}, final_out={head.fc.out_features}")
    return featurizer, head, id2label, device

# ===== ユーティリティ =====
def to_wav_bytes(any_bytes: bytes, target_sr=16000, mono=True) -> bytes:
    if not any_bytes:
        st.error("音声が空です。録音やアップロードを確認してください。"); st.stop()
    try:
        seg = AudioSegment.from_file(io.BytesIO(any_bytes))
    except Exception as e:
        st.error(f"音声読込エラー: {e}"); st.stop()
    if mono: seg = seg.set_channels(1)
    if target_sr: seg = seg.set_frame_rate(target_sr)
    buf = io.BytesIO(); seg.export(buf, format="wav")
    return buf.getvalue()

def audio_player_bytes(b: bytes, mime="audio/wav"):
    if not b: return
    b64 = base64.b64encode(b).decode("utf-8")
    st.markdown(
        f"""
        <audio controls preload="metadata" style="width:100%">
          <source src="data:{mime};base64,{b64}" type="{mime}">
        </audio>
        """,
        unsafe_allow_html=True,
    )

# ===== 音声波形表示機能を追加 =====
def create_waveform_visualization(audio_bytes):
    """音声波形を可視化"""
    if audio_bytes is None:
        return None
    
    try:
        # バイトデータから音声を読み込み
        y, sr = sf.read(io.BytesIO(audio_bytes), dtype="float32")
        
        # 図の作成
        fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 6), dpi=100)
        
        # 波形表示
        librosa.display.waveshow(y, sr=sr, ax=ax1, color='#4169E1', alpha=0.8)
        ax1.set_title('Audio Waveform', fontsize=14, fontweight='bold')
        ax1.set_xlabel('Time (s)')
        ax1.set_ylabel('Amplitude')
        ax1.grid(True, alpha=0.3)
        
        # スペクトログラム
        D = librosa.stft(y)
        DB = librosa.amplitude_to_db(abs(D), ref=np.max)
        img = librosa.display.specshow(DB, sr=sr, x_axis='time', y_axis='hz', ax=ax2)
        ax2.set_title('Spectrogram', fontsize=14, fontweight='bold')
        fig.colorbar(img, ax=ax2, format='%+2.0f dB')
        
        plt.tight_layout()
        
        return fig
    
    except Exception as e:
        st.error(f"波形表示エラー: {e}")
        return None

# ===== フォールバック(簡易特徴量) =====
def extract_features(y, sr):
    abs_y = np.abs(y)
    thr = 0.01 * (abs_y.max() + 1e-9)
    idx = np.where(abs_y > thr)[0]
    if idx.size >= 2: y = y[idx[0]:idx[-1]+1]
    energy_mean = float(np.sqrt(np.mean(y**2) + 1e-12))
    n = len(y)
    win = np.hanning(n) if n >= 512 else np.ones_like(y)
    y_win = y * win
    spec = np.fft.rfft(y_win); mag = np.abs(spec) + 1e-12
    freqs = np.fft.rfftfreq(len(y_win), d=1.0/sr)
    sc_mean = float((freqs * mag).sum() / mag.sum())
    zc = (y[:-1] * y[1:] < 0).astype(np.float32)
    zcr_mean = float(zc.mean()) if zc.size else 0.0
    # 超簡易F0
    fmin, fmax = 80.0, 600.0
    if len(y) < int(sr / fmin) + 2:
        f0_est = 0.0
    else:
        corr = np.correlate(y, y, mode='full')[len(y)-1:]
        lmin = max(1, int(sr / fmax)); lmax = min(len(corr) - 1, int(sr / fmin))
        seg = corr[lmin:lmax] if lmax > lmin else np.array([])
        if seg.size > 0:
            lag = lmin + int(np.argmax(seg)); f0_est = float(sr / lag) if lag > 0 else 0.0
        else:
            f0_est = 0.0
    return {"f0_mean": float(f0_est), "energy_mean": energy_mean, "spec_centroid": sc_mean,
            "zcr_mean": zcr_mean, "duration": len(y)/sr}

def predict_emotion_features(audio_bytes):
    wav_bytes_16k = to_wav_bytes(audio_bytes, target_sr=16000)
    y, sr = sf.read(io.BytesIO(wav_bytes_16k), dtype="float32")
    feat = extract_features(y, sr)
    f0, en, z = feat["f0_mean"], feat["energy_mean"], feat["zcr_mean"]
    arousal = float(np.tanh(160*en + 4*z))
    valence = float(np.tanh(((f0-170)/120) + 15*en))
    if valence >= 0.22 and arousal >= 0.22: label = "happiness"
    elif valence >= 0.22 and arousal < 0.22: label = "neutral"
    elif valence < 0.10 and arousal >= 0.30: label = "anger"
    elif valence < 0.10 and arousal < 0.18: label = "sadness"
    else: label = "neutral"
    scores = {k: 0.0 for k in ["happiness","anger","sadness","neutral"]}
    scores[label] = 0.7; scores["neutral"] += 0.3
    return label, scores, "Features"

# ===== AI推定(S3PRL)=====
def _normalize_label(lbl: str) -> str:
    m = {"happy": "happiness", "angry": "anger", "sad": "sadness", "neutral": "neutral"}
    return m.get(lbl.lower(), lbl)

def predict_emotion_ai(audio_bytes):
    try:
        featurizer, head, id2label, device = load_kushinada_s3prl()
    except Exception as e:
        st.error(f"モデルのロードに失敗しました: {e}")
        st.info("音声特徴量ベースの分析に切り替えます。")
        return predict_emotion_features(audio_bytes)

    try:
        wav_bytes_16k = to_wav_bytes(audio_bytes, target_sr=16000)
        y, sr = sf.read(io.BytesIO(wav_bytes_16k), dtype="float32")

        # 30秒でカット
        max_duration = 30
        max_samples = int(sr * max_duration)
        if len(y) > max_samples:
            y = y[:max_samples]; st.warning("音声が30秒を超えたため、最初の30秒のみ分析します。")

        # S3PRLは list[Tensor], list[int] を想定
        wavs = [torch.tensor(y, dtype=torch.float32)]
        wavs_len = [int(len(y))]

        with torch.no_grad():
            reps, reps_len = featurizer(wavs, wavs_len)  # 期待: reps [B,T,H], reps_len list[int]
            if not isinstance(reps, torch.Tensor):
                raise RuntimeError(f"Unexpected reps type: {type(reps)}")
            # reps を [B,T,H] へ
            if reps.dim() == 1: reps = reps.unsqueeze(0).unsqueeze(0)
            elif reps.dim() == 2: reps = reps.unsqueeze(0)
            elif reps.dim() != 3:
                raise RuntimeError(f"Unexpected reps.dim(): {reps.dim()}")

            B, T, H = reps.shape

            # reps_len を [B] リストに
            if reps_len is None: reps_len_list = [T]*B
            elif isinstance(reps_len, int): reps_len_list = [int(reps_len)]*B
            elif isinstance(reps_len, (list, tuple)): reps_len_list = [int(x) for x in reps_len]
            elif isinstance(reps_len, torch.Tensor): reps_len_list = reps_len.view(-1).tolist()
            else: reps_len_list = [T]*B
            if len(reps_len_list) != B: reps_len_list = [T]*B
            reps_len_list = [max(1, min(int(li), T)) for li in reps_len_list]

            # 有効長で時間平均 → [B,H_feat]
            pooled = torch.stack([reps[i, :reps_len_list[i]].mean(dim=0) for i in range(B)], dim=0)  # [B,H_feat]

            # 次元整合:期待入力に合わせる
            expected_in = head.expected_in
            H_feat = pooled.shape[1]

            if H_feat == expected_in:
                pooled_in = pooled
            elif H_feat % expected_in == 0:
                g = H_feat // expected_in
                pooled_in = pooled.view(B, expected_in, g).mean(dim=2)  # グループ平均で縮約
                st.info(f"ℹ️ 特徴次元を {H_feat}{expected_in} にグループ平均で整合 (group={g})")
            else:
                # どうしても合わない場合は線形射影(最小限の適合用)
                proj = nn.Linear(H_feat, expected_in).to(pooled.device)
                with torch.no_grad():
                    nn.init.eye_(proj.weight[:min(H_feat, expected_in), :min(H_feat, expected_in)])
                    if expected_in > H_feat:
                        nn.init.zeros_(proj.weight[min(H_feat, expected_in):])
                    nn.init.zeros_(proj.bias)
                pooled_in = proj(pooled)
                st.info(f"ℹ️ 線形射影で {H_feat}{expected_in} に適合")

            logits = head(pooled_in.to(device))      # [B,C]
            probs = torch.softmax(logits, dim=-1)[0].detach().cpu().numpy()

        pred_id = int(np.argmax(probs))
        raw_label = id2label[pred_id]
        label = _normalize_label(raw_label)
        scores = {_normalize_label(id2label[i]): float(probs[i]) for i in range(len(probs))}
        for k in list(scores.keys()): scores[k] = max(0.0, min(1.0, scores[k]))
        return label, scores, "AI(S3PRL)"

    except Exception as e:
        st.warning(f"AI予測中にエラーが発生: {e}")
        return predict_emotion_features(audio_bytes)

# ===== 推薦 =====
def score_places(emo_label, top_k=4, diversity=True):
    EMO_MAP_PRIORS = {
        "happiness": ["joy", "surprise"],
        "anger": ["release", "calm"],
        "sadness": ["calm", "joy"],
        "neutral": ["calm", "surprise", "joy"],
        "joy": ["joy","surprise"], "calm": ["calm","joy"],
        "surprise": ["surprise","joy"], "release": ["release","calm"],
    }
    priors = EMO_MAP_PRIORS.get(emo_label, ["calm","joy","surprise"])
    scored = []
    for p in PLACES:
        base = 0.5
        if p["emo_key"] == priors[0]: base += 0.5
        if len(priors) > 1 and p["emo_key"] == priors[1]: base += 0.25
        scored.append((base + random.uniform(-0.02, 0.02), p))
    scored.sort(key=lambda x: x[0], reverse=True)
    candidates = [p for _, p in scored[:max(top_k, 4)]]
    if not diversity: return candidates[:top_k]
    picked, seen = [], set()
    for p in candidates:
        if p["emo_key"] not in seen:
            picked.append(p); seen.add(p["emo_key"])
        if len(picked) >= top_k: break
    if len(picked) < top_k:
        for p in candidates:
            if p not in picked: picked.append(p)
            if len(picked) >= top_k: break
    return picked

# ===== 可視化 =====
def plot_emotion_map(emotion_label, scores, method="AI"):
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5), dpi=150)
    emotion_jp = {
        'happiness': '幸せ', 'anger': '怒り', 'sadness': '悲しみ', 'neutral': '中立',
        'joy': '喜び', 'calm': '落ち着き', 'surprise': '驚き', 'release': '発散'
    }
    color_map = {
        'happiness': '#FF6B6B','anger': '#FFA94D','sadness': '#868E96','neutral': '#51CF66',
        'joy': '#FF6B6B','calm': '#51CF66','surprise': '#74C0FC','release': '#FFD43B'
    }
    labels = list(scores.keys()); values = [scores[k] for k in labels]
    colors = [color_map.get(k, '#74C0FC') for k in labels]
    bars = ax1.bar([emotion_jp.get(k,k) for k in labels], values, color=colors, alpha=0.85)
    ax1.set_ylim(0, 1); ax1.set_ylabel('Score', fontsize=12)
    ax1.set_title(f'Emotion Scores ({method})', fontsize=14, fontweight='bold')
    ax1.grid(axis='y', alpha=0.3)
    for bar, v in zip(bars, values):
        ax1.text(bar.get_x() + bar.get_width()/2., bar.get_height() + 0.01,
                 f'{v:.2f}', ha='center', va='bottom', fontsize=10)
    pairs = [(k,v) for k,v in scores.items() if v > 0.05]
    sizes = [v for _,v in pairs]
    labels_pie = [emotion_jp.get(k,k) for k,_ in pairs]
    colors_pie = [color_map.get(k, '#74C0FC') for k,_ in pairs]
    ax2.pie(sizes, labels=labels_pie, colors=colors_pie,
            autopct='%1.0f%%', startangle=90, textprops={'fontsize': 11})
    ax2.set_title(f'Result: {emotion_jp.get(emotion_label, emotion_label)}',
                  fontsize=14, fontweight='bold')
    plt.tight_layout(); return fig

import urllib.parse

# 共有先URL(固定でOK)
PAGE_URL   = "https://huggingface.co/spaces/ayaka68/voice2place"
FAC_URL    = "https://www.hokusei.ac.jp/informatics/"
PERSON_URL = "https://aonoa68.github.io/"

def _with_utm(url: str, content: str):
    """共有計測用の UTM を付与"""
    u = urllib.parse.urlsplit(url)
    q = urllib.parse.parse_qs(u.query)
    q["utm_source"]   = ["hub"]
    q["utm_medium"]   = ["social"]
    q["utm_campaign"] = ["oc2025"]
    q["utm_content"]  = [content]  # twitter / facebook / line
    query = urllib.parse.urlencode(q, doseq=True)
    return urllib.parse.urlunsplit((u.scheme, u.netloc, u.path, query, u.fragment))

# ===== SNS共有ボタン機能を追加(改訂版) =====
import requests  # 追加(ファイル冒頭でもOK)

def shorten(url: str):
    """is.gd を使ってURLを短縮"""
    try:
        r = requests.get(f"https://is.gd/create.php?format=simple&url={urllib.parse.quote(url)}")
        if r.status_code == 200:
            return r.text.strip()
    except Exception:
        pass
    return url  # 失敗したら元のURL

def create_share_buttons(emotion_label: str, place_name: str):
    """SNS共有ボタンを生成(学部ページ・個人ページも同梱)"""
    # --- 短縮URL生成(UTM付き)---
    fac_short  = shorten(_with_utm(FAC_URL, "body"))
    pers_short = shorten(_with_utm(PERSON_URL, "body"))
    page_short = shorten(_with_utm(PAGE_URL, "twitter"))

    # --- 共有本文 ---
    share_text = (
        f"Voice × Place Labで「{emotion_label}」と推定。"
        f"おすすめの場所は「{place_name}」。\n"
        f"🎓 情報学部: {fac_short}\n"
        f"🌐 Onohara: {pers_short}\n"
        f"#Voice2Place #AI体験 {page_short}"
    )

    # --- SNSリンク生成 ---
    enc_text = urllib.parse.quote(share_text)
    twitter_url  = f"https://twitter.com/intent/tweet?text={enc_text}"
    facebook_url = f"https://www.facebook.com/sharer/sharer.php?u={urllib.parse.quote(page_short)}"
    line_url     = f"https://line.me/R/msg/text/?{enc_text}"

    share_html = f"""
    <div style='display:flex;flex-wrap:wrap;gap:10px;margin:20px 0;'>
      <a href='{twitter_url}' target='_blank' rel='noopener noreferrer' style='text-decoration:none'>
        <div style='background:#1DA1F2;color:#fff;padding:10px 16px;border-radius:8px;font-weight:700'>🐦 X(Twitter)で共有</div>
      </a>
      <a href='{facebook_url}' target='_blank' rel='noopener noreferrer' style='text-decoration:none'>
        <div style='background:#4267B2;color:#fff;padding:10px 16px;border-radius:8px;font-weight:700'>📘 Facebookで共有</div>
      </a>
      <a href='{line_url}' target='_blank' rel='noopener noreferrer' style='text-decoration:none'>
        <div style='background:#00B900;color:#fff;padding:10px 16px;border-radius:8px;font-weight:700'>💬 LINEで共有</div>
      </a>
    </div>
    """
    return share_html

    
# ===== メイン =====
def main():
    st.set_page_config(page_title="Voice→Place Recommender", page_icon="🎙️", layout="centered")
    st.title("Voice × Place Lab - Speak, See, Recommend")
    st.caption("録音→AI感情推定→上位スポット→評価→CSV保存(匿名)")

    for key, default in [
        ("wav_bytes", None), ("recs", None), ("feat", None),
        ("emotion_label", None), ("scores", None), ("method", None),
        ("rec_key", 0),
    ]:
        if key not in st.session_state: st.session_state[key] = default

    st.subheader("1) 録音またはアップロード")
    with st.warning("アップロードで403が出る場合は、録音機能をご利用ください。"):
        st.markdown("**録音** → 直接話す or 端末で音声再生しながら録音")

    tab_rec, tab_upload = st.tabs(["録音する(推奨)", "ファイルを使う"])

    with tab_rec:
        audio = audiorecorder("録音開始 ▶", "録音停止 ■", key=f"rec_{st.session_state['rec_key']}")
        if len(audio) > 0:
            buf = io.BytesIO(); audio.export(buf, format="wav")
            st.session_state["wav_bytes"] = buf.getvalue()
            audio_player_bytes(st.session_state["wav_bytes"], mime="audio/wav")
            st.caption(f"録音サイズ: {len(st.session_state['wav_bytes']) / 1024:.1f} KB")
        if st.button("🧹 クリアして新しく録音", key="clear_rec"):
            for k in ["wav_bytes","recs","feat","emotion_label","scores","method"]:
                st.session_state[k] = None
            st.session_state["rec_key"] += 1; st.rerun()

    with tab_upload:
        uploaded_file = st.file_uploader(
            "音声ファイルを選択(WAV推奨)", type=["wav", "mp3", "m4a"], accept_multiple_files=False
        )
        if uploaded_file is not None:
            try:
                bytes_data = uploaded_file.getvalue()
                st.session_state["wav_bytes"] = bytes_data
                st.success(f"読み込み成功: {uploaded_file.name}")
                st.caption(f"ファイルサイズ: {len(bytes_data) / 1024:.1f} KB")
                audio_player_bytes(bytes_data, mime="audio/wav")
            except Exception as e:
                st.error("ファイル読み込みエラー"); st.exception(e)
                st.info("代わりに録音機能をお試しください。")

    st.subheader("2) 同意")
    consent = st.radio("研究利用の同意(匿名IDで特徴量と評価を保存します)",
                       ["保存しない(体験のみ)", "匿名で保存する"], horizontal=True)
    save_audio = st.checkbox("音声ファイルも保存する(任意)", value=False)

    analysis_method = st.radio("分析方法", ["AIモデル(推奨)", "音声特徴量ベース"], horizontal=True)

    if st.button("🔍 推定 & レコメンド", type="primary",
                 disabled=(st.session_state["wav_bytes"] is None)):
        with st.spinner('感情を分析中...'):
            raw_bytes = st.session_state["wav_bytes"]
            if analysis_method == "AIモデル(推奨)":
                emotion_label, scores, method = predict_emotion_ai(raw_bytes)
            else:
                emotion_label, scores, method = predict_emotion_features(raw_bytes)
            st.session_state["emotion_label"] = emotion_label
            st.session_state["scores"] = scores
            st.session_state["method"] = method
            st.session_state["recs"] = score_places(emotion_label, top_k=4, diversity=True)
        st.success("分析が完了しました!")

    if st.session_state["recs"] is not None:
        emotion_label = st.session_state["emotion_label"]
        scores = st.session_state["scores"]
        method = st.session_state["method"]
        recs = st.session_state["recs"]
        emotion_japanese = {
            'happiness': '幸せ', 'anger': '怒り', 'sadness': '悲しみ', 'neutral': '中立',
            'joy': '喜び', 'calm': '落ち着き', 'surprise': '驚き', 'release': '発散'
        }
        display_emotion = emotion_japanese.get(emotion_label, emotion_label)
        st.success(f"推定感情: **{display_emotion}**")

        explanations = {
            "happiness": "幸せを感じています",
            "joy": "喜びや楽しさを感じています",
            "calm": "落ち着いて穏やかな状態です",
            "surprise": "驚きや興奮を感じています",
            "anger": "怒りやイライラを感じています",
            "sadness": "悲しみや元気のない状態です",
            "neutral": "特に強い感情はない中立状態です",
            "release": "発散や解放を求めています"
        }
        if emotion_label in explanations:
            st.info(f"💡 {explanations[emotion_label]}")

        st.subheader("感情分析結果")
        fig = plot_emotion_map(emotion_label, scores, method)
        st.pyplot(fig, clear_figure=True)
        
        # 音声波形の表示
        st.subheader("音声波形分析")
        waveform_fig = create_waveform_visualization(st.session_state["wav_bytes"])
        if waveform_fig:
            st.pyplot(waveform_fig, clear_figure=True)

        st.subheader("3) おすすめ(上位4件)")
        cols = st.columns(4)
        for i, p in enumerate(recs[:4]):
            with cols[i % 4]:
                if "image" in p: st.image(p["image"], use_container_width=True)
                st.markdown(f"**{p['name']}**"); st.caption(f"タグ: {', '.join(p['tags'])}")

        st.subheader("4) 評価")
        choice_name = st.selectbox("第一候補を選んでください", [p["name"] for p in recs[:4]])
        rating_like = st.slider("行ってみたい度(★)", 1, 5, 4)
        rating_vibe = st.slider("気分に合う度(🎯)", 1, 5, 4)
        reasons = st.multiselect("理由タグ(1—3個)", REASON_TAGS, max_selections=3)
        comment = st.text_input("ひとことコメント(任意・20字)", max_chars=20)
        
        # SNS共有ボタンの表示
        st.subheader("5) SNSで共有")
        share_html = create_share_buttons(display_emotion, choice_name)
        st.markdown(share_html, unsafe_allow_html=True)
        
        if st.button("ログ保存", key="save_log"):
            consent_research = (consent == "匿名で保存する")
            if not consent_research: st.info("体験のみモードです。研究ログは保存しません。")
            else: st.success("保存機能は開発中です。")

    st.divider()
    if st.button("▶ 次の人を録音する(状態をクリア)", key="next_person"):
        for k in ["wav_bytes","recs","emotion_label","scores","method"]:
            st.session_state[k] = None
        st.session_state["rec_key"] += 1; st.rerun()

if __name__ == "__main__":
    main()