sv-task / src /models /augment.py
lamossta's picture
data preprocessing classes
399f588
import argparse
import json
from pathlib import Path
from src.schemas.labels import (
MARKER_MODE,
QA_M_MODE,
QA_B_MODE,
SENTIMENT_LABELS,
)
MAXLEN_TO_WINDOW = {
128: 30,
256: 70,
512: 120,
}
def _extract_window(text: str, offset: int, length: int, window_words: int) -> tuple[str, int]:
"""Extract a word-level window around the entity span.
Returns (window_text, window_offset) where window_offset is the entity's
character offset within the windowed text.
"""
before = text[:offset]
entity_span = text[offset:offset + length]
after = text[offset + length:]
words_before = before.split()
words_after = after.split()
if len(words_before) <= window_words:
kept_before = before
else:
skip_chars = len(" ".join(words_before[:-window_words])) + 1
kept_before = text[skip_chars:offset]
if len(words_after) <= window_words:
kept_after = after
else:
kept_words = words_after[:window_words]
end_char = offset + length + len(" ".join(kept_words))
if end_char < len(text) and text[end_char] != " ":
space_pos = text.find(" ", end_char)
end_char = space_pos if space_pos != -1 else len(text)
kept_after = text[offset + length:end_char]
window_offset = len(kept_before)
window_text = kept_before + entity_span + kept_after
return window_text, window_offset
def _build_marker_text(window_text: str, window_offset: int, length: int) -> str:
start = MARKER_MODE.entity_start
end = MARKER_MODE.entity_end
span = window_text[window_offset:window_offset + length]
marked = (
window_text[:window_offset]
+ f" {start} {span} {end} "
+ window_text[window_offset + length:]
)
return " ".join(marked.split())
def _build_qa_m_question(entity_text: str, entity_type: str) -> str:
return QA_M_MODE.question_template.format(entity=entity_text, entity_type=entity_type)
def _build_qa_b_hypotheses(entity_text: str, entity_type: str) -> dict[str, str]:
return {
sentiment: QA_B_MODE.hypothesis_template.format(
entity=entity_text, entity_type=entity_type, sentiment=sentiment
)
for sentiment in SENTIMENT_LABELS.classes
}
def augment_sample(sample: dict, window_words: int) -> dict:
text = sample["text"]
augmented_entities = []
for e in sample["entities"]:
qa_m_question = _build_qa_m_question(e["entity_text"], e["entity_type"])
qa_b_hypotheses = _build_qa_b_hypotheses(e["entity_text"], e["entity_type"])
augmented_positions = []
for p in e["positions"]:
window_text, window_offset = _extract_window(
text, p["offset"], p["length"], window_words
)
marker_text = _build_marker_text(window_text, window_offset, p["length"])
augmented_positions.append({
"position_text": p["position_text"],
"offset": p["offset"],
"length": p["length"],
"entity_centered_window": window_text,
"marker_text": marker_text,
"qa_m_question": qa_m_question,
"qa_b_hypotheses": qa_b_hypotheses,
})
ent = {
"entity_id": e["entity_id"],
"entity_text": e["entity_text"],
"entity_type": e["entity_type"],
"positions": augmented_positions,
}
if "label" in e:
ent["label"] = e["label"]
augmented_entities.append(ent)
return {
"id": sample["id"],
"text": text,
"entities": augmented_entities,
}
def augment(samples: list[dict], window_words: int) -> list[dict]:
augmented = [augment_sample(s, window_words) for s in samples]
total_positions = sum(
len(e["positions"]) for s in augmented for e in s["entities"]
)
print(f"Augmented {len(augmented)} samples, {total_positions} positions (window={window_words} words/side)")
return augmented
def save_jsonl(samples: list[dict], path: str | Path) -> None:
path = Path(path)
path.parent.mkdir(parents=True, exist_ok=True)
with open(path, "w", encoding="utf-8") as f:
for s in samples:
f.write(json.dumps(s, ensure_ascii=False) + "\n")
print(f"Saved {len(samples)} samples to {path}")
def main(
input_path: str | Path = "data/data_preprocessed.jsonl",
max_length: int = 256,
) -> list[dict]:
if max_length not in MAXLEN_TO_WINDOW:
raise ValueError(f"max_length must be one of {list(MAXLEN_TO_WINDOW.keys())}, got {max_length}")
window_words = MAXLEN_TO_WINDOW[max_length]
output_path = f"data/data_augmented_{max_length}.jsonl"
with open(input_path, "r", encoding="utf-8") as f:
samples = [json.loads(line) for line in f]
print(f"Loaded {len(samples)} samples from {input_path}")
print(f"max_length={max_length} -> window={window_words} words/side")
augmented = augment(samples, window_words)
save_jsonl(augmented, output_path)
return augmented
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--max-length", type=int, default=256, choices=[128, 256, 512])
args = parser.parse_args()
main(max_length=args.max_length)