Spaces:
Sleeping
Sleeping
Update classification.py
Browse files- classification.py +4 -3
classification.py
CHANGED
|
@@ -173,12 +173,12 @@ def process_categories(categories, model):
|
|
| 173 |
def match_categories(df, category_df, treshold=0.45):
|
| 174 |
for topic in category_df['topic']:
|
| 175 |
df[topic] = 0
|
| 176 |
-
for
|
| 177 |
if isinstance(ebd_content, torch.Tensor):
|
| 178 |
cos_scores = util.cos_sim(ebd_content, torch.stack(list(category_df['Embeddings']), dim=0))[0]
|
| 179 |
high_score_indices = [i for i, score in enumerate(cos_scores) if score > treshold]
|
| 180 |
for j in high_score_indices:
|
| 181 |
-
df.loc[
|
| 182 |
return df
|
| 183 |
|
| 184 |
def save_data(df, filename):
|
|
@@ -193,9 +193,10 @@ def classification(column, file_path, categories, treshold):
|
|
| 193 |
|
| 194 |
# Initialize models
|
| 195 |
model_ST = initialize_models()
|
| 196 |
-
|
| 197 |
# Generate embeddings for df
|
| 198 |
df = generate_embeddings(df, model_ST, column)
|
|
|
|
| 199 |
|
| 200 |
|
| 201 |
category_df = process_categories(categories, model_ST)
|
|
|
|
| 173 |
def match_categories(df, category_df, treshold=0.45):
|
| 174 |
for topic in category_df['topic']:
|
| 175 |
df[topic] = 0
|
| 176 |
+
for index, ebd_content in enumerate(df['Embeddings']):
|
| 177 |
if isinstance(ebd_content, torch.Tensor):
|
| 178 |
cos_scores = util.cos_sim(ebd_content, torch.stack(list(category_df['Embeddings']), dim=0))[0]
|
| 179 |
high_score_indices = [i for i, score in enumerate(cos_scores) if score > treshold]
|
| 180 |
for j in high_score_indices:
|
| 181 |
+
df.loc[index, category_df.loc[j, 'topic']] = 'float(cos_scores[j])'
|
| 182 |
return df
|
| 183 |
|
| 184 |
def save_data(df, filename):
|
|
|
|
| 193 |
|
| 194 |
# Initialize models
|
| 195 |
model_ST = initialize_models()
|
| 196 |
+
print('Generating Embeddings')
|
| 197 |
# Generate embeddings for df
|
| 198 |
df = generate_embeddings(df, model_ST, column)
|
| 199 |
+
print('Embeddings Generated')
|
| 200 |
|
| 201 |
|
| 202 |
category_df = process_categories(categories, model_ST)
|