File size: 5,340 Bytes
399f588
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
import argparse
import json
from pathlib import Path
from src.schemas.labels import (
    MARKER_MODE,
    QA_M_MODE,
    QA_B_MODE,
    SENTIMENT_LABELS,
)

MAXLEN_TO_WINDOW = {
    128: 30,
    256: 70,
    512: 120,
}


def _extract_window(text: str, offset: int, length: int, window_words: int) -> tuple[str, int]:
    """Extract a word-level window around the entity span.

    Returns (window_text, window_offset) where window_offset is the entity's
    character offset within the windowed text.
    """
    before = text[:offset]
    entity_span = text[offset:offset + length]
    after = text[offset + length:]

    words_before = before.split()
    words_after = after.split()

    if len(words_before) <= window_words:
        kept_before = before
    else:
        skip_chars = len(" ".join(words_before[:-window_words])) + 1
        kept_before = text[skip_chars:offset]

    if len(words_after) <= window_words:
        kept_after = after
    else:
        kept_words = words_after[:window_words]
        end_char = offset + length + len(" ".join(kept_words))
        if end_char < len(text) and text[end_char] != " ":
            space_pos = text.find(" ", end_char)
            end_char = space_pos if space_pos != -1 else len(text)
        kept_after = text[offset + length:end_char]

    window_offset = len(kept_before)
    window_text = kept_before + entity_span + kept_after
    return window_text, window_offset


def _build_marker_text(window_text: str, window_offset: int, length: int) -> str:
    start = MARKER_MODE.entity_start
    end = MARKER_MODE.entity_end
    span = window_text[window_offset:window_offset + length]
    marked = (
        window_text[:window_offset]
        + f" {start} {span} {end} "
        + window_text[window_offset + length:]
    )
    return " ".join(marked.split())


def _build_qa_m_question(entity_text: str, entity_type: str) -> str:
    return QA_M_MODE.question_template.format(entity=entity_text, entity_type=entity_type)


def _build_qa_b_hypotheses(entity_text: str, entity_type: str) -> dict[str, str]:
    return {
        sentiment: QA_B_MODE.hypothesis_template.format(
            entity=entity_text, entity_type=entity_type, sentiment=sentiment
        )
        for sentiment in SENTIMENT_LABELS.classes
    }


def augment_sample(sample: dict, window_words: int) -> dict:
    text = sample["text"]
    augmented_entities = []

    for e in sample["entities"]:
        qa_m_question = _build_qa_m_question(e["entity_text"], e["entity_type"])
        qa_b_hypotheses = _build_qa_b_hypotheses(e["entity_text"], e["entity_type"])
        augmented_positions = []

        for p in e["positions"]:
            window_text, window_offset = _extract_window(
                text, p["offset"], p["length"], window_words
            )
            marker_text = _build_marker_text(window_text, window_offset, p["length"])

            augmented_positions.append({
                "position_text": p["position_text"],
                "offset": p["offset"],
                "length": p["length"],
                "entity_centered_window": window_text,
                "marker_text": marker_text,
                "qa_m_question": qa_m_question,
                "qa_b_hypotheses": qa_b_hypotheses,
            })

        ent = {
            "entity_id": e["entity_id"],
            "entity_text": e["entity_text"],
            "entity_type": e["entity_type"],
            "positions": augmented_positions,
        }
        if "label" in e:
            ent["label"] = e["label"]
        augmented_entities.append(ent)

    return {
        "id": sample["id"],
        "text": text,
        "entities": augmented_entities,
    }


def augment(samples: list[dict], window_words: int) -> list[dict]:
    augmented = [augment_sample(s, window_words) for s in samples]
    total_positions = sum(
        len(e["positions"]) for s in augmented for e in s["entities"]
    )
    print(f"Augmented {len(augmented)} samples, {total_positions} positions (window={window_words} words/side)")
    return augmented


def save_jsonl(samples: list[dict], path: str | Path) -> None:
    path = Path(path)
    path.parent.mkdir(parents=True, exist_ok=True)
    with open(path, "w", encoding="utf-8") as f:
        for s in samples:
            f.write(json.dumps(s, ensure_ascii=False) + "\n")
    print(f"Saved {len(samples)} samples to {path}")


def main(
    input_path: str | Path = "data/data_preprocessed.jsonl",
    max_length: int = 256,
) -> list[dict]:
    if max_length not in MAXLEN_TO_WINDOW:
        raise ValueError(f"max_length must be one of {list(MAXLEN_TO_WINDOW.keys())}, got {max_length}")

    window_words = MAXLEN_TO_WINDOW[max_length]
    output_path = f"data/data_augmented_{max_length}.jsonl"

    with open(input_path, "r", encoding="utf-8") as f:
        samples = [json.loads(line) for line in f]

    print(f"Loaded {len(samples)} samples from {input_path}")
    print(f"max_length={max_length} -> window={window_words} words/side")
    augmented = augment(samples, window_words)
    save_jsonl(augmented, output_path)
    return augmented


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--max-length", type=int, default=256, choices=[128, 256, 512])
    args = parser.parse_args()
    main(max_length=args.max_length)