kshs33_emotion_predict / utils /emotion_utils.py
leewatson's picture
Upload 2 files
d445415 verified
# utils/emotion_utils.py
import re
import math
import numpy as np
from collections import defaultdict
LABELS = ['๋ถˆํ‰/๋ถˆ๋งŒ', 'ํ™˜์˜/ํ˜ธ์˜', '๊ฐ๋™/๊ฐํƒ„', '์ง€๊ธ‹์ง€๊ธ‹', '๊ณ ๋งˆ์›€', '์Šฌํ””',
'ํ™”๋‚จ/๋ถ„๋…ธ', '์กด๊ฒฝ', '๊ธฐ๋Œ€๊ฐ', '์šฐ์ญ๋Œ/๋ฌด์‹œํ•จ', '์•ˆํƒ€๊นŒ์›€/์‹ค๋ง', '๋น„์žฅํ•จ',
'์˜์‹ฌ/๋ถˆ์‹ ', '๋ฟŒ๋“ฏํ•จ', 'ํŽธ์•ˆ/์พŒ์ ', '์‹ ๊ธฐํ•จ/๊ด€์‹ฌ', '์•„๊ปด์ฃผ๋Š”', '๋ถ€๋„๋Ÿฌ์›€',
'๊ณตํฌ/๋ฌด์„œ์›€', '์ ˆ๋ง', 'ํ•œ์‹ฌํ•จ', '์—ญ๊ฒจ์›€/์ง•๊ทธ๋Ÿฌ์›€', '์งœ์ฆ', '์–ด์ด์—†์Œ',
'์—†์Œ', 'ํŒจ๋ฐฐ/์ž๊ธฐํ˜์˜ค', '๊ท€์ฐฎ์Œ', 'ํž˜๋“ฆ/์ง€์นจ', '์ฆ๊ฑฐ์›€/์‹ ๋‚จ', '๊นจ๋‹ฌ์Œ',
'์ฃ„์ฑ…๊ฐ', '์ฆ์˜ค/ํ˜์˜ค', 'ํ๋ญ‡ํ•จ(๊ท€์—ฌ์›€/์˜ˆ์จ)', '๋‹นํ™ฉ/๋‚œ์ฒ˜', '๊ฒฝ์•…',
'๋ถ€๋‹ด/์•ˆ_๋‚ดํ‚ด', '์„œ๋Ÿฌ์›€', '์žฌ๋ฏธ์—†์Œ', '๋ถˆ์Œํ•จ/์—ฐ๋ฏผ', '๋†€๋žŒ', 'ํ–‰๋ณต',
'๋ถˆ์•ˆ/๊ฑฑ์ •', '๊ธฐ์จ', '์•ˆ์‹ฌ/์‹ ๋ขฐ']
NEGATIVE_EMOTIONS = [
'๋ถˆํ‰/๋ถˆ๋งŒ', '์ง€๊ธ‹์ง€๊ธ‹', '์Šฌํ””', 'ํ™”๋‚จ/๋ถ„๋…ธ', '์˜์‹ฌ/๋ถˆ์‹ ', '๊ณตํฌ/๋ฌด์„œ์›€', '์ ˆ๋ง', 'ํ•œ์‹ฌํ•จ',
'์—ญ๊ฒจ์›€/์ง•๊ทธ๋Ÿฌ์›€', '์งœ์ฆ', '์–ด์ด์—†์Œ', 'ํŒจ๋ฐฐ/์ž๊ธฐํ˜์˜ค', '๊ท€์ฐฎ์Œ', 'ํž˜๋“ฆ/์ง€์นจ', '์ฃ„์ฑ…๊ฐ',
'์ฆ์˜ค/ํ˜์˜ค', '๋‹นํ™ฉ/๋‚œ์ฒ˜', '๋ถ€๋‹ด/์•ˆ_๋‚ดํ‚ด', '์„œ๋Ÿฌ์›€', '์žฌ๋ฏธ์—†์Œ'
]
def parse_dialogue(text: str):
"""
"ํ™”์ž:๋ฌธ์žฅ" ํ˜•์‹์˜ ์ค„๋‹จ์œ„ ๋Œ€ํ™”๋ฅผ ํŒŒ์‹ฑ
"""
lines = [ln.strip() for ln in text.strip().split("\n") if ln.strip()]
pairs = []
for line in lines:
m = re.match(r"([^:]+):(.+)", line)
if m:
pairs.append((m.group(1).strip(), m.group(2).strip()))
return pairs
def adjusted_score(raw_prob: float, k: float = 5.0) -> float:
"""
[0,1] ํ™•๋ฅ ์„ ์‚ฌ๋žŒ์ด ์ง๊ด€์ ์œผ๋กœ ์ดํ•ดํ•  ์ˆ˜ ์žˆ๋„๋ก 0~100 ์ ์ˆ˜๋กœ ์Šค์ผ€์ผ๋ง
(logistic stretching)
"""
return 100.0 / (1.0 + math.exp(-k * (raw_prob - 0.5)))
def apply_ema(series, alpha=0.4):
if not series:
return []
smoothed = [series[0]]
for s in series[1:]:
smoothed.append(alpha * s + (1 - alpha) * smoothed[-1])
return smoothed
def detect_emotion_spikes(emotion_series, z_thresh=1.8, min_len=5):
"""
Z-score ๊ธฐ๋ฐ˜ ๊ธ‰๋ณ€ ๊ฐ์ • ํƒ์ง€
- ํ†ต๊ณ„์  ๋ฐฐ๊ฒฝ: ํ‘œ์ค€์ ์ˆ˜ z = (x - ฮผ) / ฯƒ
- ๊ถŒ์žฅ z ์ž„๊ณ„: 1.8~2.5 (๋ฐ์ดํ„ฐ ๋ณ€๋™์„ฑ์— ๋”ฐ๋ผ ์กฐ์ •)
"""
if len(emotion_series) < min_len:
return []
mean = float(np.mean(emotion_series))
std = float(np.std(emotion_series)) + 1e-6
spikes = []
for i, v in enumerate(emotion_series):
z = (v - mean) / std
if z >= z_thresh:
spikes.append((i, v, round(z, 2)))
return spikes
def infer_conflict_initiator(dialogue, spikes_by_speaker):
"""
๊ฐ„๋‹จํ•œ ๊ทœ์น™๊ธฐ๋ฐ˜ ์œ ๋ฐœ์ž ์ถ”์ •:
- B์˜ ๊ฐ์ •์ด ๊ธ‰๋ณ€ํ•œ ์‹œ์  idx์˜ ์ง์ „ ๋ฐœํ™”์ž๊ฐ€ A๋ผ๋ฉด, A๊ฐ€ ์œ ๋ฐœ count +1
- ๊ฐ€์žฅ ๋งŽ์€ ๊ธ‰๋ณ€์„ ์œ ๋ฐœํ•œ ํ™”์ž๋ฅผ '์ถ”์ • ์œ ๋ฐœ์ž'๋กœ ๋ฐ˜ํ™˜
"""
blame = {}
for speaker, spikes in spikes_by_speaker.items():
for (idx, v, z) in spikes:
if idx == 0:
continue
prev_speaker = dialogue[idx - 1][0]
if prev_speaker != speaker:
blame[prev_speaker] = blame.get(prev_speaker, 0) + 1
if not blame:
return None
return sorted(blame.items(), key=lambda x: -x[1])[0][0]