File size: 5,340 Bytes
399f588 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 | import argparse
import json
from pathlib import Path
from src.schemas.labels import (
MARKER_MODE,
QA_M_MODE,
QA_B_MODE,
SENTIMENT_LABELS,
)
MAXLEN_TO_WINDOW = {
128: 30,
256: 70,
512: 120,
}
def _extract_window(text: str, offset: int, length: int, window_words: int) -> tuple[str, int]:
"""Extract a word-level window around the entity span.
Returns (window_text, window_offset) where window_offset is the entity's
character offset within the windowed text.
"""
before = text[:offset]
entity_span = text[offset:offset + length]
after = text[offset + length:]
words_before = before.split()
words_after = after.split()
if len(words_before) <= window_words:
kept_before = before
else:
skip_chars = len(" ".join(words_before[:-window_words])) + 1
kept_before = text[skip_chars:offset]
if len(words_after) <= window_words:
kept_after = after
else:
kept_words = words_after[:window_words]
end_char = offset + length + len(" ".join(kept_words))
if end_char < len(text) and text[end_char] != " ":
space_pos = text.find(" ", end_char)
end_char = space_pos if space_pos != -1 else len(text)
kept_after = text[offset + length:end_char]
window_offset = len(kept_before)
window_text = kept_before + entity_span + kept_after
return window_text, window_offset
def _build_marker_text(window_text: str, window_offset: int, length: int) -> str:
start = MARKER_MODE.entity_start
end = MARKER_MODE.entity_end
span = window_text[window_offset:window_offset + length]
marked = (
window_text[:window_offset]
+ f" {start} {span} {end} "
+ window_text[window_offset + length:]
)
return " ".join(marked.split())
def _build_qa_m_question(entity_text: str, entity_type: str) -> str:
return QA_M_MODE.question_template.format(entity=entity_text, entity_type=entity_type)
def _build_qa_b_hypotheses(entity_text: str, entity_type: str) -> dict[str, str]:
return {
sentiment: QA_B_MODE.hypothesis_template.format(
entity=entity_text, entity_type=entity_type, sentiment=sentiment
)
for sentiment in SENTIMENT_LABELS.classes
}
def augment_sample(sample: dict, window_words: int) -> dict:
text = sample["text"]
augmented_entities = []
for e in sample["entities"]:
qa_m_question = _build_qa_m_question(e["entity_text"], e["entity_type"])
qa_b_hypotheses = _build_qa_b_hypotheses(e["entity_text"], e["entity_type"])
augmented_positions = []
for p in e["positions"]:
window_text, window_offset = _extract_window(
text, p["offset"], p["length"], window_words
)
marker_text = _build_marker_text(window_text, window_offset, p["length"])
augmented_positions.append({
"position_text": p["position_text"],
"offset": p["offset"],
"length": p["length"],
"entity_centered_window": window_text,
"marker_text": marker_text,
"qa_m_question": qa_m_question,
"qa_b_hypotheses": qa_b_hypotheses,
})
ent = {
"entity_id": e["entity_id"],
"entity_text": e["entity_text"],
"entity_type": e["entity_type"],
"positions": augmented_positions,
}
if "label" in e:
ent["label"] = e["label"]
augmented_entities.append(ent)
return {
"id": sample["id"],
"text": text,
"entities": augmented_entities,
}
def augment(samples: list[dict], window_words: int) -> list[dict]:
augmented = [augment_sample(s, window_words) for s in samples]
total_positions = sum(
len(e["positions"]) for s in augmented for e in s["entities"]
)
print(f"Augmented {len(augmented)} samples, {total_positions} positions (window={window_words} words/side)")
return augmented
def save_jsonl(samples: list[dict], path: str | Path) -> None:
path = Path(path)
path.parent.mkdir(parents=True, exist_ok=True)
with open(path, "w", encoding="utf-8") as f:
for s in samples:
f.write(json.dumps(s, ensure_ascii=False) + "\n")
print(f"Saved {len(samples)} samples to {path}")
def main(
input_path: str | Path = "data/data_preprocessed.jsonl",
max_length: int = 256,
) -> list[dict]:
if max_length not in MAXLEN_TO_WINDOW:
raise ValueError(f"max_length must be one of {list(MAXLEN_TO_WINDOW.keys())}, got {max_length}")
window_words = MAXLEN_TO_WINDOW[max_length]
output_path = f"data/data_augmented_{max_length}.jsonl"
with open(input_path, "r", encoding="utf-8") as f:
samples = [json.loads(line) for line in f]
print(f"Loaded {len(samples)} samples from {input_path}")
print(f"max_length={max_length} -> window={window_words} words/side")
augmented = augment(samples, window_words)
save_jsonl(augmented, output_path)
return augmented
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--max-length", type=int, default=256, choices=[128, 256, 512])
args = parser.parse_args()
main(max_length=args.max_length)
|