project-tdm / mismatch_model.py
hy
mismatch
1225cdd
import re
import torch
import torch.nn.functional as F
from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification
from sentence_transformers import SentenceTransformer, util
# ๋””๋ฐ”์ด์Šค ์„ค์ • (GPU ์šฐ์„ , ์—†์œผ๋ฉด CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"โœ… ํ˜„์žฌ ์‹คํ–‰ ํ™˜๊ฒฝ: {device}")
# =============================================================================
# 2. ๋ชจ๋ธ ๋กœ๋“œ
# =============================================================================
print("\nโณ [1/3] KoBART ์š”์•ฝ ๋ชจ๋ธ ๋กœ๋”ฉ ์ค‘...")
kobart_summarizer = pipeline(
"summarization",
model="gogamza/kobart-summarization",
device=0 if torch.cuda.is_available() else -1
)
print("โณ [2/3] SBERT ์œ ์‚ฌ๋„ ๋ชจ๋ธ ๋กœ๋”ฉ ์ค‘...")
sbert_model = SentenceTransformer("jhgan/ko-sroberta-multitask")
print("โณ [3/3] NLI(์ž์—ฐ์–ด์ถ”๋ก ) ๋ชจ๋ธ ๋กœ๋”ฉ ์ค‘...")
nli_model_name = "Huffon/klue-roberta-base-nli"
nli_tokenizer = AutoTokenizer.from_pretrained(nli_model_name)
nli_model = AutoModelForSequenceClassification.from_pretrained(nli_model_name).to(device)
nli_model.eval()
print("๐ŸŽ‰ ๋ชจ๋“  ๋ชจ๋ธ ๋กœ๋“œ ์™„๋ฃŒ!\n")
# =============================================================================
# 3. ๋„์šฐ๋ฏธ ํ•จ์ˆ˜ ์ •์˜ (์˜ˆ์ „ ์ฝ”๋“œ ์Šคํƒ€์ผ ์œ ์ง€ + ํ•„์š”ํ•œ ๋ถ€๋ถ„๋งŒ ๊ฐœ์„ )
# =============================================================================
def _clean_text(text: str) -> str:
text = text.strip()
text = re.sub(r"\s+", " ", text)
return text
def _split_sentences_ko(text: str):
"""look-behind ์—†์ด ๋ฌธ์žฅ ๋ถ„๋ฆฌ(์—๋Ÿฌ ๋ฐฉ์ง€)."""
text = _clean_text(text)
parts = re.split(r"(?<=[.!?])\s+", text) # ๊ณ ์ • ๊ธธ์ด look-behind(1๊ธ€์ž)๋งŒ ์‚ฌ์šฉ
if len(parts) <= 1:
parts = re.split(r"(?:๋‹ค)\s+", text) # ๋งˆ์นจํ‘œ ๊ฑฐ์˜ ์—†์„ ๋•Œ ๋ณด๊ฐ•
return [p.strip() for p in parts if p.strip()]
def summarize_kobart_strict(text):
text = _clean_text(text)
sents = _split_sentences_ko(text)
print("[DEBUG] len(text) =", len(text), "len(sents) =", len(sents))
print("[DEBUG] first3 =", " | ".join(sents[:3]))
# โœ… ์˜ค์ง ๋ฌธ์žฅ ์ˆ˜ ๊ธฐ์ค€๋งŒ ์‚ฌ์šฉ
if len(sents) <= 3:
print("[DEBUG] <=3 sentences -> return as-is")
return _clean_text(" ".join(sents)) if sents else text
try:
result = kobart_summarizer(
text,
min_length=30,
max_length=90,
num_beams=4,
no_repeat_ngram_size=3,
early_stopping=True,
truncation=True, # ๊ธธ์ด ์ดˆ๊ณผ ๋ฐฉ์ง€
)[0]["summary_text"]
out = _clean_text(result)
print("[DEBUG] kobart_out =", out)
# ์š”์•ฝ์ด ๋ง๋„ ์•ˆ ๋˜๊ฒŒ ์งง์„ ๋•Œ๋งŒ fallback
if len(out) < 10:
print("[DEBUG] too short -> fallback to first 3 sentences")
return _clean_text(" ".join(sents[:3]))
return out
except Exception as e:
print("๐Ÿšจ [Error] ์š”์•ฝ ๋ชจ๋ธ ์—๋Ÿฌ:", repr(e))
return _clean_text(" ".join(sents[:3])) if sents else text
def get_cosine_similarity(title, summary):
"""(์œ ์ง€) SBERT ์ฝ”์‚ฌ์ธ ์œ ์‚ฌ๋„"""
title = _clean_text(title)
summary = _clean_text(summary)
emb1 = sbert_model.encode(title, convert_to_tensor=True)
emb2 = sbert_model.encode(summary, convert_to_tensor=True)
return float(util.cos_sim(emb1, emb2).item())
def _nli_forward(premise: str, hypothesis: str) -> torch.Tensor:
"""NLI softmax ํ™•๋ฅ (3ํด๋ž˜์Šค)"""
inputs = nli_tokenizer(
_clean_text(premise),
_clean_text(hypothesis),
return_tensors="pt",
truncation=True,
max_length=512
).to(device)
if "token_type_ids" in inputs:
del inputs["token_type_ids"]
with torch.no_grad():
logits = nli_model(**inputs).logits[0]
probs = F.softmax(logits, dim=-1)
return probs
def infer_nli_label_indices() -> dict:
"""
๋ผ๋ฒจ ์ˆœ์„œ ์ž๋™ ์ถ”์ •: entail/neutral/contra ์ธ๋ฑ์Šค ํ™•์ •
(์˜ˆ์ „ ์ฝ”๋“œ๋Š” [E,N,C]๋ฅผ ๊ฐ€์ •ํ–ˆ์ง€๋งŒ, ์•ˆ์ „ํ•˜๊ฒŒ ๋ณด์ •)
"""
p = "์†ํฅ๋ฏผ์€ ์ถ•๊ตฌ ์„ ์ˆ˜๋‹ค."
h_ent = "์†ํฅ๋ฏผ์€ ์ถ•๊ตฌ ์„ ์ˆ˜๋‹ค."
h_con = "์†ํฅ๋ฏผ์€ ์ถ•๊ตฌ ์„ ์ˆ˜๊ฐ€ ์•„๋‹ˆ๋‹ค."
pe = _nli_forward(p, h_ent)
pc = _nli_forward(p, h_con)
entail_idx = int(torch.argmax(pe).item())
contra_idx = int(torch.argmax(pc).item())
if entail_idx == contra_idx:
p2 = "์˜ค๋Š˜์€ ๋ง‘๋‹ค."
h2 = "์˜ค๋Š˜์€ ๋น„๊ฐ€ ์˜จ๋‹ค."
pc2 = _nli_forward(p2, h2)
contra_idx = int(torch.argmax(pc2).item())
neutral_idx = list({0, 1, 2} - {entail_idx, contra_idx})
neutral_idx = int(neutral_idx[0]) if neutral_idx else 1
return {"entailment": entail_idx, "neutral": neutral_idx, "contradiction": contra_idx}
NLI_IDX = infer_nli_label_indices()
def get_mismatch_score(summary, title):
"""
(ํ•จ์ˆ˜๋ช… ์œ ์ง€) NLI ๊ธฐ๋ฐ˜ ๋ถˆ์ผ์น˜ ์ ์ˆ˜ ๋ฐ˜ํ™˜
- ์˜ˆ์ „: contradiction๋งŒ ๋ฐ˜ํ™˜ (๋‚š์‹œ/์•”์‹œํ˜• ๋ชป ์žก์Œ)
- ์ˆ˜์ •: mismatch = 1 - entailment (์ถ”์ฒœ)
"""
probs = _nli_forward(summary, title)
entail = float(probs[NLI_IDX["entailment"]].item())
neutral = float(probs[NLI_IDX["neutral"]].item())
contra = float(probs[NLI_IDX["contradiction"]].item())
# ํ•ต์‹ฌ ์ˆ˜์ •: "๋ชจ์ˆœ๋งŒ"์ด ์•„๋‹ˆ๋ผ "ํ•จ์˜ ๋ถ€์กฑ"์„ ๋ถˆ์ผ์น˜๋กœ ๋ด„
nli_mismatch = 1.0 - entail
nli_mismatch = max(0.0, min(1.0, nli_mismatch))
# ๋””๋ฒ„๊ทธ ๋ฌธ์ž์—ด ๋งŒ๋“ค ๋•Œ ์“ฐ๊ธฐ ์ข‹๊ฒŒ ๊ฐ™์ด ๋ฐ˜ํ™˜ํ•  ์ˆ˜๋„ ์žˆ์ง€๋งŒ,
# ์˜ˆ์ „ ํ˜•์‹ ์œ ์ง€ ์œ„ํ•ด ์—ฌ๊ธฐ์„œ๋Š” mismatch๋งŒ ๋ฐ˜ํ™˜
return round(nli_mismatch, 4), round(entail, 4), round(neutral, 4), round(contra, 4)
# =============================================================================
# 4. ์ตœ์ข… ๋ฉ”์ธ ํ•จ์ˆ˜ (์˜ˆ์ „ ํ•จ์ˆ˜๋ช…/๋ฆฌํ„ดํ˜•์‹ ์œ ์ง€)
# =============================================================================
def calculate_mismatch_score(article_title, article_body):
"""
- w1 (SBERT ๊ฑฐ๋ฆฌ): 0.6
- w2 (NLI ๋ถˆ์ผ์น˜): 0.4
- Threshold: 0.45 ์ด์ƒ์ด๋ฉด '์œ„ํ—˜'
"""
# 1) ๋ณธ๋ฌธ ์š”์•ฝ
summary = summarize_kobart_strict(article_body)
# 2) SBERT ์˜๋ฏธ์  ๊ฑฐ๋ฆฌ
sbert_sim = get_cosine_similarity(article_title, summary)
semantic_distance = 1 - sbert_sim
# 3) NLI ๋ถˆ์ผ์น˜(1-entailment) + ๋””๋ฒ„๊ทธ์šฉ ํ™•๋ฅ ๋„ ๋ฐ›๊ธฐ
nli_mismatch, entail, neutral, contra = get_mismatch_score(summary, article_title)
# 4) ์ตœ์ข… ์ ์ˆ˜(์˜ˆ์ „๊ณผ ๋™์ผ ๊ตฌ์กฐ)
w1, w2 = 0.6, 0.4
final_score = (w1 * semantic_distance) + (w2 * nli_mismatch)
reason = (
f"[๋””๋ฒ„๊ทธ ๋ชจ๋“œ]\n"
f"1. ์š”์•ฝ๋ฌธ: {summary}\n"
f"2. SBERT ๊ฑฐ๋ฆฌ: {semantic_distance:.4f}\n"
f"3. NLI ๋ถˆ์ผ์น˜(1-entail): {nli_mismatch:.4f}\n"
f" - entail: {entail:.4f}, neutral: {neutral:.4f}, contradiction: {contra:.4f}\n"
f" - label_idx: {NLI_IDX}"
)
# 5) ๊ฒฐ๊ณผ ํŒ์ •
if final_score >= 0.45:
recommendation = "์ œ๋ชฉ์ด ๋ณธ๋ฌธ์˜ ๋‚ด์šฉ์„ ์™œ๊ณกํ•˜๊ฑฐ๋‚˜(ํ•จ์˜ ๋ถ€์กฑ) ๊ณผ์žฅ/์•”์‹œ๋  ๊ฐ€๋Šฅ์„ฑ์ด ๋†’์Šต๋‹ˆ๋‹ค."
else:
recommendation = "์ œ๋ชฉ๊ณผ ๋ณธ๋ฌธ์˜ ๋‚ด์šฉ์ด ๋Œ€์ฒด๋กœ ์ผ์น˜ํ•ฉ๋‹ˆ๋‹ค."
return {
"score": round(final_score, 4),
"reason": reason,
"recommendation": recommendation
}