SmartLead / src /post_extraction_tools /data_quality_enhancer.py
Subhajit Chakraborty
update files(6)
99de885
import re
import json
import torch
key_industry_types = [
"Information Technology & Services",
"Software & SaaS",
"Artificial Intelligence & Machine Learning",
"Cybersecurity & Data Protection",
"Cloud Computing & DevOps",
"Telecommunications & Networking",
"Semiconductors & Electronics",
"Computer Hardware & Manufacturing",
"Internet & Digital Media",
"E-commerce & Retail",
"Finance & Banking",
"Insurance & Risk Management",
"Investment & Asset Management",
"Real Estate & Construction",
"Architecture & Urban Planning",
"Automotive & Transportation",
"Aerospace & Defense",
"Energy & Utilities",
"Oil, Gas & Mining",
"Renewable Energy & Sustainability",
"Manufacturing & Industrial Engineering",
"Logistics, Supply Chain & Warehousing",
"Consumer Goods & FMCG",
"Food & Beverages",
"Agriculture & AgriTech",
"Pharmaceuticals & Biotechnology",
"Healthcare & Medical Devices",
"Education & EdTech",
"Media & Entertainment",
"Sports & Recreation",
"Travel, Tourism & Hospitality",
"Legal Services & Law Firms",
"Accounting & Financial Services",
"Human Resources & Staffing",
"Consulting & Business Services",
"Marketing, Advertising & PR",
"Design, Art & Creative Services",
"Nonprofit & Social Impact",
"Government & Public Administration",
"Environmental Services",
"Research & Development",
"Mining & Metals",
"Textiles, Apparel & Fashion",
"Chemicals & Materials",
"Marine & Shipping",
"Utilities & Waste Management",
"Printing & Publishing",
"Electronics Repair & Maintenance",
"Aviation & Airlines"
]
def clean_company_name(name: str) -> str:
name = name.lower()
name = re.sub(r'\b(inc\.?|ltd\.?|llc\.?|co\.?|corp\.?|corporation|limited|pvt|private|public|plc|gmbh|s\.a\.|srl|bv|ag|oy|ab|spa|sas|sdn bhd|holdings|group|company|enterprises|technologies)\b', '', name)
name = re.sub(r'[^a-z0-9\s]', '', name)
return [re.sub(r'\s+', ' ', name).strip()]
def enhancer(data: object, embedder) -> list:
"""
Enhances the data quality by removing duplicates
"""
# def tokenize(text):
# return set(re.findall(r"[a-zA-Z0-9]+", text.lower()))
companies = data.get("companies", [])
num_of_companies = len(companies)
duplicate_idx = set()
duplicate_comps = []
for i in range(num_of_companies):
if i in duplicate_idx:
continue
c1 = companies[i]
# name1 = tokenize(c1.get("company_name", ""))
# ind1 = tokenize(c1.get("industry_type", ""))
c1_name_embedding = embedder.encode(clean_company_name(c1.get("company_name", "")))
c1_ind_embedding = embedder.encode([c1.get("industry_type", "")])
# c1["ind_embedding"] = c1_ind_embedding
c1_country = c1.get("country", "").lower().strip()
# print(name1, ind1, country1)
for j in range(i+1, num_of_companies):
if j in duplicate_idx:
continue
c2 = companies[j]
c2_name_embedding = embedder.encode(clean_company_name(c2.get("company_name", "")))
c2_ind_embedding = embedder.encode([c2.get("industry_type", "")])
# c2["ind_embedding"] = c2_ind_embedding
c2_country = c2.get("country", "").lower().strip()
name_sim = embedder.similarity(c1_name_embedding, c2_name_embedding).item()
ind_sim = embedder.similarity(c1_ind_embedding, c2_ind_embedding).item()
if name_sim >= 0.6 and ind_sim >= 0.6 and c1_country == c2_country:
duplicate_idx.add(j)
print(f"Duplicate found: {c1.get('company_name')} and {c2.get('company_name')}")
print(f"Name similarity: {name_sim}, Industry similarity: {ind_sim}")
# print(name2, ind2, country2)
# name_intersection = len(name1 & name2)
# industry_intersection = len(ind1 & ind2)
# if (name_intersection >= max(len(name1), len(name2)) / 2 and
# industry_intersection >= max(len(ind1), len(ind2)) / 2 and
# country1 == country2):
# duplicate_idx.add(j)
if duplicate_idx:
print(duplicate_idx)
duplicate_comps = [companies[i]["company_name"] for i in duplicate_idx]
companies = [c for idx, c in enumerate(companies) if idx not in duplicate_idx]
print(f"Removed {len(duplicate_idx)} duplicate entries.",flush=True)
else:
print("No duplicate entries found.",flush=True)
print("Now adding the industry keys...")
companies = add_ind_key(companies, embedder)
print("Added Industry keys",flush=True)
return [{"companies": companies}, {"duplicate_company_names": duplicate_comps}]
def add_ind_key(data: list, embedder) -> list:
with open("src/data/key_industry_embeddings.json", "r") as f:
key_ind_embs = json.load(f)["industry_embeddings"]
for c in data:
if "key_industry" not in c:
comp_emb = embedder.encode([c["industry_type"]])
sim_t = embedder.similarity(key_ind_embs, comp_emb)
max_sim_idx = torch.argmax(sim_t).item()
c["key_industry"] = key_industry_types[max_sim_idx]
return data