import argparse import json from pathlib import Path from src.schemas.labels import ( MARKER_MODE, QA_M_MODE, QA_B_MODE, SENTIMENT_LABELS, ) MAXLEN_TO_WINDOW = { 128: 30, 256: 70, 512: 120, } def _extract_window(text: str, offset: int, length: int, window_words: int) -> tuple[str, int]: """Extract a word-level window around the entity span. Returns (window_text, window_offset) where window_offset is the entity's character offset within the windowed text. """ before = text[:offset] entity_span = text[offset:offset + length] after = text[offset + length:] words_before = before.split() words_after = after.split() if len(words_before) <= window_words: kept_before = before else: skip_chars = len(" ".join(words_before[:-window_words])) + 1 kept_before = text[skip_chars:offset] if len(words_after) <= window_words: kept_after = after else: kept_words = words_after[:window_words] end_char = offset + length + len(" ".join(kept_words)) if end_char < len(text) and text[end_char] != " ": space_pos = text.find(" ", end_char) end_char = space_pos if space_pos != -1 else len(text) kept_after = text[offset + length:end_char] window_offset = len(kept_before) window_text = kept_before + entity_span + kept_after return window_text, window_offset def _build_marker_text(window_text: str, window_offset: int, length: int) -> str: start = MARKER_MODE.entity_start end = MARKER_MODE.entity_end span = window_text[window_offset:window_offset + length] marked = ( window_text[:window_offset] + f" {start} {span} {end} " + window_text[window_offset + length:] ) return " ".join(marked.split()) def _build_qa_m_question(entity_text: str, entity_type: str) -> str: return QA_M_MODE.question_template.format(entity=entity_text, entity_type=entity_type) def _build_qa_b_hypotheses(entity_text: str, entity_type: str) -> dict[str, str]: return { sentiment: QA_B_MODE.hypothesis_template.format( entity=entity_text, entity_type=entity_type, sentiment=sentiment ) for sentiment in SENTIMENT_LABELS.classes } def augment_sample(sample: dict, window_words: int) -> dict: text = sample["text"] augmented_entities = [] for e in sample["entities"]: qa_m_question = _build_qa_m_question(e["entity_text"], e["entity_type"]) qa_b_hypotheses = _build_qa_b_hypotheses(e["entity_text"], e["entity_type"]) augmented_positions = [] for p in e["positions"]: window_text, window_offset = _extract_window( text, p["offset"], p["length"], window_words ) marker_text = _build_marker_text(window_text, window_offset, p["length"]) augmented_positions.append({ "position_text": p["position_text"], "offset": p["offset"], "length": p["length"], "entity_centered_window": window_text, "marker_text": marker_text, "qa_m_question": qa_m_question, "qa_b_hypotheses": qa_b_hypotheses, }) ent = { "entity_id": e["entity_id"], "entity_text": e["entity_text"], "entity_type": e["entity_type"], "positions": augmented_positions, } if "label" in e: ent["label"] = e["label"] augmented_entities.append(ent) return { "id": sample["id"], "text": text, "entities": augmented_entities, } def augment(samples: list[dict], window_words: int) -> list[dict]: augmented = [augment_sample(s, window_words) for s in samples] total_positions = sum( len(e["positions"]) for s in augmented for e in s["entities"] ) print(f"Augmented {len(augmented)} samples, {total_positions} positions (window={window_words} words/side)") return augmented def save_jsonl(samples: list[dict], path: str | Path) -> None: path = Path(path) path.parent.mkdir(parents=True, exist_ok=True) with open(path, "w", encoding="utf-8") as f: for s in samples: f.write(json.dumps(s, ensure_ascii=False) + "\n") print(f"Saved {len(samples)} samples to {path}") def main( input_path: str | Path = "data/data_preprocessed.jsonl", max_length: int = 256, ) -> list[dict]: if max_length not in MAXLEN_TO_WINDOW: raise ValueError(f"max_length must be one of {list(MAXLEN_TO_WINDOW.keys())}, got {max_length}") window_words = MAXLEN_TO_WINDOW[max_length] output_path = f"data/data_augmented_{max_length}.jsonl" with open(input_path, "r", encoding="utf-8") as f: samples = [json.loads(line) for line in f] print(f"Loaded {len(samples)} samples from {input_path}") print(f"max_length={max_length} -> window={window_words} words/side") augmented = augment(samples, window_words) save_jsonl(augmented, output_path) return augmented if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--max-length", type=int, default=256, choices=[128, 256, 512]) args = parser.parse_args() main(max_length=args.max_length)