File size: 4,864 Bytes
1db7196 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 | import os
import json
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "2"
from sentence_transformers import SentenceTransformer, util
import numpy as np
# Load a medical-friendly or general purpose transformer
model = SentenceTransformer('all-MiniLM-L6-v2')
def find_wiki_anchor_robust(doc_text, wiki_list, top_k=20):
doc_words = doc_text.split()
doc_len = len(doc_words)
# 1. Pre-process wiki_list into smaller chunks (paragraphs)
# so we match text segments of similar scale
wiki_chunks = []
for text in wiki_list:
# Split by double newline to get paragraphs
chunks = [p.strip() for p in text.split('\n\n') if len(p.split()) > 20]
wiki_chunks.extend(chunks)
# 2. Encode
doc_emb = model.encode(doc_text, convert_to_tensor=True)
chunk_embs = model.encode(wiki_chunks, convert_to_tensor=True)
# 3. Search more candidates (top_k=20) to find a good length match
hits = util.semantic_search(doc_emb, chunk_embs, top_k=top_k)[0]
# 4. Find the best match within a STRICTER length bound (e.g., +/- 20%)
for hit in hits:
candidate_text = wiki_chunks[hit['corpus_id']]
cand_len = len(candidate_text.split())
if 0.8 <= (cand_len / doc_len) <= 1.2:
return candidate_text
# Fallback: Pick the one with the closest word count from the top hits
closest_hit = min(hits, key=lambda x: abs(len(wiki_chunks[x['corpus_id']].split()) - doc_len))
return wiki_chunks[closest_hit['corpus_id']]
import textstat
def get_linguistic_metrics(text):
return {
"fkgl": textstat.flesch_kincaid_grade(text),
"gunning_fog": textstat.gunning_fog(text),
"smog_index": textstat.smog_index(text),
"word_count": len(text.split())
}
def get_lexical_complexity(text):
"""Simple Lexical Density: Content words / Total words"""
# This is useful for ESL/EFL metrics
words = text.lower().split()
# Simplified content word list (can be expanded with NLTK pos_tag)
return len(set(words)) / len(words) if len(words) > 0 else 0
import spacy
# Load the transformer-based model for higher accuracy in parsing
nlp = spacy.load("en_core_web_sm")
def get_parse_tree_stats(text):
doc = nlp(text)
depths = []
for sent in doc.sents:
def walk_tree(node, depth):
if not list(node.children):
return depth
return max(walk_tree(child, depth + 1) for child in node.children)
depths.append(walk_tree(sent.root, 1))
# Returns average depth across all sentences in the doc
return np.mean(depths) if depths else 0
import pandas as pd
processed_data = []
from datasets import load_dataset
ds = load_dataset("wikimedia/wikipedia", "20231101.en")
wiki_list=[item['text'] for item in ds['train']]
import json
with open("/home/mshahidul/readctrl/data/synthetic_dataset_diff_labels/syn_data_diff_labels_en_v1.json", "r") as f:
res = json.load(f)
# my_target_documents=[item['text'] for item in ds['train'].select(range(5))]
my_target_documents = []
save_path=f"/home/mshahidul/readctrl/data/data_annotator_data/crowdsourcing_input_en.json"
if os.path.exists(save_path):
with open(save_path, "r") as f:
processed_data = json.load(f)
for item in res:
for key,value in item['diff_label_texts'].items():
my_target_documents.append({
"index": item['index'],
"label": key,
"text": value
})
import tqdm
for doc in tqdm.tqdm(my_target_documents):
if any(d['index']==doc['index'] and d['label']==doc['label'] for d in processed_data):
print(f"Skipping already processed index {doc['index']} label {doc['label']}")
continue
# A. Find the Anchor
wiki_anchor = find_wiki_anchor_robust(doc['text'], wiki_list)
# B. Calculate Metrics for BOTH
doc_metrics = get_linguistic_metrics(doc['text'])
wiki_metrics = get_linguistic_metrics(wiki_anchor)
doc_parse = get_parse_tree_stats(doc['text'])
wiki_parse = get_parse_tree_stats(wiki_anchor)
# C. Store results
processed_data.append({
"index": doc['index'],
"label": doc['label'],
"original_doc": doc['text'],
"wiki_anchor": wiki_anchor,
"doc_fkgl": doc_metrics['fkgl'],
"wiki_fkgl": wiki_metrics['fkgl'],
"doc_tree_depth": doc_parse,
"wiki_tree_depth": wiki_parse,
"fkgl_delta": doc_metrics['fkgl'] - wiki_metrics['fkgl']
})
if len(processed_data) % 5 == 0:
with open(save_path, "w") as f:
json.dump(processed_data, f, indent=2)
print(f"Processed {len(processed_data)} documents so far.")
import json
with open(save_path, "w") as f:
json.dump(processed_data, f, indent=2) |