Spaces:
Running
Running
| import re | |
| import torch | |
| import torch.nn.functional as F | |
| from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification | |
| from sentence_transformers import SentenceTransformer, util | |
| # ๋๋ฐ์ด์ค ์ค์ (GPU ์ฐ์ , ์์ผ๋ฉด CPU) | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print(f"โ ํ์ฌ ์คํ ํ๊ฒฝ: {device}") | |
| # ============================================================================= | |
| # 2. ๋ชจ๋ธ ๋ก๋ | |
| # ============================================================================= | |
| print("\nโณ [1/3] KoBART ์์ฝ ๋ชจ๋ธ ๋ก๋ฉ ์ค...") | |
| kobart_summarizer = pipeline( | |
| "summarization", | |
| model="gogamza/kobart-summarization", | |
| device=0 if torch.cuda.is_available() else -1 | |
| ) | |
| print("โณ [2/3] SBERT ์ ์ฌ๋ ๋ชจ๋ธ ๋ก๋ฉ ์ค...") | |
| sbert_model = SentenceTransformer("jhgan/ko-sroberta-multitask") | |
| print("โณ [3/3] NLI(์์ฐ์ด์ถ๋ก ) ๋ชจ๋ธ ๋ก๋ฉ ์ค...") | |
| nli_model_name = "Huffon/klue-roberta-base-nli" | |
| nli_tokenizer = AutoTokenizer.from_pretrained(nli_model_name) | |
| nli_model = AutoModelForSequenceClassification.from_pretrained(nli_model_name).to(device) | |
| nli_model.eval() | |
| print("๐ ๋ชจ๋ ๋ชจ๋ธ ๋ก๋ ์๋ฃ!\n") | |
| # ============================================================================= | |
| # 3. ๋์ฐ๋ฏธ ํจ์ ์ ์ (์์ ์ฝ๋ ์คํ์ผ ์ ์ง + ํ์ํ ๋ถ๋ถ๋ง ๊ฐ์ ) | |
| # ============================================================================= | |
| def _clean_text(text: str) -> str: | |
| text = text.strip() | |
| text = re.sub(r"\s+", " ", text) | |
| return text | |
| def _split_sentences_ko(text: str): | |
| """look-behind ์์ด ๋ฌธ์ฅ ๋ถ๋ฆฌ(์๋ฌ ๋ฐฉ์ง).""" | |
| text = _clean_text(text) | |
| parts = re.split(r"(?<=[.!?])\s+", text) # ๊ณ ์ ๊ธธ์ด look-behind(1๊ธ์)๋ง ์ฌ์ฉ | |
| if len(parts) <= 1: | |
| parts = re.split(r"(?:๋ค)\s+", text) # ๋ง์นจํ ๊ฑฐ์ ์์ ๋ ๋ณด๊ฐ | |
| return [p.strip() for p in parts if p.strip()] | |
| def summarize_kobart_strict(text): | |
| text = _clean_text(text) | |
| sents = _split_sentences_ko(text) | |
| print("[DEBUG] len(text) =", len(text), "len(sents) =", len(sents)) | |
| print("[DEBUG] first3 =", " | ".join(sents[:3])) | |
| # โ ์ค์ง ๋ฌธ์ฅ ์ ๊ธฐ์ค๋ง ์ฌ์ฉ | |
| if len(sents) <= 3: | |
| print("[DEBUG] <=3 sentences -> return as-is") | |
| return _clean_text(" ".join(sents)) if sents else text | |
| try: | |
| result = kobart_summarizer( | |
| text, | |
| min_length=30, | |
| max_length=90, | |
| num_beams=4, | |
| no_repeat_ngram_size=3, | |
| early_stopping=True, | |
| truncation=True, # ๊ธธ์ด ์ด๊ณผ ๋ฐฉ์ง | |
| )[0]["summary_text"] | |
| out = _clean_text(result) | |
| print("[DEBUG] kobart_out =", out) | |
| # ์์ฝ์ด ๋ง๋ ์ ๋๊ฒ ์งง์ ๋๋ง fallback | |
| if len(out) < 10: | |
| print("[DEBUG] too short -> fallback to first 3 sentences") | |
| return _clean_text(" ".join(sents[:3])) | |
| return out | |
| except Exception as e: | |
| print("๐จ [Error] ์์ฝ ๋ชจ๋ธ ์๋ฌ:", repr(e)) | |
| return _clean_text(" ".join(sents[:3])) if sents else text | |
| def get_cosine_similarity(title, summary): | |
| """(์ ์ง) SBERT ์ฝ์ฌ์ธ ์ ์ฌ๋""" | |
| title = _clean_text(title) | |
| summary = _clean_text(summary) | |
| emb1 = sbert_model.encode(title, convert_to_tensor=True) | |
| emb2 = sbert_model.encode(summary, convert_to_tensor=True) | |
| return float(util.cos_sim(emb1, emb2).item()) | |
| def _nli_forward(premise: str, hypothesis: str) -> torch.Tensor: | |
| """NLI softmax ํ๋ฅ (3ํด๋์ค)""" | |
| inputs = nli_tokenizer( | |
| _clean_text(premise), | |
| _clean_text(hypothesis), | |
| return_tensors="pt", | |
| truncation=True, | |
| max_length=512 | |
| ).to(device) | |
| if "token_type_ids" in inputs: | |
| del inputs["token_type_ids"] | |
| with torch.no_grad(): | |
| logits = nli_model(**inputs).logits[0] | |
| probs = F.softmax(logits, dim=-1) | |
| return probs | |
| def infer_nli_label_indices() -> dict: | |
| """ | |
| ๋ผ๋ฒจ ์์ ์๋ ์ถ์ : entail/neutral/contra ์ธ๋ฑ์ค ํ์ | |
| (์์ ์ฝ๋๋ [E,N,C]๋ฅผ ๊ฐ์ ํ์ง๋ง, ์์ ํ๊ฒ ๋ณด์ ) | |
| """ | |
| p = "์ํฅ๋ฏผ์ ์ถ๊ตฌ ์ ์๋ค." | |
| h_ent = "์ํฅ๋ฏผ์ ์ถ๊ตฌ ์ ์๋ค." | |
| h_con = "์ํฅ๋ฏผ์ ์ถ๊ตฌ ์ ์๊ฐ ์๋๋ค." | |
| pe = _nli_forward(p, h_ent) | |
| pc = _nli_forward(p, h_con) | |
| entail_idx = int(torch.argmax(pe).item()) | |
| contra_idx = int(torch.argmax(pc).item()) | |
| if entail_idx == contra_idx: | |
| p2 = "์ค๋์ ๋ง๋ค." | |
| h2 = "์ค๋์ ๋น๊ฐ ์จ๋ค." | |
| pc2 = _nli_forward(p2, h2) | |
| contra_idx = int(torch.argmax(pc2).item()) | |
| neutral_idx = list({0, 1, 2} - {entail_idx, contra_idx}) | |
| neutral_idx = int(neutral_idx[0]) if neutral_idx else 1 | |
| return {"entailment": entail_idx, "neutral": neutral_idx, "contradiction": contra_idx} | |
| NLI_IDX = infer_nli_label_indices() | |
| def get_mismatch_score(summary, title): | |
| """ | |
| (ํจ์๋ช ์ ์ง) NLI ๊ธฐ๋ฐ ๋ถ์ผ์น ์ ์ ๋ฐํ | |
| - ์์ : contradiction๋ง ๋ฐํ (๋์/์์ํ ๋ชป ์ก์) | |
| - ์์ : mismatch = 1 - entailment (์ถ์ฒ) | |
| """ | |
| probs = _nli_forward(summary, title) | |
| entail = float(probs[NLI_IDX["entailment"]].item()) | |
| neutral = float(probs[NLI_IDX["neutral"]].item()) | |
| contra = float(probs[NLI_IDX["contradiction"]].item()) | |
| # ํต์ฌ ์์ : "๋ชจ์๋ง"์ด ์๋๋ผ "ํจ์ ๋ถ์กฑ"์ ๋ถ์ผ์น๋ก ๋ด | |
| nli_mismatch = 1.0 - entail | |
| nli_mismatch = max(0.0, min(1.0, nli_mismatch)) | |
| # ๋๋ฒ๊ทธ ๋ฌธ์์ด ๋ง๋ค ๋ ์ฐ๊ธฐ ์ข๊ฒ ๊ฐ์ด ๋ฐํํ ์๋ ์์ง๋ง, | |
| # ์์ ํ์ ์ ์ง ์ํด ์ฌ๊ธฐ์๋ mismatch๋ง ๋ฐํ | |
| return round(nli_mismatch, 4), round(entail, 4), round(neutral, 4), round(contra, 4) | |
| # ============================================================================= | |
| # 4. ์ต์ข ๋ฉ์ธ ํจ์ (์์ ํจ์๋ช /๋ฆฌํดํ์ ์ ์ง) | |
| # ============================================================================= | |
| def calculate_mismatch_score(article_title, article_body): | |
| """ | |
| - w1 (SBERT ๊ฑฐ๋ฆฌ): 0.6 | |
| - w2 (NLI ๋ถ์ผ์น): 0.4 | |
| - Threshold: 0.45 ์ด์์ด๋ฉด '์ํ' | |
| """ | |
| # 1) ๋ณธ๋ฌธ ์์ฝ | |
| summary = summarize_kobart_strict(article_body) | |
| # 2) SBERT ์๋ฏธ์ ๊ฑฐ๋ฆฌ | |
| sbert_sim = get_cosine_similarity(article_title, summary) | |
| semantic_distance = 1 - sbert_sim | |
| # 3) NLI ๋ถ์ผ์น(1-entailment) + ๋๋ฒ๊ทธ์ฉ ํ๋ฅ ๋ ๋ฐ๊ธฐ | |
| nli_mismatch, entail, neutral, contra = get_mismatch_score(summary, article_title) | |
| # 4) ์ต์ข ์ ์(์์ ๊ณผ ๋์ผ ๊ตฌ์กฐ) | |
| w1, w2 = 0.6, 0.4 | |
| final_score = (w1 * semantic_distance) + (w2 * nli_mismatch) | |
| reason = ( | |
| f"[๋๋ฒ๊ทธ ๋ชจ๋]\n" | |
| f"1. ์์ฝ๋ฌธ: {summary}\n" | |
| f"2. SBERT ๊ฑฐ๋ฆฌ: {semantic_distance:.4f}\n" | |
| f"3. NLI ๋ถ์ผ์น(1-entail): {nli_mismatch:.4f}\n" | |
| f" - entail: {entail:.4f}, neutral: {neutral:.4f}, contradiction: {contra:.4f}\n" | |
| f" - label_idx: {NLI_IDX}" | |
| ) | |
| # 5) ๊ฒฐ๊ณผ ํ์ | |
| if final_score >= 0.45: | |
| recommendation = "์ ๋ชฉ์ด ๋ณธ๋ฌธ์ ๋ด์ฉ์ ์๊ณกํ๊ฑฐ๋(ํจ์ ๋ถ์กฑ) ๊ณผ์ฅ/์์๋ ๊ฐ๋ฅ์ฑ์ด ๋์ต๋๋ค." | |
| else: | |
| recommendation = "์ ๋ชฉ๊ณผ ๋ณธ๋ฌธ์ ๋ด์ฉ์ด ๋์ฒด๋ก ์ผ์นํฉ๋๋ค." | |
| return { | |
| "score": round(final_score, 4), | |
| "reason": reason, | |
| "recommendation": recommendation | |
| } |