File size: 5,372 Bytes
723bbe6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99de885
 
 
 
 
 
723bbe6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99de885
723bbe6
 
 
 
 
 
 
 
 
99de885
723bbe6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8f6827d
723bbe6
8f6827d
723bbe6
 
 
8f6827d
723bbe6
 
 
 
 
 
 
145772b
723bbe6
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import re
import json
import torch

key_industry_types = [
    "Information Technology & Services",
    "Software & SaaS",
    "Artificial Intelligence & Machine Learning",
    "Cybersecurity & Data Protection",
    "Cloud Computing & DevOps",
    "Telecommunications & Networking",
    "Semiconductors & Electronics",
    "Computer Hardware & Manufacturing",
    "Internet & Digital Media",
    "E-commerce & Retail",
    "Finance & Banking",
    "Insurance & Risk Management",
    "Investment & Asset Management",
    "Real Estate & Construction",
    "Architecture & Urban Planning",
    "Automotive & Transportation",
    "Aerospace & Defense",
    "Energy & Utilities",
    "Oil, Gas & Mining",
    "Renewable Energy & Sustainability",
    "Manufacturing & Industrial Engineering",
    "Logistics, Supply Chain & Warehousing",
    "Consumer Goods & FMCG",
    "Food & Beverages",
    "Agriculture & AgriTech",
    "Pharmaceuticals & Biotechnology",
    "Healthcare & Medical Devices",
    "Education & EdTech",
    "Media & Entertainment",
    "Sports & Recreation",
    "Travel, Tourism & Hospitality",
    "Legal Services & Law Firms",
    "Accounting & Financial Services",
    "Human Resources & Staffing",
    "Consulting & Business Services",
    "Marketing, Advertising & PR",
    "Design, Art & Creative Services",
    "Nonprofit & Social Impact",
    "Government & Public Administration",
    "Environmental Services",
    "Research & Development",
    "Mining & Metals",
    "Textiles, Apparel & Fashion",
    "Chemicals & Materials",
    "Marine & Shipping",
    "Utilities & Waste Management",
    "Printing & Publishing",
    "Electronics Repair & Maintenance",
    "Aviation & Airlines"
]

def clean_company_name(name: str) -> str:
    name = name.lower()
    name = re.sub(r'\b(inc\.?|ltd\.?|llc\.?|co\.?|corp\.?|corporation|limited|pvt|private|public|plc|gmbh|s\.a\.|srl|bv|ag|oy|ab|spa|sas|sdn bhd|holdings|group|company|enterprises|technologies)\b', '', name)
    name = re.sub(r'[^a-z0-9\s]', '', name)
    return [re.sub(r'\s+', ' ', name).strip()]

def enhancer(data: object, embedder) -> list:
    """
    Enhances the data quality by removing duplicates
    """
    # def tokenize(text):
    #     return set(re.findall(r"[a-zA-Z0-9]+", text.lower()))
    
    companies = data.get("companies", [])
    num_of_companies = len(companies)
    duplicate_idx = set()
    duplicate_comps = []

    for i in range(num_of_companies):
        if i in duplicate_idx:
            continue

        c1 = companies[i]
        # name1 = tokenize(c1.get("company_name", ""))
        # ind1 = tokenize(c1.get("industry_type", ""))
        c1_name_embedding = embedder.encode(clean_company_name(c1.get("company_name", "")))
        c1_ind_embedding = embedder.encode([c1.get("industry_type", "")])
        # c1["ind_embedding"] = c1_ind_embedding
        c1_country = c1.get("country", "").lower().strip()
        # print(name1, ind1, country1)

        for j in range(i+1, num_of_companies):
            if j in duplicate_idx:
                continue
            c2 = companies[j]
            c2_name_embedding = embedder.encode(clean_company_name(c2.get("company_name", "")))
            c2_ind_embedding = embedder.encode([c2.get("industry_type", "")])
            # c2["ind_embedding"] = c2_ind_embedding
            c2_country = c2.get("country", "").lower().strip()

            name_sim = embedder.similarity(c1_name_embedding, c2_name_embedding).item()
            ind_sim = embedder.similarity(c1_ind_embedding, c2_ind_embedding).item()
            if name_sim >= 0.6 and ind_sim >= 0.6 and c1_country == c2_country:
                duplicate_idx.add(j)
                print(f"Duplicate found: {c1.get('company_name')} and {c2.get('company_name')}")
                print(f"Name similarity: {name_sim}, Industry similarity: {ind_sim}")

            # print(name2, ind2, country2)
        
            # name_intersection = len(name1 & name2)
            # industry_intersection = len(ind1 & ind2)

            # if (name_intersection >= max(len(name1), len(name2)) / 2 and
            #     industry_intersection >= max(len(ind1), len(ind2)) / 2 and
            #     country1 == country2):
                
            #     duplicate_idx.add(j)
            
    if duplicate_idx:
        print(duplicate_idx)
        duplicate_comps = [companies[i]["company_name"] for i in duplicate_idx]
        companies = [c for idx, c in enumerate(companies) if idx not in duplicate_idx]
        print(f"Removed {len(duplicate_idx)} duplicate entries.",flush=True)
    else:
        print("No duplicate entries found.",flush=True)
    
    print("Now adding the industry keys...")
    companies = add_ind_key(companies, embedder)
    print("Added Industry keys",flush=True)

    return [{"companies": companies}, {"duplicate_company_names": duplicate_comps}]




def add_ind_key(data: list, embedder) -> list:
    with open("src/data/key_industry_embeddings.json", "r") as f:
        key_ind_embs = json.load(f)["industry_embeddings"]
    for c in data:
        if "key_industry" not in c:
            comp_emb = embedder.encode([c["industry_type"]])
            sim_t = embedder.similarity(key_ind_embs, comp_emb)
            max_sim_idx = torch.argmax(sim_t).item()
            c["key_industry"] = key_industry_types[max_sim_idx]

    return data