File size: 6,161 Bytes
002bd9b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 |
import click
import sqlite3
import logging
import spacy
import tqdm
import contextlib
import time
import torch
import os
import string
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
DISABLE_TIMEER = os.environ.get("DISABLE_TIMER", False)
@contextlib.contextmanager
def timer(timer_name="timer", pbar=None, pos=0):
if DISABLE_TIMEER:
return
start = time.time()
yield
end = time.time()
if pbar is not None:
pbar.display(f"Time taken in [{timer_name}]: {end - start:.3e}", pos=pos)
else:
logger.info(f"Time taken in [{timer_name}]: {end - start:.3e}")
def dict_factory(cursor, row):
# NOTE: now we will be returning rows as dictionaries instead of tuples
d = {}
for idx, col in enumerate(cursor.description):
d[col[0]] = row[idx]
return d
class RowDataset(torch.utils.data.Dataset):
def __init__(self, rows, nlp):
self.nlp = nlp
self.rows = rows
def __len__(self):
return len(self.rows)
def __getitem__(self, idx):
return self.rows[idx]
def _collate_fn(self, batch):
# phrases_batches = [row["phrases"].split("\t") for row in batch]
# NOTE: it is important to use '.' to split the noun chunks, as spacy uses it to determine the token pos.
all_phrases = [row["phrases"].replace("\t", ". ") for row in batch]
all_nouns, all_noun_chunks_ls = get_noun_and_noun_chunks(all_phrases, self.nlp)
if len(all_nouns) != len(batch) or len(all_noun_chunks_ls) != len(batch):
raise ValueError("Length of all_nouns and batch should be the same")
for sample, nouns, noun_chunks in zip(batch, all_nouns, all_noun_chunks_ls):
sample["nouns"] = nouns
sample["noun_chunks"] = noun_chunks
return batch
@click.command()
@click.option("--db", help="Path to the database file")
@click.option("--debug", is_flag=True, help="Debug mode")
@click.option("--batch_size", type=int, default=200_000, help="Batch size")
@click.option("--num_workers", type=int, default=12, help="Number of workers")
def main(db: str, debug, batch_size, num_workers):
nlp = spacy.load("en_core_web_lg", disable=["ner"])
conn = sqlite3.connect(db)
# NOTE: now we will be returning rows as dictionaries instead of tuples
conn.row_factory = dict_factory
cursor = conn.cursor()
table_names, table_schemas = get_tables_with_name_and_schema(cursor)
logger.info(f"Table Names: {table_names}")
for table_name in table_names:
if table_name.endswith("_extension"):
logger.info(f"Skipping table {table_name}, as it is already an extension")
continue
extract_pos_to_table(
nlp, cursor, conn, table_name, debug=debug, batch_size=batch_size, num_workers=num_workers
)
def extract_pos_to_table(nlp, cursor, conn, table_name, batch_size=200_00, num_workers=8, debug=False):
if table_name.endswith("_extension"):
logger.info(f"Skipping table {table_name}, as it is already an extension")
return
pos_table_name = f"{table_name}_pos_extension"
logger.info(f"Extracting POS from table {table_name} to table {pos_table_name}")
logger.info(f"Batch size: {batch_size}, num_workers: {num_workers}")
# Drop the extracted_nouns table if it exists
cursor.execute(f"DROP TABLE IF EXISTS {pos_table_name}")
conn.commit()
# Create a new table to store the extracted nouns
cursor.execute(
f"""
CREATE TABLE IF NOT EXISTS {pos_table_name} (
region_id INTEGER PRIMARY KEY,
nouns TEXT,
noun_chunks TEXT,
FOREIGN KEY (region_id) REFERENCES {table_name}(region_id)
)
"""
)
conn.commit()
cursor.execute(f"SELECT * FROM {table_name}" + (" LIMIT 10" if debug else ""))
rows = cursor.fetchall()
rows_dataset = RowDataset(rows, nlp)
rows_dataloader = torch.utils.data.DataLoader(
rows_dataset, batch_size=batch_size, num_workers=num_workers, collate_fn=rows_dataset._collate_fn
)
# pbar = tqdm.tqdm(rows, total=len(rows))
# pbar = tqdm.trange(0, len(rows), batch_size)
pbar = tqdm.tqdm(rows_dataloader, total=len(rows))
for batch in pbar:
for sample in batch:
region_id = sample["region_id"]
nouns = sample["nouns"]
nouns_str = "\t".join(nouns)
noun_chunks = sample["noun_chunks"]
noun_chunks_str = "\t".join(noun_chunks)
cursor.execute(
f"INSERT INTO {pos_table_name} (region_id, nouns, noun_chunks) VALUES (?, ?, ?)",
(region_id, nouns_str, noun_chunks_str),
)
conn.commit()
pbar.update(batch_size)
conn.commit()
logger.info(f"Finished extracting POS from table {table_name} to table {pos_table_name}")
def get_noun_and_noun_chunks(texts, nlp):
docs = nlp.pipe(texts)
noun_chunks_ls = []
nouns = []
for doc in docs:
nouns.append(normalize_nouns(doc))
noun_chunks_ls.append([chunk.text for chunk in doc.noun_chunks])
return nouns, noun_chunks_ls
def normalize_nouns(doc):
# NOTE: it is important to use '.' to split the noun chunks, as spacy uses it to determine the token pos.
normalized_nouns = [
token.lemma_.lower().strip(string.punctuation)
for token in doc
if token.pos_ == "NOUN" or token.pos_ == "PROPN"
]
return normalized_nouns
def get_tables_with_name_and_schema(cursor):
# Get the list of tables
cursor.execute("SELECT name, sql FROM sqlite_master WHERE type='table'")
# Print the table names and their schema
table_names = []
table_schemas = []
for result in cursor.fetchall():
table_name, table_schema = result["name"], result["sql"]
# table_name, table_schema = result
logger.info(f"Table Name: {table_name}")
logger.info(f"Table Schema: {table_schema}\n")
table_names.append(table_name)
table_schemas.append(table_schema)
return table_names, table_schemas
if __name__ == "__main__":
main()
|