File size: 4,619 Bytes
3d36724
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
"""

HuggingFaceμ—μ„œ λͺ¨λΈμ„ λ‘œλ“œν•˜μ—¬ μΆ”λ‘ ν•˜λŠ” 예제



Usage:

    from inference_example import extract_sentences

    results = extract_sentences("μ‚Όμ„±μ „μžμ˜ 싀적이 μ‹œμž₯ μ˜ˆμƒμ„ μƒνšŒν–ˆλ‹€. ...")

"""

import re
from typing import List, Dict

import torch
from transformers import AutoTokenizer

from model import (
    DocumentEncoderConfig,
    DocumentEncoderForExtractiveSummarization,
    IDX_TO_ROLE,
)


def split_into_sentences(text: str) -> List[str]:
    sentences = re.split(r"(?<=[.!?])\s+", text.strip())
    return [s.strip() for s in sentences if s.strip()]


def extract_sentences(

    text: str,

    model_name_or_path: str = "./",  # 둜컬 λ˜λŠ” HuggingFace repo ID

    top_k: int = 3,

    threshold: float = 0.5,

    device: str = None,

) -> Dict:
    """

    ν…μŠ€νŠΈμ—μ„œ λŒ€ν‘œλ¬Έμž₯을 μΆ”μΆœν•˜κ³  역할을 λΆ„λ₯˜ν•©λ‹ˆλ‹€.



    Args:

        text: μž…λ ₯ ν…μŠ€νŠΈ (금육 리포트 λ“±)

        model_name_or_path: λͺ¨λΈ 경둜 λ˜λŠ” HuggingFace repo ID

        top_k: μΆ”μΆœν•  μ΅œλŒ€ λ¬Έμž₯ 수

        threshold: λŒ€ν‘œλ¬Έμž₯ νŒλ‹¨ μž„κ³„κ°’

        device: cuda λ˜λŠ” cpu



    Returns:

        dict with 'sentences', 'all_scores', 'all_roles', 'selected'

    """
    device = device or ("cuda" if torch.cuda.is_available() else "cpu")

    config = DocumentEncoderConfig.from_pretrained(model_name_or_path)
    model = DocumentEncoderForExtractiveSummarization.from_pretrained(
        model_name_or_path, config=config
    )
    model = model.to(device)
    model.eval()

    tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)

    sentences = split_into_sentences(text)
    if not sentences:
        return {"sentences": [], "all_scores": [], "all_roles": [], "selected": []}

    max_sentences = config.max_sentences
    max_length = config.max_length

    padded = sentences[:max_sentences]
    num_real = len(padded)
    while len(padded) < max_sentences:
        padded.append("")

    all_input_ids, all_attention_mask = [], []
    for s in padded:
        if s:
            enc = tokenizer(s, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt")
        else:
            enc = {
                "input_ids": torch.zeros(1, max_length, dtype=torch.long),
                "attention_mask": torch.zeros(1, max_length, dtype=torch.long),
            }
        all_input_ids.append(enc["input_ids"])
        all_attention_mask.append(enc["attention_mask"])

    input_ids = torch.cat(all_input_ids, dim=0).unsqueeze(0).to(device)
    attention_mask = torch.cat(all_attention_mask, dim=0).unsqueeze(0).to(device)
    document_mask = torch.zeros(1, max_sentences, device=device)
    document_mask[0, :num_real] = 1

    with torch.no_grad():
        scores, role_logits = model(input_ids, attention_mask, document_mask)

    scores_list = scores[0, :num_real].tolist()
    role_indices = role_logits[0, :num_real].argmax(dim=-1).tolist()
    roles_list = [IDX_TO_ROLE[idx] for idx in role_indices]

    selected = []
    for i, (sent, score, role) in enumerate(zip(sentences, scores_list, roles_list)):
        if score >= threshold:
            selected.append({"index": i, "sentence": sent, "score": score, "role": role})

    selected.sort(key=lambda x: x["score"], reverse=True)
    selected = selected[:top_k]
    selected.sort(key=lambda x: x["index"])

    return {
        "sentences": sentences,
        "all_scores": scores_list,
        "all_roles": roles_list,
        "selected": selected,
    }


if __name__ == "__main__":
    text = """

    μ‚Όμ„±μ „μžμ˜ 2024λ…„ 4λΆ„κΈ° 싀적이 μ‹œμž₯ μ˜ˆμƒμ„ μƒνšŒν–ˆλ‹€.

    λ©”λͺ¨λ¦¬ λ°˜λ„μ²΄ 가격 μƒμŠΉμœΌλ‘œ μ˜μ—…μ΄μ΅μ΄ μ „λΆ„κΈ° λŒ€λΉ„ 30% μ¦κ°€ν–ˆλ‹€.

    HBM3E 양산이 λ³Έκ²©ν™”λ˜λ©΄μ„œ AI λ°˜λ„μ²΄ μ‹œμž₯ 점유율이 ν™•λŒ€λ  전망이닀.

    λ‹€λ§Œ, 쀑ꡭ μ‹œμž₯의 λΆˆν™•μ‹€μ„±μ΄ μ—¬μ „νžˆ 리슀크 μš”μΈμœΌλ‘œ μž‘μš©ν•˜κ³  μžˆλ‹€.

    νšŒμ‚¬λŠ” μ˜¬ν•΄ μ„€λΉ„ 투자λ₯Ό 20% ν™•λŒ€ν•  κ³„νšμ΄λ‹€.

    """

    result = extract_sentences(text, model_name_or_path="./")

    print("=" * 60)
    print("전체 λ¬Έμž₯ 뢄석:")
    for i, (s, sc, r) in enumerate(zip(result["sentences"], result["all_scores"], result["all_roles"])):
        marker = "*" if sc >= 0.5 else " "
        print(f"  {marker} {i+1}. [{sc:.4f}] [{r:10s}] {s}")

    print(f"\nμ„ νƒλœ λŒ€ν‘œλ¬Έμž₯:")
    for item in result["selected"]:
        print(f"  - [{item['score']:.4f}] [{item['role']:10s}] {item['sentence']}")