deepspeed / scripts /tools /add_pos_table_annotation_db.py
xingzhikb's picture
init
002bd9b
import click
import sqlite3
import logging
import spacy
import tqdm
import contextlib
import time
import torch
import os
import string
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
DISABLE_TIMEER = os.environ.get("DISABLE_TIMER", False)
@contextlib.contextmanager
def timer(timer_name="timer", pbar=None, pos=0):
if DISABLE_TIMEER:
return
start = time.time()
yield
end = time.time()
if pbar is not None:
pbar.display(f"Time taken in [{timer_name}]: {end - start:.3e}", pos=pos)
else:
logger.info(f"Time taken in [{timer_name}]: {end - start:.3e}")
def dict_factory(cursor, row):
# NOTE: now we will be returning rows as dictionaries instead of tuples
d = {}
for idx, col in enumerate(cursor.description):
d[col[0]] = row[idx]
return d
class RowDataset(torch.utils.data.Dataset):
def __init__(self, rows, nlp):
self.nlp = nlp
self.rows = rows
def __len__(self):
return len(self.rows)
def __getitem__(self, idx):
return self.rows[idx]
def _collate_fn(self, batch):
# phrases_batches = [row["phrases"].split("\t") for row in batch]
# NOTE: it is important to use '.' to split the noun chunks, as spacy uses it to determine the token pos.
all_phrases = [row["phrases"].replace("\t", ". ") for row in batch]
all_nouns, all_noun_chunks_ls = get_noun_and_noun_chunks(all_phrases, self.nlp)
if len(all_nouns) != len(batch) or len(all_noun_chunks_ls) != len(batch):
raise ValueError("Length of all_nouns and batch should be the same")
for sample, nouns, noun_chunks in zip(batch, all_nouns, all_noun_chunks_ls):
sample["nouns"] = nouns
sample["noun_chunks"] = noun_chunks
return batch
@click.command()
@click.option("--db", help="Path to the database file")
@click.option("--debug", is_flag=True, help="Debug mode")
@click.option("--batch_size", type=int, default=200_000, help="Batch size")
@click.option("--num_workers", type=int, default=12, help="Number of workers")
def main(db: str, debug, batch_size, num_workers):
nlp = spacy.load("en_core_web_lg", disable=["ner"])
conn = sqlite3.connect(db)
# NOTE: now we will be returning rows as dictionaries instead of tuples
conn.row_factory = dict_factory
cursor = conn.cursor()
table_names, table_schemas = get_tables_with_name_and_schema(cursor)
logger.info(f"Table Names: {table_names}")
for table_name in table_names:
if table_name.endswith("_extension"):
logger.info(f"Skipping table {table_name}, as it is already an extension")
continue
extract_pos_to_table(
nlp, cursor, conn, table_name, debug=debug, batch_size=batch_size, num_workers=num_workers
)
def extract_pos_to_table(nlp, cursor, conn, table_name, batch_size=200_00, num_workers=8, debug=False):
if table_name.endswith("_extension"):
logger.info(f"Skipping table {table_name}, as it is already an extension")
return
pos_table_name = f"{table_name}_pos_extension"
logger.info(f"Extracting POS from table {table_name} to table {pos_table_name}")
logger.info(f"Batch size: {batch_size}, num_workers: {num_workers}")
# Drop the extracted_nouns table if it exists
cursor.execute(f"DROP TABLE IF EXISTS {pos_table_name}")
conn.commit()
# Create a new table to store the extracted nouns
cursor.execute(
f"""
CREATE TABLE IF NOT EXISTS {pos_table_name} (
region_id INTEGER PRIMARY KEY,
nouns TEXT,
noun_chunks TEXT,
FOREIGN KEY (region_id) REFERENCES {table_name}(region_id)
)
"""
)
conn.commit()
cursor.execute(f"SELECT * FROM {table_name}" + (" LIMIT 10" if debug else ""))
rows = cursor.fetchall()
rows_dataset = RowDataset(rows, nlp)
rows_dataloader = torch.utils.data.DataLoader(
rows_dataset, batch_size=batch_size, num_workers=num_workers, collate_fn=rows_dataset._collate_fn
)
# pbar = tqdm.tqdm(rows, total=len(rows))
# pbar = tqdm.trange(0, len(rows), batch_size)
pbar = tqdm.tqdm(rows_dataloader, total=len(rows))
for batch in pbar:
for sample in batch:
region_id = sample["region_id"]
nouns = sample["nouns"]
nouns_str = "\t".join(nouns)
noun_chunks = sample["noun_chunks"]
noun_chunks_str = "\t".join(noun_chunks)
cursor.execute(
f"INSERT INTO {pos_table_name} (region_id, nouns, noun_chunks) VALUES (?, ?, ?)",
(region_id, nouns_str, noun_chunks_str),
)
conn.commit()
pbar.update(batch_size)
conn.commit()
logger.info(f"Finished extracting POS from table {table_name} to table {pos_table_name}")
def get_noun_and_noun_chunks(texts, nlp):
docs = nlp.pipe(texts)
noun_chunks_ls = []
nouns = []
for doc in docs:
nouns.append(normalize_nouns(doc))
noun_chunks_ls.append([chunk.text for chunk in doc.noun_chunks])
return nouns, noun_chunks_ls
def normalize_nouns(doc):
# NOTE: it is important to use '.' to split the noun chunks, as spacy uses it to determine the token pos.
normalized_nouns = [
token.lemma_.lower().strip(string.punctuation)
for token in doc
if token.pos_ == "NOUN" or token.pos_ == "PROPN"
]
return normalized_nouns
def get_tables_with_name_and_schema(cursor):
# Get the list of tables
cursor.execute("SELECT name, sql FROM sqlite_master WHERE type='table'")
# Print the table names and their schema
table_names = []
table_schemas = []
for result in cursor.fetchall():
table_name, table_schema = result["name"], result["sql"]
# table_name, table_schema = result
logger.info(f"Table Name: {table_name}")
logger.info(f"Table Schema: {table_schema}\n")
table_names.append(table_name)
table_schemas.append(table_schema)
return table_names, table_schemas
if __name__ == "__main__":
main()