import re import json import torch key_industry_types = [ "Information Technology & Services", "Software & SaaS", "Artificial Intelligence & Machine Learning", "Cybersecurity & Data Protection", "Cloud Computing & DevOps", "Telecommunications & Networking", "Semiconductors & Electronics", "Computer Hardware & Manufacturing", "Internet & Digital Media", "E-commerce & Retail", "Finance & Banking", "Insurance & Risk Management", "Investment & Asset Management", "Real Estate & Construction", "Architecture & Urban Planning", "Automotive & Transportation", "Aerospace & Defense", "Energy & Utilities", "Oil, Gas & Mining", "Renewable Energy & Sustainability", "Manufacturing & Industrial Engineering", "Logistics, Supply Chain & Warehousing", "Consumer Goods & FMCG", "Food & Beverages", "Agriculture & AgriTech", "Pharmaceuticals & Biotechnology", "Healthcare & Medical Devices", "Education & EdTech", "Media & Entertainment", "Sports & Recreation", "Travel, Tourism & Hospitality", "Legal Services & Law Firms", "Accounting & Financial Services", "Human Resources & Staffing", "Consulting & Business Services", "Marketing, Advertising & PR", "Design, Art & Creative Services", "Nonprofit & Social Impact", "Government & Public Administration", "Environmental Services", "Research & Development", "Mining & Metals", "Textiles, Apparel & Fashion", "Chemicals & Materials", "Marine & Shipping", "Utilities & Waste Management", "Printing & Publishing", "Electronics Repair & Maintenance", "Aviation & Airlines" ] def clean_company_name(name: str) -> str: name = name.lower() name = re.sub(r'\b(inc\.?|ltd\.?|llc\.?|co\.?|corp\.?|corporation|limited|pvt|private|public|plc|gmbh|s\.a\.|srl|bv|ag|oy|ab|spa|sas|sdn bhd|holdings|group|company|enterprises|technologies)\b', '', name) name = re.sub(r'[^a-z0-9\s]', '', name) return [re.sub(r'\s+', ' ', name).strip()] def enhancer(data: object, embedder) -> list: """ Enhances the data quality by removing duplicates """ # def tokenize(text): # return set(re.findall(r"[a-zA-Z0-9]+", text.lower())) companies = data.get("companies", []) num_of_companies = len(companies) duplicate_idx = set() duplicate_comps = [] for i in range(num_of_companies): if i in duplicate_idx: continue c1 = companies[i] # name1 = tokenize(c1.get("company_name", "")) # ind1 = tokenize(c1.get("industry_type", "")) c1_name_embedding = embedder.encode(clean_company_name(c1.get("company_name", ""))) c1_ind_embedding = embedder.encode([c1.get("industry_type", "")]) # c1["ind_embedding"] = c1_ind_embedding c1_country = c1.get("country", "").lower().strip() # print(name1, ind1, country1) for j in range(i+1, num_of_companies): if j in duplicate_idx: continue c2 = companies[j] c2_name_embedding = embedder.encode(clean_company_name(c2.get("company_name", ""))) c2_ind_embedding = embedder.encode([c2.get("industry_type", "")]) # c2["ind_embedding"] = c2_ind_embedding c2_country = c2.get("country", "").lower().strip() name_sim = embedder.similarity(c1_name_embedding, c2_name_embedding).item() ind_sim = embedder.similarity(c1_ind_embedding, c2_ind_embedding).item() if name_sim >= 0.6 and ind_sim >= 0.6 and c1_country == c2_country: duplicate_idx.add(j) print(f"Duplicate found: {c1.get('company_name')} and {c2.get('company_name')}") print(f"Name similarity: {name_sim}, Industry similarity: {ind_sim}") # print(name2, ind2, country2) # name_intersection = len(name1 & name2) # industry_intersection = len(ind1 & ind2) # if (name_intersection >= max(len(name1), len(name2)) / 2 and # industry_intersection >= max(len(ind1), len(ind2)) / 2 and # country1 == country2): # duplicate_idx.add(j) if duplicate_idx: print(duplicate_idx) duplicate_comps = [companies[i]["company_name"] for i in duplicate_idx] companies = [c for idx, c in enumerate(companies) if idx not in duplicate_idx] print(f"Removed {len(duplicate_idx)} duplicate entries.",flush=True) else: print("No duplicate entries found.",flush=True) print("Now adding the industry keys...") companies = add_ind_key(companies, embedder) print("Added Industry keys",flush=True) return [{"companies": companies}, {"duplicate_company_names": duplicate_comps}] def add_ind_key(data: list, embedder) -> list: with open("src/data/key_industry_embeddings.json", "r") as f: key_ind_embs = json.load(f)["industry_embeddings"] for c in data: if "key_industry" not in c: comp_emb = embedder.encode([c["industry_type"]]) sim_t = embedder.similarity(key_ind_embs, comp_emb) max_sim_idx = torch.argmax(sim_t).item() c["key_industry"] = key_industry_types[max_sim_idx] return data