Spaces:
Sleeping
Sleeping
Upload 2 files
Browse files- utils/emotion_utils.py +82 -0
- utils/prediction_utils.py +83 -0
utils/emotion_utils.py
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# utils/emotion_utils.py
|
| 2 |
+
import re
|
| 3 |
+
import math
|
| 4 |
+
import numpy as np
|
| 5 |
+
from collections import defaultdict
|
| 6 |
+
|
| 7 |
+
LABELS = ['๋ถํ/๋ถ๋ง', 'ํ์/ํธ์', '๊ฐ๋/๊ฐํ', '์ง๊ธ์ง๊ธ', '๊ณ ๋ง์', '์ฌํ',
|
| 8 |
+
'ํ๋จ/๋ถ๋
ธ', '์กด๊ฒฝ', '๊ธฐ๋๊ฐ', '์ฐ์ญ๋/๋ฌด์ํจ', '์ํ๊น์/์ค๋ง', '๋น์ฅํจ',
|
| 9 |
+
'์์ฌ/๋ถ์ ', '๋ฟ๋ฏํจ', 'ํธ์/์พ์ ', '์ ๊ธฐํจ/๊ด์ฌ', '์๊ปด์ฃผ๋', '๋ถ๋๋ฌ์',
|
| 10 |
+
'๊ณตํฌ/๋ฌด์์', '์ ๋ง', 'ํ์ฌํจ', '์ญ๊ฒจ์/์ง๊ทธ๋ฌ์', '์ง์ฆ', '์ด์ด์์',
|
| 11 |
+
'์์', 'ํจ๋ฐฐ/์๊ธฐํ์ค', '๊ท์ฐฎ์', 'ํ๋ฆ/์ง์นจ', '์ฆ๊ฑฐ์/์ ๋จ', '๊นจ๋ฌ์',
|
| 12 |
+
'์ฃ์ฑ
๊ฐ', '์ฆ์ค/ํ์ค', 'ํ๋ญํจ(๊ท์ฌ์/์์จ)', '๋นํฉ/๋์ฒ', '๊ฒฝ์
',
|
| 13 |
+
'๋ถ๋ด/์_๋ดํด', '์๋ฌ์', '์ฌ๋ฏธ์์', '๋ถ์ํจ/์ฐ๋ฏผ', '๋๋', 'ํ๋ณต',
|
| 14 |
+
'๋ถ์/๊ฑฑ์ ', '๊ธฐ์จ', '์์ฌ/์ ๋ขฐ']
|
| 15 |
+
|
| 16 |
+
NEGATIVE_EMOTIONS = [
|
| 17 |
+
'๋ถํ/๋ถ๋ง', '์ง๊ธ์ง๊ธ', '์ฌํ', 'ํ๋จ/๋ถ๋
ธ', '์์ฌ/๋ถ์ ', '๊ณตํฌ/๋ฌด์์', '์ ๋ง', 'ํ์ฌํจ',
|
| 18 |
+
'์ญ๊ฒจ์/์ง๊ทธ๋ฌ์', '์ง์ฆ', '์ด์ด์์', 'ํจ๋ฐฐ/์๊ธฐํ์ค', '๊ท์ฐฎ์', 'ํ๋ฆ/์ง์นจ', '์ฃ์ฑ
๊ฐ',
|
| 19 |
+
'์ฆ์ค/ํ์ค', '๋นํฉ/๋์ฒ', '๋ถ๋ด/์_๋ดํด', '์๋ฌ์', '์ฌ๋ฏธ์์'
|
| 20 |
+
]
|
| 21 |
+
|
| 22 |
+
def parse_dialogue(text: str):
|
| 23 |
+
"""
|
| 24 |
+
"ํ์:๋ฌธ์ฅ" ํ์์ ์ค๋จ์ ๋ํ๋ฅผ ํ์ฑ
|
| 25 |
+
"""
|
| 26 |
+
lines = [ln.strip() for ln in text.strip().split("\n") if ln.strip()]
|
| 27 |
+
pairs = []
|
| 28 |
+
for line in lines:
|
| 29 |
+
m = re.match(r"([^:]+):(.+)", line)
|
| 30 |
+
if m:
|
| 31 |
+
pairs.append((m.group(1).strip(), m.group(2).strip()))
|
| 32 |
+
return pairs
|
| 33 |
+
|
| 34 |
+
def adjusted_score(raw_prob: float, k: float = 5.0) -> float:
|
| 35 |
+
"""
|
| 36 |
+
[0,1] ํ๋ฅ ์ ์ฌ๋์ด ์ง๊ด์ ์ผ๋ก ์ดํดํ ์ ์๋๋ก 0~100 ์ ์๋ก ์ค์ผ์ผ๋ง
|
| 37 |
+
(logistic stretching)
|
| 38 |
+
"""
|
| 39 |
+
return 100.0 / (1.0 + math.exp(-k * (raw_prob - 0.5)))
|
| 40 |
+
|
| 41 |
+
def apply_ema(series, alpha=0.4):
|
| 42 |
+
if not series:
|
| 43 |
+
return []
|
| 44 |
+
smoothed = [series[0]]
|
| 45 |
+
for s in series[1:]:
|
| 46 |
+
smoothed.append(alpha * s + (1 - alpha) * smoothed[-1])
|
| 47 |
+
return smoothed
|
| 48 |
+
|
| 49 |
+
def detect_emotion_spikes(emotion_series, z_thresh=1.8, min_len=5):
|
| 50 |
+
"""
|
| 51 |
+
Z-score ๊ธฐ๋ฐ ๊ธ๋ณ ๊ฐ์ ํ์ง
|
| 52 |
+
- ํต๊ณ์ ๋ฐฐ๊ฒฝ: ํ์ค์ ์ z = (x - ฮผ) / ฯ
|
| 53 |
+
- ๊ถ์ฅ z ์๊ณ: 1.8~2.5 (๋ฐ์ดํฐ ๋ณ๋์ฑ์ ๋ฐ๋ผ ์กฐ์ )
|
| 54 |
+
"""
|
| 55 |
+
if len(emotion_series) < min_len:
|
| 56 |
+
return []
|
| 57 |
+
mean = float(np.mean(emotion_series))
|
| 58 |
+
std = float(np.std(emotion_series)) + 1e-6
|
| 59 |
+
spikes = []
|
| 60 |
+
for i, v in enumerate(emotion_series):
|
| 61 |
+
z = (v - mean) / std
|
| 62 |
+
if z >= z_thresh:
|
| 63 |
+
spikes.append((i, v, round(z, 2)))
|
| 64 |
+
return spikes
|
| 65 |
+
|
| 66 |
+
def infer_conflict_initiator(dialogue, spikes_by_speaker):
|
| 67 |
+
"""
|
| 68 |
+
๊ฐ๋จํ ๊ท์น๊ธฐ๋ฐ ์ ๋ฐ์ ์ถ์ :
|
| 69 |
+
- B์ ๊ฐ์ ์ด ๊ธ๋ณํ ์์ idx์ ์ง์ ๋ฐํ์๊ฐ A๋ผ๋ฉด, A๊ฐ ์ ๋ฐ count +1
|
| 70 |
+
- ๊ฐ์ฅ ๋ง์ ๊ธ๋ณ์ ์ ๋ฐํ ํ์๋ฅผ '์ถ์ ์ ๋ฐ์'๋ก ๋ฐํ
|
| 71 |
+
"""
|
| 72 |
+
blame = {}
|
| 73 |
+
for speaker, spikes in spikes_by_speaker.items():
|
| 74 |
+
for (idx, v, z) in spikes:
|
| 75 |
+
if idx == 0:
|
| 76 |
+
continue
|
| 77 |
+
prev_speaker = dialogue[idx - 1][0]
|
| 78 |
+
if prev_speaker != speaker:
|
| 79 |
+
blame[prev_speaker] = blame.get(prev_speaker, 0) + 1
|
| 80 |
+
if not blame:
|
| 81 |
+
return None
|
| 82 |
+
return sorted(blame.items(), key=lambda x: -x[1])[0][0]
|
utils/prediction_utils.py
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# utils/prediction_utils.py
|
| 2 |
+
import warnings
|
| 3 |
+
warnings.filterwarnings("ignore")
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
|
| 7 |
+
# Prophet ์๋, ์คํจ ์ Holt(์ง์ ํํ) ํด๋ฐฑ
|
| 8 |
+
try:
|
| 9 |
+
from prophet import Prophet
|
| 10 |
+
_HAS_PROPHET = True
|
| 11 |
+
except Exception:
|
| 12 |
+
_HAS_PROPHET = False
|
| 13 |
+
|
| 14 |
+
try:
|
| 15 |
+
from statsmodels.tsa.holtwinters import ExponentialSmoothing
|
| 16 |
+
_HAS_HOLT = True
|
| 17 |
+
except Exception:
|
| 18 |
+
_HAS_HOLT = False
|
| 19 |
+
|
| 20 |
+
import pandas as pd
|
| 21 |
+
|
| 22 |
+
def prophet_emotion_forecast(scores, steps=3):
|
| 23 |
+
"""
|
| 24 |
+
Prophet ๊ธฐ๋ฐ ์์ธก. ์
๋ ฅ: ์ ์ ๋ฆฌ์คํธ(0~100)
|
| 25 |
+
๋ฐํ: ์์ธก ์ ์ ๋ฆฌ์คํธ ๊ธธ์ด=steps (float)
|
| 26 |
+
์ฐธ๊ณ : Taylor & Letham (2018) 'Forecasting at Scale'
|
| 27 |
+
"""
|
| 28 |
+
if not _HAS_PROPHET or len(scores) < 5:
|
| 29 |
+
return None
|
| 30 |
+
|
| 31 |
+
df = pd.DataFrame({
|
| 32 |
+
'ds': pd.date_range(start='2023-01-01', periods=len(scores), freq='D'),
|
| 33 |
+
'y': scores
|
| 34 |
+
})
|
| 35 |
+
model = Prophet(
|
| 36 |
+
daily_seasonality=False, weekly_seasonality=False, yearly_seasonality=False,
|
| 37 |
+
changepoint_prior_scale=0.2 # ๋ณํ์ ๋ฏผ๊ฐ๋
|
| 38 |
+
)
|
| 39 |
+
model.fit(df)
|
| 40 |
+
future = model.make_future_dataframe(periods=steps)
|
| 41 |
+
forecast = model.predict(future)
|
| 42 |
+
tail = forecast[['yhat']].tail(steps)['yhat'].to_list()
|
| 43 |
+
# ์์ ํด๋ฆฌํ
|
| 44 |
+
return [float(np.clip(x, 0, 100)) for x in tail]
|
| 45 |
+
|
| 46 |
+
def holt_fallback_forecast(scores, steps=3):
|
| 47 |
+
"""
|
| 48 |
+
Holt ์ง์ํํ(์ถ์ธํ) ํด๋ฐฑ ์์ธก.
|
| 49 |
+
Prophet ์ฌ์ฉ ๋ถ๊ฐ ํ๊ฒฝ์์ ๊ฐ๋จํ๊ณ ์์ ์ .
|
| 50 |
+
"""
|
| 51 |
+
if not _HAS_HOLT or len(scores) < 4:
|
| 52 |
+
return None
|
| 53 |
+
try:
|
| 54 |
+
model = ExponentialSmoothing(
|
| 55 |
+
np.array(scores, dtype=float),
|
| 56 |
+
trend='add', seasonal=None
|
| 57 |
+
).fit()
|
| 58 |
+
fcast = model.forecast(steps)
|
| 59 |
+
return [float(np.clip(x, 0, 100)) for x in fcast.tolist()]
|
| 60 |
+
except Exception:
|
| 61 |
+
return None
|
| 62 |
+
|
| 63 |
+
def forecast_scores(scores, steps=3):
|
| 64 |
+
"""
|
| 65 |
+
ํตํฉ ์์ธก ์์ง:
|
| 66 |
+
1) Prophet ์ฑ๊ณต ์ Prophet ์ฌ์ฉ
|
| 67 |
+
2) ์๋๋ฉด Holt ํด๋ฐฑ
|
| 68 |
+
3) ๊ทธ๋๋ ์ ๋๋ฉด, ๋ง์ง๋ง ๊ธฐ์ธ๊ธฐ ์ ํ ์ธ์ฝ(๊ทน๋จ ํด๋ฐฑ)
|
| 69 |
+
"""
|
| 70 |
+
if len(scores) < 2:
|
| 71 |
+
return [scores[-1]] * steps if scores else [0.0]*steps
|
| 72 |
+
|
| 73 |
+
yhat = prophet_emotion_forecast(scores, steps=steps)
|
| 74 |
+
if yhat is not None:
|
| 75 |
+
return yhat
|
| 76 |
+
|
| 77 |
+
yhat = holt_fallback_forecast(scores, steps=steps)
|
| 78 |
+
if yhat is not None:
|
| 79 |
+
return yhat
|
| 80 |
+
|
| 81 |
+
# ๋ง์ง๋ง ๋ ์ ์ ๊ธฐ์ธ๊ธฐ๋ก ์ธ์ฝ
|
| 82 |
+
slope = scores[-1] - scores[-2]
|
| 83 |
+
return [float(np.clip(scores[-1] + slope*(i+1), 0, 100)) for i in range(steps)]
|