|
|
|
|
|
import os |
|
|
from pathlib import Path |
|
|
|
|
|
import json |
|
|
import torch |
|
|
import matplotlib |
|
|
matplotlib.use("Agg") |
|
|
import matplotlib.pyplot as plt |
|
|
import base64 |
|
|
from io import BytesIO |
|
|
from collections import defaultdict |
|
|
|
|
|
from models.model_loader import load_kote_model |
|
|
from utils.emotion_utils import ( |
|
|
parse_dialogue, adjusted_score, apply_ema, |
|
|
LABELS, NEGATIVE_EMOTIONS, detect_emotion_spikes, infer_conflict_initiator |
|
|
) |
|
|
from utils.prediction_utils import forecast_scores |
|
|
from report.report_generator import generate_text_report |
|
|
|
|
|
from utils.plot_fonts import setup_korean_font |
|
|
ROOT_DIR = Path(__file__).resolve().parent |
|
|
setup_korean_font(project_root=ROOT_DIR) |
|
|
|
|
|
|
|
|
MODEL = load_kote_model() |
|
|
|
|
|
def _score_sentence(text: str): |
|
|
with torch.no_grad(): |
|
|
probs = MODEL(text)[0].cpu().numpy().tolist() |
|
|
|
|
|
return {label: float(p) for label, p in zip(LABELS, probs)} |
|
|
|
|
|
def analyze_dialogue(raw_text: str, smooth_alpha=0.4, z_thresh=1.8, forecast_steps=3): |
|
|
""" |
|
|
์
๋ ฅ: "ํ์:๋ฌธ์ฅ" ์ค ํ์์ ๋ํ ํ
์คํธ |
|
|
์ถ๋ ฅ: (๋ฆฌํฌํธ ํ
์คํธ, ํ๋กฏ(base64), ์ค๊ฐ ๊ตฌ์กฐ ๋ฐ์ดํฐ) |
|
|
""" |
|
|
dialogue = parse_dialogue(raw_text) |
|
|
if not dialogue: |
|
|
return "๋ํ๊ฐ ๋น์ด์์ต๋๋ค.", None, {} |
|
|
|
|
|
|
|
|
per_speaker_series = defaultdict(lambda: defaultdict(list)) |
|
|
per_utt_log = [] |
|
|
|
|
|
for idx, (speaker, sentence) in enumerate(dialogue): |
|
|
label2prob = _score_sentence(sentence) |
|
|
neg_emotions = {} |
|
|
for lb in NEGATIVE_EMOTIONS: |
|
|
raw = label2prob.get(lb, 0.0) |
|
|
adj = adjusted_score(raw) |
|
|
neg_emotions[lb] = round(adj, 2) |
|
|
per_speaker_series[speaker][lb].append(adj) |
|
|
|
|
|
per_utt_log.append({ |
|
|
"utterance_idx": idx, |
|
|
"speaker": speaker, |
|
|
"text": sentence, |
|
|
"negative_emotions": neg_emotions |
|
|
}) |
|
|
|
|
|
|
|
|
spikes_by_speaker = defaultdict(list) |
|
|
future_alerts = defaultdict(list) |
|
|
|
|
|
plot_buf = BytesIO() |
|
|
plt.figure(figsize=(11, 4)) |
|
|
|
|
|
speakers = list(per_speaker_series.keys()) |
|
|
for spk in speakers: |
|
|
|
|
|
|
|
|
last_scores = [(lb, per_speaker_series[spk][lb][-1]) for lb in NEGATIVE_EMOTIONS if per_speaker_series[spk][lb]] |
|
|
top = sorted(last_scores, key=lambda x: -x[1])[:3] |
|
|
plotted = False |
|
|
|
|
|
for lb, _ in top: |
|
|
raw = per_speaker_series[spk][lb] |
|
|
sm = apply_ema(raw, alpha=smooth_alpha) |
|
|
if len(sm) < 2: |
|
|
continue |
|
|
|
|
|
|
|
|
spikes = detect_emotion_spikes(sm, z_thresh=z_thresh, min_len=5) |
|
|
for s in spikes: |
|
|
|
|
|
spikes_by_speaker[spk].append(s) |
|
|
|
|
|
|
|
|
fut = forecast_scores(sm, steps=forecast_steps) |
|
|
|
|
|
|
|
|
x_obs = list(range(len(sm))) |
|
|
plt.plot(x_obs, sm, label=f"{spk}-{lb}") |
|
|
x_fut = [len(sm)+i for i in range(1, forecast_steps+1)] |
|
|
plt.plot(x_fut, fut, linestyle="--") |
|
|
|
|
|
|
|
|
if any(v >= 80.0 for v in fut): |
|
|
future_alerts[spk].append(f"'{lb}' ๊ฐ์ ์ด ํฅํ {forecast_steps}ํด ๋ด 80+ ๋๋ฌ ๊ฐ๋ฅ") |
|
|
|
|
|
plotted = True |
|
|
|
|
|
|
|
|
if not plotted: |
|
|
|
|
|
any_lb = next(iter(per_speaker_series[spk]), None) |
|
|
if any_lb: |
|
|
sm = apply_ema(per_speaker_series[spk][any_lb], alpha=smooth_alpha) |
|
|
if len(sm) >= 2: |
|
|
plt.plot(range(len(sm)), sm, label=f"{spk}-{any_lb}") |
|
|
|
|
|
plt.title("๋ถ์ ๊ฐ์ ํ๋ฆ(EMA) ๋ฐ ์์ธก") |
|
|
plt.xlabel("๋ฐํ ์์") |
|
|
plt.ylabel("๊ฐ์ ์ ์(0~100)") |
|
|
plt.grid(True) |
|
|
plt.legend() |
|
|
plt.tight_layout() |
|
|
plt.savefig(plot_buf, format="png") |
|
|
plt.close() |
|
|
|
|
|
img_b64 = base64.b64encode(plot_buf.getvalue()).decode("utf-8") |
|
|
|
|
|
|
|
|
initiator = infer_conflict_initiator(dialogue, spikes_by_speaker) |
|
|
|
|
|
|
|
|
report_text = generate_text_report(initiator, spikes_by_speaker, per_utt_log, future_alerts) |
|
|
|
|
|
result_struct = { |
|
|
"dialogue": dialogue, |
|
|
"per_utt_log": per_utt_log, |
|
|
"spikes_by_speaker": {k: v for k, v in spikes_by_speaker.items()}, |
|
|
"future_alerts": {k: v for k, v in future_alerts.items()}, |
|
|
"initiator": initiator |
|
|
} |
|
|
return report_text, img_b64, result_struct |