kshs33_emotion_predict / emotion_predictor.py
leewatson's picture
Update emotion_predictor.py
c8ad3d8 verified
# emotion_predictor.py
import os
from pathlib import Path
import json
import torch
import matplotlib
matplotlib.use("Agg") # ์„œ๋ฒ„/ํ—ค๋“œ๋ฆฌ์Šค ํ™˜๊ฒฝ
import matplotlib.pyplot as plt
import base64
from io import BytesIO
from collections import defaultdict
from models.model_loader import load_kote_model
from utils.emotion_utils import (
parse_dialogue, adjusted_score, apply_ema,
LABELS, NEGATIVE_EMOTIONS, detect_emotion_spikes, infer_conflict_initiator
)
from utils.prediction_utils import forecast_scores
from report.report_generator import generate_text_report
from utils.plot_fonts import setup_korean_font
ROOT_DIR = Path(__file__).resolve().parent
setup_korean_font(project_root=ROOT_DIR)
# ๊ธ€๋กœ๋ฒŒ ์บ์‹œ ๋กœ๋”ฉ (์„œ๋ฒ„ ์‹œ์ž‘์‹œ 1ํšŒ)
MODEL = load_kote_model()
def _score_sentence(text: str):
with torch.no_grad():
probs = MODEL(text)[0].cpu().numpy().tolist()
# {label: prob} ์‚ฌ์ „
return {label: float(p) for label, p in zip(LABELS, probs)}
def analyze_dialogue(raw_text: str, smooth_alpha=0.4, z_thresh=1.8, forecast_steps=3):
"""
์ž…๋ ฅ: "ํ™”์ž:๋ฌธ์žฅ" ์ค„ ํ˜•์‹์˜ ๋Œ€ํ™” ํ…์ŠคํŠธ
์ถœ๋ ฅ: (๋ฆฌํฌํŠธ ํ…์ŠคํŠธ, ํ”Œ๋กฏ(base64), ์ค‘๊ฐ„ ๊ตฌ์กฐ ๋ฐ์ดํ„ฐ)
"""
dialogue = parse_dialogue(raw_text)
if not dialogue:
return "๋Œ€ํ™”๊ฐ€ ๋น„์–ด์žˆ์Šต๋‹ˆ๋‹ค.", None, {}
# ๋ฐœํ™”๋ณ„ ๋ถ€์ • ๊ฐ์ • ์ ์ˆ˜ ๊ธฐ๋ก
per_speaker_series = defaultdict(lambda: defaultdict(list))
per_utt_log = []
for idx, (speaker, sentence) in enumerate(dialogue):
label2prob = _score_sentence(sentence)
neg_emotions = {}
for lb in NEGATIVE_EMOTIONS:
raw = label2prob.get(lb, 0.0)
adj = adjusted_score(raw) # 0~100
neg_emotions[lb] = round(adj, 2)
per_speaker_series[speaker][lb].append(adj)
per_utt_log.append({
"utterance_idx": idx,
"speaker": speaker,
"text": sentence,
"negative_emotions": neg_emotions
})
# ์Šค๋ฌด๋”ฉ + ๊ธ‰๋ณ€ ํƒ์ง€ + ์˜ˆ์ธก
spikes_by_speaker = defaultdict(list)
future_alerts = defaultdict(list)
plot_buf = BytesIO()
plt.figure(figsize=(11, 4))
speakers = list(per_speaker_series.keys())
for spk in speakers:
# ๋ถ€์ • ๊ฐ์ • ์ค‘ ์ƒ์œ„ ๋ณ€ํ™” ํ•ญ๋ชฉ๋งŒ ์‹œ๊ฐํ™” (๊ฐ€๋…์„ฑ)
# ๊ธฐ์ค€: ๋งˆ์ง€๋ง‰ ๊ฐ’ ๊ธฐ์ค€ top-3
last_scores = [(lb, per_speaker_series[spk][lb][-1]) for lb in NEGATIVE_EMOTIONS if per_speaker_series[spk][lb]]
top = sorted(last_scores, key=lambda x: -x[1])[:3]
plotted = False
for lb, _ in top:
raw = per_speaker_series[spk][lb]
sm = apply_ema(raw, alpha=smooth_alpha)
if len(sm) < 2:
continue
# ๊ธ‰๋ณ€ ํƒ์ง€
spikes = detect_emotion_spikes(sm, z_thresh=z_thresh, min_len=5)
for s in spikes:
# s: (idx, score, z)
spikes_by_speaker[spk].append(s)
# ์˜ˆ์ธก
fut = forecast_scores(sm, steps=forecast_steps)
# ์‹œ๊ฐํ™”
x_obs = list(range(len(sm)))
plt.plot(x_obs, sm, label=f"{spk}-{lb}")
x_fut = [len(sm)+i for i in range(1, forecast_steps+1)]
plt.plot(x_fut, fut, linestyle="--")
# ๊ฒฝ๊ณ  ๋กœ์ง: ํ–ฅํ›„ ์ ์ˆ˜ 80 ์ด์ƒ ์‹œ ๊ฒฝ๊ณ 
if any(v >= 80.0 for v in fut):
future_alerts[spk].append(f"'{lb}' ๊ฐ์ •์ด ํ–ฅํ›„ {forecast_steps}ํ„ด ๋‚ด 80+ ๋„๋‹ฌ ๊ฐ€๋Šฅ")
plotted = True
# ์Šคํ”ผ์ปค๋ณ„ ๋ผ๋ฒจ๋ง ์ตœ์†Œ 1๊ฐœ๋ผ๋„ ๊ทธ๋ ค์กŒ๋Š”์ง€
if not plotted:
# ์•„๋ฌด ๋ผ๋ฒจ๋„ ๊ธฐ์ค€ ํ†ต๊ณผ ๋ชปํ•˜๋ฉด ์ƒ์œ„ 1๊ฐœ๋ผ๋„ ๊ทธ๋ฆฌ๊ธฐ
any_lb = next(iter(per_speaker_series[spk]), None)
if any_lb:
sm = apply_ema(per_speaker_series[spk][any_lb], alpha=smooth_alpha)
if len(sm) >= 2:
plt.plot(range(len(sm)), sm, label=f"{spk}-{any_lb}")
plt.title("๋ถ€์ • ๊ฐ์ • ํ๋ฆ„(EMA) ๋ฐ ์˜ˆ์ธก")
plt.xlabel("๋ฐœํ™” ์ˆœ์„œ")
plt.ylabel("๊ฐ์ • ์ ์ˆ˜(0~100)")
plt.grid(True)
plt.legend()
plt.tight_layout()
plt.savefig(plot_buf, format="png")
plt.close()
img_b64 = base64.b64encode(plot_buf.getvalue()).decode("utf-8")
# ๊ฐˆ๋“ฑ ์œ ๋ฐœ์ž ์ถ”์ •
initiator = infer_conflict_initiator(dialogue, spikes_by_speaker)
# ๋ฆฌํฌํŠธ ํ…์ŠคํŠธ ์ƒ์„ฑ
report_text = generate_text_report(initiator, spikes_by_speaker, per_utt_log, future_alerts)
result_struct = {
"dialogue": dialogue,
"per_utt_log": per_utt_log,
"spikes_by_speaker": {k: v for k, v in spikes_by_speaker.items()},
"future_alerts": {k: v for k, v in future_alerts.items()},
"initiator": initiator
}
return report_text, img_b64, result_struct