| import argparse |
| import json |
| from pathlib import Path |
| from src.schemas.labels import ( |
| MARKER_MODE, |
| QA_M_MODE, |
| QA_B_MODE, |
| SENTIMENT_LABELS, |
| ) |
|
|
| MAXLEN_TO_WINDOW = { |
| 128: 30, |
| 256: 70, |
| 512: 120, |
| } |
|
|
|
|
| def _extract_window(text: str, offset: int, length: int, window_words: int) -> tuple[str, int]: |
| """Extract a word-level window around the entity span. |
| |
| Returns (window_text, window_offset) where window_offset is the entity's |
| character offset within the windowed text. |
| """ |
| before = text[:offset] |
| entity_span = text[offset:offset + length] |
| after = text[offset + length:] |
|
|
| words_before = before.split() |
| words_after = after.split() |
|
|
| if len(words_before) <= window_words: |
| kept_before = before |
| else: |
| skip_chars = len(" ".join(words_before[:-window_words])) + 1 |
| kept_before = text[skip_chars:offset] |
|
|
| if len(words_after) <= window_words: |
| kept_after = after |
| else: |
| kept_words = words_after[:window_words] |
| end_char = offset + length + len(" ".join(kept_words)) |
| if end_char < len(text) and text[end_char] != " ": |
| space_pos = text.find(" ", end_char) |
| end_char = space_pos if space_pos != -1 else len(text) |
| kept_after = text[offset + length:end_char] |
|
|
| window_offset = len(kept_before) |
| window_text = kept_before + entity_span + kept_after |
| return window_text, window_offset |
|
|
|
|
| def _build_marker_text(window_text: str, window_offset: int, length: int) -> str: |
| start = MARKER_MODE.entity_start |
| end = MARKER_MODE.entity_end |
| span = window_text[window_offset:window_offset + length] |
| marked = ( |
| window_text[:window_offset] |
| + f" {start} {span} {end} " |
| + window_text[window_offset + length:] |
| ) |
| return " ".join(marked.split()) |
|
|
|
|
| def _build_qa_m_question(entity_text: str, entity_type: str) -> str: |
| return QA_M_MODE.question_template.format(entity=entity_text, entity_type=entity_type) |
|
|
|
|
| def _build_qa_b_hypotheses(entity_text: str, entity_type: str) -> dict[str, str]: |
| return { |
| sentiment: QA_B_MODE.hypothesis_template.format( |
| entity=entity_text, entity_type=entity_type, sentiment=sentiment |
| ) |
| for sentiment in SENTIMENT_LABELS.classes |
| } |
|
|
|
|
| def augment_sample(sample: dict, window_words: int) -> dict: |
| text = sample["text"] |
| augmented_entities = [] |
|
|
| for e in sample["entities"]: |
| qa_m_question = _build_qa_m_question(e["entity_text"], e["entity_type"]) |
| qa_b_hypotheses = _build_qa_b_hypotheses(e["entity_text"], e["entity_type"]) |
| augmented_positions = [] |
|
|
| for p in e["positions"]: |
| window_text, window_offset = _extract_window( |
| text, p["offset"], p["length"], window_words |
| ) |
| marker_text = _build_marker_text(window_text, window_offset, p["length"]) |
|
|
| augmented_positions.append({ |
| "position_text": p["position_text"], |
| "offset": p["offset"], |
| "length": p["length"], |
| "entity_centered_window": window_text, |
| "marker_text": marker_text, |
| "qa_m_question": qa_m_question, |
| "qa_b_hypotheses": qa_b_hypotheses, |
| }) |
|
|
| ent = { |
| "entity_id": e["entity_id"], |
| "entity_text": e["entity_text"], |
| "entity_type": e["entity_type"], |
| "positions": augmented_positions, |
| } |
| if "label" in e: |
| ent["label"] = e["label"] |
| augmented_entities.append(ent) |
|
|
| return { |
| "id": sample["id"], |
| "text": text, |
| "entities": augmented_entities, |
| } |
|
|
|
|
| def augment(samples: list[dict], window_words: int) -> list[dict]: |
| augmented = [augment_sample(s, window_words) for s in samples] |
| total_positions = sum( |
| len(e["positions"]) for s in augmented for e in s["entities"] |
| ) |
| print(f"Augmented {len(augmented)} samples, {total_positions} positions (window={window_words} words/side)") |
| return augmented |
|
|
|
|
| def save_jsonl(samples: list[dict], path: str | Path) -> None: |
| path = Path(path) |
| path.parent.mkdir(parents=True, exist_ok=True) |
| with open(path, "w", encoding="utf-8") as f: |
| for s in samples: |
| f.write(json.dumps(s, ensure_ascii=False) + "\n") |
| print(f"Saved {len(samples)} samples to {path}") |
|
|
|
|
| def main( |
| input_path: str | Path = "data/data_preprocessed.jsonl", |
| max_length: int = 256, |
| ) -> list[dict]: |
| if max_length not in MAXLEN_TO_WINDOW: |
| raise ValueError(f"max_length must be one of {list(MAXLEN_TO_WINDOW.keys())}, got {max_length}") |
|
|
| window_words = MAXLEN_TO_WINDOW[max_length] |
| output_path = f"data/data_augmented_{max_length}.jsonl" |
|
|
| with open(input_path, "r", encoding="utf-8") as f: |
| samples = [json.loads(line) for line in f] |
|
|
| print(f"Loaded {len(samples)} samples from {input_path}") |
| print(f"max_length={max_length} -> window={window_words} words/side") |
| augmented = augment(samples, window_words) |
| save_jsonl(augmented, output_path) |
| return augmented |
|
|
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--max-length", type=int, default=256, choices=[128, 256, 512]) |
| args = parser.parse_args() |
| main(max_length=args.max_length) |
|
|