Spaces:
Running
Running
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from transformers import BertTokenizer, BertModel | |
| import pickle | |
| import re | |
| import os | |
| import sys | |
| import numpy as np | |
| from collections import defaultdict | |
| # ============================================================================= | |
| # 1. ๋ชจ๋ธ ํด๋์ค ์ ์ | |
| # ============================================================================= | |
| # (1) ๊ท์น ๊ธฐ๋ฐ ์ค์ฝ์ด๋ฌ ํด๋์ค | |
| class RuleBasedScorer: | |
| def __init__(self): | |
| # ํจํด๋ณ ๋จ์ด ์ฌ์ | |
| self.patterns = { | |
| 11: defaultdict(float), 12: defaultdict(float), | |
| 13: defaultdict(float), 14: defaultdict(float) | |
| } | |
| self.pattern_names = { | |
| 11: '์๋ฌธ ์ ๋ฐํ(๋ถํธ)', 12: '์๋ฌธ ์ ๋ฐํ(์๋)', | |
| 13: '์ ์ ํํ ์ฌ์ฉํ', 14: '์์ด/์ค์๋ง ์ฌ์ฉํ' | |
| } | |
| # ๋ถํธ ํจํด (๋จ์ ๋ฌผ์ํ ์ ์ธ, ๊ณผ์ฅ๋ ๋ถํธ๋ง) | |
| self.symbol_patterns = { | |
| 'repeated': re.compile(r'([!?โฆ~])\1+'), # ๋ฐ๋ณต ๋ถํธ (??, !!) | |
| 'ellipsis': re.compile(r'\.\.\.|โฆ') # ๋ง์ค์ํ | |
| } | |
| def get_score(self, title): | |
| # 1. ํ ์คํธ ํ ํฐํ (๋จ์ ๋์ด์ฐ๊ธฐ ๋ฐ ๋ฌธ์ ์ถ์ถ) | |
| words = re.findall(r'[๊ฐ-ํฃA-Za-z0-9]+', str(title)) | |
| scores = {} | |
| # 2. ๋ถํธ ์ ์ ๊ณ์ฐ | |
| rep = len(self.symbol_patterns['repeated'].findall(title)) | |
| ell = len(self.symbol_patterns['ellipsis'].findall(title)) | |
| symbol_score = (rep * 30) + (ell * 10) | |
| # 3. ํจํด๋ณ(11~14) ์ ์ ๊ณ์ฐ | |
| for p in [11, 12, 13, 14]: | |
| word_score = 0 | |
| # ๋จ์ด ๋งค์นญ ์ ์ (์ฌ์ ์ ์๋ ๋จ์ด์ธ์ง ํ์ธ) | |
| if p in self.patterns: # ์์ ์ฅ์น | |
| for word in words: | |
| if word in self.patterns[p]: | |
| # ๊ฐ์ค์น ์ ์ฉ (๋ก๊ทธ ์ค์ผ์ผ) | |
| word_score += np.log1p(self.patterns[p][word]) * 2 | |
| total = 0 | |
| # ํจํด๋ณ ์ ์ ํฉ์ฐ ๋ก์ง | |
| if p == 11: # ์๋ฌธ๋ถํธํ | |
| total = symbol_score # ์ค์ง ๋ถํธ๋ง ๋ด | |
| elif p == 12: # ์๋ฌธ์๋ํ ("...์ด์ ๋") | |
| total = word_score + (symbol_score * 0.5) | |
| else: # 13(์ ์ ), 14(์์ด) | |
| total = word_score # ์ค์ง ๋จ์ด๋ง ๋ด | |
| scores[p] = total | |
| # 4. ์ต์ข ์ ์ ์ฐ์ถ (๊ฐ์ฅ ๋์ ์ ์ ์ ํ) | |
| if not scores: | |
| return {'score': 0, 'pattern': 0, 'pattern_name': '์ ์'} | |
| max_pattern = max(scores, key=scores.get) | |
| max_score = min(100, scores[max_pattern]) # 100์ ๋ง์ | |
| return { | |
| 'score': max_score, | |
| 'pattern': max_pattern, | |
| 'pattern_name': self.pattern_names.get(max_pattern, '์ ์ ์์') | |
| } | |
| # ๐จ Pickle ๋ก๋ฉ ์๋ฌ ๋ฐฉ์ง์ฉ | |
| import __main__ | |
| setattr(__main__, "RuleBasedScorer", RuleBasedScorer) | |
| # (2) KoBERT ๋ชจ๋ธ ํด๋์ค | |
| class FishingClassifier(nn.Module): | |
| def __init__(self, bert, num_classes=2): | |
| super().__init__() | |
| self.bert = bert | |
| self.dropout = nn.Dropout(0.3) | |
| self.fc = nn.Linear(768, num_classes) | |
| def forward(self, input_ids, mask): | |
| _, pooled = self.bert(input_ids=input_ids, attention_mask=mask, return_dict=False) | |
| return self.fc(self.dropout(pooled)) | |
| # ============================================================================= | |
| # 2. ๋ชจ๋ธ ๋ก๋ | |
| # ============================================================================= | |
| print("[AggroModel] ์์คํ ๋ก๋ฉ ์์...") | |
| from kobert_transformers import get_tokenizer | |
| aggro_model = None | |
| tokenizer = None | |
| rule_scorer = None | |
| device = torch.device("cpu") | |
| BASE_DIR = os.path.dirname(os.path.abspath(__file__)) | |
| # A. ๊ท์น ๋ชจ๋ธ ๋ก๋ | |
| try: | |
| with open(os.path.join(BASE_DIR, "rule_based_scorer.pkl"), "rb") as f: | |
| rule_scorer = pickle.load(f) | |
| print("โ [Aggro] ๊ท์น ๋ชจ๋ธ ๋ก๋ ์ฑ๊ณต") | |
| except: | |
| print("โ ๏ธ [Aggro] ๊ท์น ๋ชจ๋ธ ์์, ๋น ๊ฐ์ฒด ์์ฑ") | |
| rule_scorer = RuleBasedScorer() | |
| # B. KoBERT ๋ชจ๋ธ ๋ก๋ | |
| try: | |
| print("๐ KoBERT ๋ชจ๋ธ ๋ก๋ฉ ์ค...") | |
| # ํ ํฌ๋์ด์ ๋ก๋ | |
| # ์ด ํจ์๊ฐ ์์์ ์ฌ์ ํ์ผ(.spm)์ ๋ค์ด๋ก๋ํ๊ณ ์ฐ๊ฒฐํด์ค๋๋ค. | |
| tokenizer = get_tokenizer() | |
| # ๋ชจ๋ธ ๋ก๋ (monologg ๊ตฌ์กฐ ์ ์ง) | |
| bert_base = BertModel.from_pretrained('monologg/kobert') | |
| aggro_model = FishingClassifier(bert_base).to(device) | |
| # ๊ฐ์ค์น ํ์ผ ๋ก๋ | |
| pth_path = os.path.join(BASE_DIR, "bert_fishing_model_best.pth") | |
| pt_path = os.path.join(BASE_DIR, "kobert_aggro_score.pt") | |
| final_path = pth_path if os.path.exists(pth_path) else pt_path | |
| if os.path.exists(final_path): | |
| checkpoint = torch.load(final_path, map_location=device) | |
| # state_dict ์ถ์ถ | |
| if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint: | |
| loaded_state_dict = checkpoint['model_state_dict'] | |
| elif isinstance(checkpoint, dict): | |
| loaded_state_dict = checkpoint | |
| else: | |
| loaded_state_dict = checkpoint.state_dict() | |
| new_state_dict = {} | |
| for k, v in loaded_state_dict.items(): | |
| name = k | |
| # 1. module. ์ ๋์ด ์ ๊ฑฐ | |
| if name.startswith('module.'): | |
| name = name[7:] | |
| # 2. classifier -> fc ๋ก ์ด๋ฆ ๋ณ๊ฒฝ (์ฌ๊ธฐ์ ๋งค์นญ๋จ!) | |
| if 'classifier' in name: | |
| new_name = name.replace('classifier', 'fc') | |
| print(f"๐ง [Fix] ์ด๋ฆ ๋ณ๊ฒฝ ์ ์ฉ: {name} -> {new_name}") | |
| name = new_name | |
| new_state_dict[name] = v | |
| # 3. ๋ก๋ ์คํ (๊ฒฐ๊ณผ ํ์ธ) | |
| missing_keys, unexpected_keys = aggro_model.load_state_dict(new_state_dict, strict=False) | |
| # [์ค์] fc.weight๊ฐ ๋๋ฝ(missing)๋์๋์ง ํ์ธ | |
| if any("fc.weight" in key for key in missing_keys): | |
| print("๐จ [CRITICAL] fc ๋ ์ด์ด๊ฐ ์ฌ์ ํ ๋ก๋๋์ง ์์์ต๋๋ค! (์ ์ ๊ณ ์ ์์ธ)") | |
| print(f"Missing Keys: {missing_keys}") | |
| else: | |
| print("โ [Success] fc ๋ ์ด์ด(๋ถ๋ฅ๊ธฐ)๊ฐ ์ ์์ ์ผ๋ก ๋ก๋๋์์ต๋๋ค!") | |
| # if os.path.exists(final_path): | |
| # state = torch.load(final_path, map_location=device) | |
| # if isinstance(state, dict) and 'model_state_dict' in state: | |
| # aggro_model.load_state_dict(state['model_state_dict'], strict=False) | |
| # else: | |
| # aggro_model.load_state_dict(state, strict=False) | |
| aggro_model.eval() | |
| print(f"โ [Aggro] KoBERT ๋ชจ๋ธ ๋ก๋ ์๋ฃ: {os.path.basename(final_path)}") | |
| else: | |
| print("โ ๏ธ [Aggro] ๊ฐ์ค์น ํ์ผ(.pth/.pt)์ ์ฐพ์ ์ ์์ต๋๋ค!") | |
| aggro_model = None | |
| except Exception as e: | |
| print(f"๐จ [Aggro] ๋ชจ๋ธ ๋ก๋ฉ ์ค ์๋ฌ ๋ฐ์: {e}") | |
| aggro_model = None | |
| # ============================================================================= | |
| # 3. ๋ฉ์ธ ํจ์ | |
| # ============================================================================= | |
| def get_aggro_score(title: str) -> dict: | |
| # 1. ๊ท์น ์ ์ ๊ณ์ฐ | |
| rule_score = 0.0 | |
| rule_pattern = "๋ถ์ ๋ถ๊ฐ" | |
| try: | |
| res = rule_scorer.get_score(title) | |
| rule_score = res['score'] | |
| rule_pattern = res.get('pattern_name', '์ ์ ์์') | |
| except: pass | |
| # 2. KoBERT ์ ์ ๊ณ์ฐ | |
| bert_score = 0.0 | |
| if aggro_model and tokenizer: | |
| try: | |
| inputs = tokenizer( | |
| title, | |
| return_tensors='pt', | |
| padding="max_length", | |
| truncation=True, | |
| max_length=64 | |
| ) | |
| # ๐ต๏ธโโ๏ธ [ํต์ฌ ๋๋ฒ๊ทธ] ํ ํฐ์ด ์ ๋๋ก ๋ง๋ค์ด์ก๋์ง ๋ก๊ทธ์ ์ถ๋ ฅ! | |
| # ์ ์์ด๋ผ๋ฉด: [2, 4532, 1234, 3, ...] ์ฒ๋ผ ๋ค์ํ ์ซ์๊ฐ ๋์์ผ ํจ | |
| # ์๋ฌ๋ผ๋ฉด: [2, 0, 0, 0, 3, ...] ์ฒ๋ผ 0์ด ๊ฐ๋ํ๊ฑฐ๋ [2, 3] ์ฒ๋ผ ๋น์ด์์ | |
| input_ids = inputs['input_ids'].to(device) | |
| mask = inputs['attention_mask'].to(device) | |
| oken_type_ids = inputs['token_type_ids'].to(device) | |
| print(f"\n๐ [ํ ํฐ ํ์ธ] ์ ๋ ฅ: '{title}'") | |
| print(f"๐ [ํ ํฐ ID]: {input_ids[:15]} ...") # ์๋ถ๋ถ 15๊ฐ๋ง ์ถ๋ ฅ | |
| with torch.no_grad(): | |
| outputs = aggro_model(input_ids, mask) | |
| probs = F.softmax(outputs / 2.0, dim=1) | |
| bert_score = probs[0][1].item() * 100 | |
| except Exception as e: | |
| print(f"๐จ [BERT Error] {e}") | |
| bert_score = 50.0 | |
| # Safety Net | |
| if rule_score < 5: | |
| bert_score *= 0.3 | |
| elif rule_score < 20: | |
| bert_score *= 0.8 | |
| #3. ํฉ์ฐ | |
| w_rule = 0.0 | |
| w_bert = 1.0 | |
| final_score = (rule_score * w_rule) + (bert_score * w_bert) | |
| # 4. ๊ฒฐ๊ณผ | |
| normalized_score = min(final_score / 100.0, 1.0) | |
| # 5. ๋ฑ๊ธ ํ์ | |
| if final_score >= 80: | |
| reason = f"๋งค์ฐ ๋์ ๐ด" | |
| recommendation = "์ ๋ฉด ์์ ๊ถ์ฅ" | |
| elif final_score >= 60: | |
| reason = f"๋์ ๐ " | |
| recommendation = "๊ณผ์ฅ๋ ํํ ์์ ํ์" | |
| elif final_score >= 40: | |
| reason = f"๋ณดํต ๐ก" | |
| recommendation = "์ผ๋ถ ํํ ์ค๋ฆฝํ ๊ถ์ฅ" | |
| else: | |
| reason = f"๋ฎ์ ๐ข" | |
| recommendation = "์ ์ ํ ์ ๋ชฉ์ ๋๋ค" | |
| return { | |
| "score": round(normalized_score, 4), | |
| "reason": reason, | |
| "recommendation": recommendation | |
| } |