project-tdm / aggro_model.py
hy
final
ab47fb4
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import BertTokenizer, BertModel
import pickle
import re
import os
import sys
import numpy as np
from collections import defaultdict
# =============================================================================
# 1. ๋ชจ๋ธ ํด๋ž˜์Šค ์ •์˜
# =============================================================================
# (1) ๊ทœ์น™ ๊ธฐ๋ฐ˜ ์Šค์ฝ”์–ด๋Ÿฌ ํด๋ž˜์Šค
class RuleBasedScorer:
def __init__(self):
# ํŒจํ„ด๋ณ„ ๋‹จ์–ด ์‚ฌ์ „
self.patterns = {
11: defaultdict(float), 12: defaultdict(float),
13: defaultdict(float), 14: defaultdict(float)
}
self.pattern_names = {
11: '์˜๋ฌธ ์œ ๋ฐœํ˜•(๋ถ€ํ˜ธ)', 12: '์˜๋ฌธ ์œ ๋ฐœํ˜•(์€๋‹‰)',
13: '์„ ์ •ํ‘œํ˜„ ์‚ฌ์šฉํ˜•', 14: '์†์–ด/์ค„์ž„๋ง ์‚ฌ์šฉํ˜•'
}
# ๋ถ€ํ˜ธ ํŒจํ„ด (๋‹จ์ˆœ ๋ฌผ์Œํ‘œ ์ œ์™ธ, ๊ณผ์žฅ๋œ ๋ถ€ํ˜ธ๋งŒ)
self.symbol_patterns = {
'repeated': re.compile(r'([!?โ€ฆ~])\1+'), # ๋ฐ˜๋ณต ๋ถ€ํ˜ธ (??, !!)
'ellipsis': re.compile(r'\.\.\.|โ€ฆ') # ๋ง์ค„์ž„ํ‘œ
}
def get_score(self, title):
# 1. ํ…์ŠคํŠธ ํ† ํฐํ™” (๋‹จ์ˆœ ๋„์–ด์“ฐ๊ธฐ ๋ฐ ๋ฌธ์ž ์ถ”์ถœ)
words = re.findall(r'[๊ฐ€-ํžฃA-Za-z0-9]+', str(title))
scores = {}
# 2. ๋ถ€ํ˜ธ ์ ์ˆ˜ ๊ณ„์‚ฐ
rep = len(self.symbol_patterns['repeated'].findall(title))
ell = len(self.symbol_patterns['ellipsis'].findall(title))
symbol_score = (rep * 30) + (ell * 10)
# 3. ํŒจํ„ด๋ณ„(11~14) ์ ์ˆ˜ ๊ณ„์‚ฐ
for p in [11, 12, 13, 14]:
word_score = 0
# ๋‹จ์–ด ๋งค์นญ ์ ์ˆ˜ (์‚ฌ์ „์— ์žˆ๋Š” ๋‹จ์–ด์ธ์ง€ ํ™•์ธ)
if p in self.patterns: # ์•ˆ์ „์žฅ์น˜
for word in words:
if word in self.patterns[p]:
# ๊ฐ€์ค‘์น˜ ์ ์šฉ (๋กœ๊ทธ ์Šค์ผ€์ผ)
word_score += np.log1p(self.patterns[p][word]) * 2
total = 0
# ํŒจํ„ด๋ณ„ ์ ์ˆ˜ ํ•ฉ์‚ฐ ๋กœ์ง
if p == 11: # ์˜๋ฌธ๋ถ€ํ˜ธํ˜•
total = symbol_score # ์˜ค์ง ๋ถ€ํ˜ธ๋งŒ ๋ด„
elif p == 12: # ์˜๋ฌธ์€๋‹‰ํ˜• ("...์ด์œ ๋Š”")
total = word_score + (symbol_score * 0.5)
else: # 13(์„ ์ •), 14(์†์–ด)
total = word_score # ์˜ค์ง ๋‹จ์–ด๋งŒ ๋ด„
scores[p] = total
# 4. ์ตœ์ข… ์ ์ˆ˜ ์‚ฐ์ถœ (๊ฐ€์žฅ ๋†’์€ ์ ์ˆ˜ ์„ ํƒ)
if not scores:
return {'score': 0, 'pattern': 0, 'pattern_name': '์ •์ƒ'}
max_pattern = max(scores, key=scores.get)
max_score = min(100, scores[max_pattern]) # 100์  ๋งŒ์ 
return {
'score': max_score,
'pattern': max_pattern,
'pattern_name': self.pattern_names.get(max_pattern, '์•Œ ์ˆ˜ ์—†์Œ')
}
# ๐Ÿšจ Pickle ๋กœ๋”ฉ ์—๋Ÿฌ ๋ฐฉ์ง€์šฉ
import __main__
setattr(__main__, "RuleBasedScorer", RuleBasedScorer)
# (2) KoBERT ๋ชจ๋ธ ํด๋ž˜์Šค
class FishingClassifier(nn.Module):
def __init__(self, bert, num_classes=2):
super().__init__()
self.bert = bert
self.dropout = nn.Dropout(0.3)
self.fc = nn.Linear(768, num_classes)
def forward(self, input_ids, mask):
_, pooled = self.bert(input_ids=input_ids, attention_mask=mask, return_dict=False)
return self.fc(self.dropout(pooled))
# =============================================================================
# 2. ๋ชจ๋ธ ๋กœ๋“œ
# =============================================================================
print("[AggroModel] ์‹œ์Šคํ…œ ๋กœ๋”ฉ ์‹œ์ž‘...")
from kobert_transformers import get_tokenizer
aggro_model = None
tokenizer = None
rule_scorer = None
device = torch.device("cpu")
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
# A. ๊ทœ์น™ ๋ชจ๋ธ ๋กœ๋“œ
try:
with open(os.path.join(BASE_DIR, "rule_based_scorer.pkl"), "rb") as f:
rule_scorer = pickle.load(f)
print("โœ… [Aggro] ๊ทœ์น™ ๋ชจ๋ธ ๋กœ๋“œ ์„ฑ๊ณต")
except:
print("โš ๏ธ [Aggro] ๊ทœ์น™ ๋ชจ๋ธ ์—†์Œ, ๋นˆ ๊ฐ์ฒด ์ƒ์„ฑ")
rule_scorer = RuleBasedScorer()
# B. KoBERT ๋ชจ๋ธ ๋กœ๋“œ
try:
print("๐Ÿ”„ KoBERT ๋ชจ๋ธ ๋กœ๋”ฉ ์ค‘...")
# ํ† ํฌ๋‚˜์ด์ € ๋กœ๋“œ
# ์ด ํ•จ์ˆ˜๊ฐ€ ์•Œ์•„์„œ ์‚ฌ์ „ ํŒŒ์ผ(.spm)์„ ๋‹ค์šด๋กœ๋“œํ•˜๊ณ  ์—ฐ๊ฒฐํ•ด์ค๋‹ˆ๋‹ค.
tokenizer = get_tokenizer()
# ๋ชจ๋ธ ๋กœ๋“œ (monologg ๊ตฌ์กฐ ์œ ์ง€)
bert_base = BertModel.from_pretrained('monologg/kobert')
aggro_model = FishingClassifier(bert_base).to(device)
# ๊ฐ€์ค‘์น˜ ํŒŒ์ผ ๋กœ๋“œ
pth_path = os.path.join(BASE_DIR, "bert_fishing_model_best.pth")
pt_path = os.path.join(BASE_DIR, "kobert_aggro_score.pt")
final_path = pth_path if os.path.exists(pth_path) else pt_path
if os.path.exists(final_path):
checkpoint = torch.load(final_path, map_location=device)
# state_dict ์ถ”์ถœ
if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
loaded_state_dict = checkpoint['model_state_dict']
elif isinstance(checkpoint, dict):
loaded_state_dict = checkpoint
else:
loaded_state_dict = checkpoint.state_dict()
new_state_dict = {}
for k, v in loaded_state_dict.items():
name = k
# 1. module. ์ ‘๋‘์–ด ์ œ๊ฑฐ
if name.startswith('module.'):
name = name[7:]
# 2. classifier -> fc ๋กœ ์ด๋ฆ„ ๋ณ€๊ฒฝ (์—ฌ๊ธฐ์„œ ๋งค์นญ๋จ!)
if 'classifier' in name:
new_name = name.replace('classifier', 'fc')
print(f"๐Ÿ”ง [Fix] ์ด๋ฆ„ ๋ณ€๊ฒฝ ์ ์šฉ: {name} -> {new_name}")
name = new_name
new_state_dict[name] = v
# 3. ๋กœ๋“œ ์‹คํ–‰ (๊ฒฐ๊ณผ ํ™•์ธ)
missing_keys, unexpected_keys = aggro_model.load_state_dict(new_state_dict, strict=False)
# [์ค‘์š”] fc.weight๊ฐ€ ๋ˆ„๋ฝ(missing)๋˜์—ˆ๋Š”์ง€ ํ™•์ธ
if any("fc.weight" in key for key in missing_keys):
print("๐Ÿšจ [CRITICAL] fc ๋ ˆ์ด์–ด๊ฐ€ ์—ฌ์ „ํžˆ ๋กœ๋“œ๋˜์ง€ ์•Š์•˜์Šต๋‹ˆ๋‹ค! (์ ์ˆ˜ ๊ณ ์ • ์›์ธ)")
print(f"Missing Keys: {missing_keys}")
else:
print("โœ… [Success] fc ๋ ˆ์ด์–ด(๋ถ„๋ฅ˜๊ธฐ)๊ฐ€ ์ •์ƒ์ ์œผ๋กœ ๋กœ๋“œ๋˜์—ˆ์Šต๋‹ˆ๋‹ค!")
# if os.path.exists(final_path):
# state = torch.load(final_path, map_location=device)
# if isinstance(state, dict) and 'model_state_dict' in state:
# aggro_model.load_state_dict(state['model_state_dict'], strict=False)
# else:
# aggro_model.load_state_dict(state, strict=False)
aggro_model.eval()
print(f"โœ… [Aggro] KoBERT ๋ชจ๋ธ ๋กœ๋“œ ์™„๋ฃŒ: {os.path.basename(final_path)}")
else:
print("โš ๏ธ [Aggro] ๊ฐ€์ค‘์น˜ ํŒŒ์ผ(.pth/.pt)์„ ์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค!")
aggro_model = None
except Exception as e:
print(f"๐Ÿšจ [Aggro] ๋ชจ๋ธ ๋กœ๋”ฉ ์ค‘ ์—๋Ÿฌ ๋ฐœ์ƒ: {e}")
aggro_model = None
# =============================================================================
# 3. ๋ฉ”์ธ ํ•จ์ˆ˜
# =============================================================================
def get_aggro_score(title: str) -> dict:
# 1. ๊ทœ์น™ ์ ์ˆ˜ ๊ณ„์‚ฐ
rule_score = 0.0
rule_pattern = "๋ถ„์„ ๋ถˆ๊ฐ€"
try:
res = rule_scorer.get_score(title)
rule_score = res['score']
rule_pattern = res.get('pattern_name', '์•Œ ์ˆ˜ ์—†์Œ')
except: pass
# 2. KoBERT ์ ์ˆ˜ ๊ณ„์‚ฐ
bert_score = 0.0
if aggro_model and tokenizer:
try:
inputs = tokenizer(
title,
return_tensors='pt',
padding="max_length",
truncation=True,
max_length=64
)
# ๐Ÿ•ต๏ธโ€โ™‚๏ธ [ํ•ต์‹ฌ ๋””๋ฒ„๊ทธ] ํ† ํฐ์ด ์ œ๋Œ€๋กœ ๋งŒ๋“ค์–ด์กŒ๋Š”์ง€ ๋กœ๊ทธ์— ์ถœ๋ ฅ!
# ์ •์ƒ์ด๋ผ๋ฉด: [2, 4532, 1234, 3, ...] ์ฒ˜๋Ÿผ ๋‹ค์–‘ํ•œ ์ˆซ์ž๊ฐ€ ๋‚˜์™€์•ผ ํ•จ
# ์—๋Ÿฌ๋ผ๋ฉด: [2, 0, 0, 0, 3, ...] ์ฒ˜๋Ÿผ 0์ด ๊ฐ€๋“ํ•˜๊ฑฐ๋‚˜ [2, 3] ์ฒ˜๋Ÿผ ๋น„์–ด์žˆ์Œ
input_ids = inputs['input_ids'].to(device)
mask = inputs['attention_mask'].to(device)
oken_type_ids = inputs['token_type_ids'].to(device)
print(f"\n๐Ÿ‘‰ [ํ† ํฐ ํ™•์ธ] ์ž…๋ ฅ: '{title}'")
print(f"๐Ÿ‘‰ [ํ† ํฐ ID]: {input_ids[:15]} ...") # ์•ž๋ถ€๋ถ„ 15๊ฐœ๋งŒ ์ถœ๋ ฅ
with torch.no_grad():
outputs = aggro_model(input_ids, mask)
probs = F.softmax(outputs / 2.0, dim=1)
bert_score = probs[0][1].item() * 100
except Exception as e:
print(f"๐Ÿšจ [BERT Error] {e}")
bert_score = 50.0
# Safety Net
if rule_score < 5:
bert_score *= 0.3
elif rule_score < 20:
bert_score *= 0.8
#3. ํ•ฉ์‚ฐ
w_rule = 0.0
w_bert = 1.0
final_score = (rule_score * w_rule) + (bert_score * w_bert)
# 4. ๊ฒฐ๊ณผ
normalized_score = min(final_score / 100.0, 1.0)
# 5. ๋“ฑ๊ธ‰ ํŒ์ •
if final_score >= 80:
reason = f"๋งค์šฐ ๋†’์Œ ๐Ÿ”ด"
recommendation = "์ „๋ฉด ์ˆ˜์ • ๊ถŒ์žฅ"
elif final_score >= 60:
reason = f"๋†’์Œ ๐ŸŸ "
recommendation = "๊ณผ์žฅ๋œ ํ‘œํ˜„ ์ˆ˜์ • ํ•„์š”"
elif final_score >= 40:
reason = f"๋ณดํ†ต ๐ŸŸก"
recommendation = "์ผ๋ถ€ ํ‘œํ˜„ ์ค‘๋ฆฝํ™” ๊ถŒ์žฅ"
else:
reason = f"๋‚ฎ์Œ ๐ŸŸข"
recommendation = "์ ์ ˆํ•œ ์ œ๋ชฉ์ž…๋‹ˆ๋‹ค"
return {
"score": round(normalized_score, 4),
"reason": reason,
"recommendation": recommendation
}