| import json |
| from pathlib import Path |
| from src.schemas.labels import LABEL_REMAP, SENTIMENT_LABELS |
| VALID_LABELS = set(SENTIMENT_LABELS.classes) |
|
|
|
|
| def _deduplicate_articles(samples: list[dict]) -> list[dict]: |
| """Remove samples with duplicate article text, keeping the first.""" |
| seen_texts = set() |
| deduped = [] |
| removed = 0 |
| for s in samples: |
| if s["text"] in seen_texts: |
| removed += 1 |
| continue |
| seen_texts.add(s["text"]) |
| deduped.append(s) |
| print(f"Deduplicated articles: removed {removed}, kept {len(deduped)}") |
| return deduped |
|
|
|
|
| def _remap_labels(samples: list[dict]) -> list[dict]: |
| """Remap non-standard labels and drop entities with unmappable labels.""" |
| mapping = LABEL_REMAP.mapping |
| remapped_count = 0 |
| dropped_count = 0 |
|
|
| for s in samples: |
| cleaned_entities = [] |
| for e in s["entities"]: |
| label = e["label"] |
| if label in VALID_LABELS: |
| cleaned_entities.append(e) |
| elif label in mapping: |
| e["label"] = mapping[label] |
| cleaned_entities.append(e) |
| remapped_count += 1 |
| else: |
| dropped_count += 1 |
| s["entities"] = cleaned_entities |
|
|
| print(f"Labels remapped: {remapped_count}, dropped (unmappable): {dropped_count}") |
| return samples |
|
|
|
|
| def _fix_position_text(samples: list[dict]) -> list[dict]: |
| """Overwrite position_text with actual span from article text when case differs.""" |
| fixed = 0 |
| for s in samples: |
| text = s["text"] |
| for e in s["entities"]: |
| for p in e["positions"]: |
| end = p["offset"] + p["length"] |
| if end > len(text): |
| continue |
| actual = text[p["offset"]:end] |
| if actual != p["position_text"] and actual.lower() == p["position_text"].lower(): |
| p["position_text"] = actual |
| fixed += 1 |
| print(f"Position text case mismatches fixed: {fixed}") |
| return samples |
|
|
|
|
| def _merge_entities(samples: list[dict]) -> list[dict]: |
| """Merge entity records sharing (entity_text, label) within a sample. |
| |
| Collects all positions into a single entity record per unique |
| (entity_text.lower(), label) pair. |
| """ |
| total_merged = 0 |
| different_label = 0 |
|
|
| for s in samples: |
| merged: dict[str, dict] = {} |
| label_seen: dict[str, str] = {} |
|
|
| for e in s["entities"]: |
| key = e["entity_text"].lower() |
|
|
| if key in label_seen and label_seen[key] != e["label"]: |
| different_label += 1 |
|
|
| merge_key = (key, e["label"]) |
| if merge_key in merged: |
| merged[merge_key]["positions"].extend(e["positions"]) |
| total_merged += 1 |
| else: |
| merged[merge_key] = { |
| "entity_id": e["entity_id"], |
| "entity_text": e["entity_text"], |
| "entity_type": e["entity_type"], |
| "positions": list(e["positions"]), |
| "label": e["label"], |
| } |
| label_seen[key] = e["label"] |
|
|
| s["entities"] = list(merged.values()) |
|
|
| print(f"Entities merged: {total_merged}, different-label pairs: {different_label}") |
| return samples |
|
|
|
|
| def _deduplicate_positions(samples: list[dict]) -> list[dict]: |
| """Remove exact-duplicate positions (same offset + length) within each entity.""" |
| removed = 0 |
| for s in samples: |
| for e in s["entities"]: |
| seen = set() |
| unique = [] |
| for p in e["positions"]: |
| span = (p["offset"], p["length"]) |
| if span in seen: |
| removed += 1 |
| continue |
| seen.add(span) |
| unique.append(p) |
| e["positions"] = unique |
| print(f"Exact-duplicate positions removed: {removed}") |
| return samples |
|
|
|
|
| def _resolve_same_offset(samples: list[dict]) -> list[dict]: |
| """For positions sharing the same offset, keep the longest span.""" |
| resolved = 0 |
| for s in samples: |
| for e in s["entities"]: |
| by_offset: dict[int, dict] = {} |
| for p in e["positions"]: |
| off = p["offset"] |
| if off not in by_offset or p["length"] > by_offset[off]["length"]: |
| if off in by_offset: |
| resolved += 1 |
| by_offset[off] = p |
| e["positions"] = sorted(by_offset.values(), key=lambda p: p["offset"]) |
| print(f"Same-offset positions resolved (kept longest): {resolved}") |
| return samples |
|
|
|
|
| def _resolve_partial_overlaps(samples: list[dict]) -> list[dict]: |
| """Resolve positions that partially overlap (different offset, shared characters). |
| |
| Positions are already sorted by offset. When two positions overlap, |
| the longer span is kept and the shorter one is discarded. |
| """ |
| resolved = 0 |
| for s in samples: |
| for e in s["entities"]: |
| positions = sorted(e["positions"], key=lambda p: p["offset"]) |
| kept: list[dict] = [] |
| for p in positions: |
| if not kept: |
| kept.append(p) |
| continue |
| prev = kept[-1] |
| prev_end = prev["offset"] + prev["length"] |
| if p["offset"] < prev_end: |
| resolved += 1 |
| if p["length"] > prev["length"]: |
| kept[-1] = p |
| else: |
| kept.append(p) |
| e["positions"] = kept |
| print(f"Partial overlaps resolved (kept longest): {resolved}") |
| return samples |
|
|
|
|
| def preprocess(samples: list[dict]) -> list[dict]: |
| """Run the full preprocessing pipeline on raw samples.""" |
| samples = _deduplicate_articles(samples) |
| samples = _remap_labels(samples) |
| samples = _fix_position_text(samples) |
| samples = _merge_entities(samples) |
| samples = _deduplicate_positions(samples) |
| samples = _resolve_same_offset(samples) |
| samples = _resolve_partial_overlaps(samples) |
|
|
| total_entities = sum(len(s["entities"]) for s in samples) |
| print(f"Preprocessing complete: {len(samples)} samples, {total_entities} entities") |
| return samples |
|
|
|
|
| def save_jsonl(samples: list[dict], path: str | Path) -> None: |
| """Write preprocessed samples as JSONL (one JSON object per line).""" |
| path = Path(path) |
| path.parent.mkdir(parents=True, exist_ok=True) |
| with open(path, "w", encoding="utf-8") as f: |
| for s in samples: |
| f.write(json.dumps(s, ensure_ascii=False) + "\n") |
| print(f"Saved {len(samples)} samples to {path}") |
|
|
|
|
| def main( |
| input_path: str | Path = "data/data_raw.json", |
| output_path: str | Path = "data/data_preprocessed.jsonl", |
| ) -> list[dict]: |
| """Load raw data, preprocess, and save.""" |
| with open(input_path, "r", encoding="utf-8") as f: |
| raw = json.load(f) |
|
|
| print(f"Loaded {len(raw)} raw samples from {input_path}") |
| samples = preprocess(raw) |
| save_jsonl(samples, output_path) |
| return samples |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|