# emotion_predictor.py import os from pathlib import Path import json import torch import matplotlib matplotlib.use("Agg") # 서버/헤드리스 환경 import matplotlib.pyplot as plt import base64 from io import BytesIO from collections import defaultdict from models.model_loader import load_kote_model from utils.emotion_utils import ( parse_dialogue, adjusted_score, apply_ema, LABELS, NEGATIVE_EMOTIONS, detect_emotion_spikes, infer_conflict_initiator ) from utils.prediction_utils import forecast_scores from report.report_generator import generate_text_report from utils.plot_fonts import setup_korean_font ROOT_DIR = Path(__file__).resolve().parent setup_korean_font(project_root=ROOT_DIR) # 글로벌 캐시 로딩 (서버 시작시 1회) MODEL = load_kote_model() def _score_sentence(text: str): with torch.no_grad(): probs = MODEL(text)[0].cpu().numpy().tolist() # {label: prob} 사전 return {label: float(p) for label, p in zip(LABELS, probs)} def analyze_dialogue(raw_text: str, smooth_alpha=0.4, z_thresh=1.8, forecast_steps=3): """ 입력: "화자:문장" 줄 형식의 대화 텍스트 출력: (리포트 텍스트, 플롯(base64), 중간 구조 데이터) """ dialogue = parse_dialogue(raw_text) if not dialogue: return "대화가 비어있습니다.", None, {} # 발화별 부정 감정 점수 기록 per_speaker_series = defaultdict(lambda: defaultdict(list)) per_utt_log = [] for idx, (speaker, sentence) in enumerate(dialogue): label2prob = _score_sentence(sentence) neg_emotions = {} for lb in NEGATIVE_EMOTIONS: raw = label2prob.get(lb, 0.0) adj = adjusted_score(raw) # 0~100 neg_emotions[lb] = round(adj, 2) per_speaker_series[speaker][lb].append(adj) per_utt_log.append({ "utterance_idx": idx, "speaker": speaker, "text": sentence, "negative_emotions": neg_emotions }) # 스무딩 + 급변 탐지 + 예측 spikes_by_speaker = defaultdict(list) future_alerts = defaultdict(list) plot_buf = BytesIO() plt.figure(figsize=(11, 4)) speakers = list(per_speaker_series.keys()) for spk in speakers: # 부정 감정 중 상위 변화 항목만 시각화 (가독성) # 기준: 마지막 값 기준 top-3 last_scores = [(lb, per_speaker_series[spk][lb][-1]) for lb in NEGATIVE_EMOTIONS if per_speaker_series[spk][lb]] top = sorted(last_scores, key=lambda x: -x[1])[:3] plotted = False for lb, _ in top: raw = per_speaker_series[spk][lb] sm = apply_ema(raw, alpha=smooth_alpha) if len(sm) < 2: continue # 급변 탐지 spikes = detect_emotion_spikes(sm, z_thresh=z_thresh, min_len=5) for s in spikes: # s: (idx, score, z) spikes_by_speaker[spk].append(s) # 예측 fut = forecast_scores(sm, steps=forecast_steps) # 시각화 x_obs = list(range(len(sm))) plt.plot(x_obs, sm, label=f"{spk}-{lb}") x_fut = [len(sm)+i for i in range(1, forecast_steps+1)] plt.plot(x_fut, fut, linestyle="--") # 경고 로직: 향후 점수 80 이상 시 경고 if any(v >= 80.0 for v in fut): future_alerts[spk].append(f"'{lb}' 감정이 향후 {forecast_steps}턴 내 80+ 도달 가능") plotted = True # 스피커별 라벨링 최소 1개라도 그려졌는지 if not plotted: # 아무 라벨도 기준 통과 못하면 상위 1개라도 그리기 any_lb = next(iter(per_speaker_series[spk]), None) if any_lb: sm = apply_ema(per_speaker_series[spk][any_lb], alpha=smooth_alpha) if len(sm) >= 2: plt.plot(range(len(sm)), sm, label=f"{spk}-{any_lb}") plt.title("부정 감정 흐름(EMA) 및 예측") plt.xlabel("발화 순서") plt.ylabel("감정 점수(0~100)") plt.grid(True) plt.legend() plt.tight_layout() plt.savefig(plot_buf, format="png") plt.close() img_b64 = base64.b64encode(plot_buf.getvalue()).decode("utf-8") # 갈등 유발자 추정 initiator = infer_conflict_initiator(dialogue, spikes_by_speaker) # 리포트 텍스트 생성 report_text = generate_text_report(initiator, spikes_by_speaker, per_utt_log, future_alerts) result_struct = { "dialogue": dialogue, "per_utt_log": per_utt_log, "spikes_by_speaker": {k: v for k, v in spikes_by_speaker.items()}, "future_alerts": {k: v for k, v in future_alerts.items()}, "initiator": initiator } return report_text, img_b64, result_struct