derek-thomas
commited on
Commit
·
3772eaf
1
Parent(s):
7d5ff0e
Move client instantiation
Browse files- src/utilities.py +5 -3
src/utilities.py
CHANGED
|
@@ -12,7 +12,6 @@ USERNAME = os.environ["USERNAME"]
|
|
| 12 |
OG_DATASET = f"{USERNAME}/dataset-creator-reddit-{SUBREDDIT}"
|
| 13 |
PROCESSED_DATASET = os.environ['PROCESSED_DATASET']
|
| 14 |
|
| 15 |
-
client = Client("derek-thomas/nomic-embeddings")
|
| 16 |
logger = setup_logger(__name__)
|
| 17 |
|
| 18 |
|
|
@@ -29,6 +28,9 @@ async def load_datasets():
|
|
| 29 |
|
| 30 |
|
| 31 |
def merge_and_update_datasets(dataset, original_dataset):
|
|
|
|
|
|
|
|
|
|
| 32 |
# Merge and figure out which rows need to be updated with embeddings
|
| 33 |
odf = original_dataset['train'].to_pandas()
|
| 34 |
df = dataset['train'].to_pandas()
|
|
@@ -50,13 +52,13 @@ def merge_and_update_datasets(dataset, original_dataset):
|
|
| 50 |
# Iterate over the DataFrame rows where 'embedding' is None
|
| 51 |
for index, row in merged_df[merged_df['embedding'].isnull()].iterrows():
|
| 52 |
# Update 'embedding' for the current row using our function
|
| 53 |
-
merged_df.at[index, 'embedding'] = update_embeddings(row['content'])
|
| 54 |
|
| 55 |
dataset['train'] = Dataset.from_pandas(merged_df)
|
| 56 |
logger.info(f"Updated {updated_rows} rows")
|
| 57 |
return dataset
|
| 58 |
|
| 59 |
|
| 60 |
-
def update_embeddings(content):
|
| 61 |
embedding = client.predict(content, api_name="/embed")
|
| 62 |
return np.array(embedding)
|
|
|
|
| 12 |
OG_DATASET = f"{USERNAME}/dataset-creator-reddit-{SUBREDDIT}"
|
| 13 |
PROCESSED_DATASET = os.environ['PROCESSED_DATASET']
|
| 14 |
|
|
|
|
| 15 |
logger = setup_logger(__name__)
|
| 16 |
|
| 17 |
|
|
|
|
| 28 |
|
| 29 |
|
| 30 |
def merge_and_update_datasets(dataset, original_dataset):
|
| 31 |
+
# Get client
|
| 32 |
+
client = Client("derek-thomas/nomic-embeddings")
|
| 33 |
+
|
| 34 |
# Merge and figure out which rows need to be updated with embeddings
|
| 35 |
odf = original_dataset['train'].to_pandas()
|
| 36 |
df = dataset['train'].to_pandas()
|
|
|
|
| 52 |
# Iterate over the DataFrame rows where 'embedding' is None
|
| 53 |
for index, row in merged_df[merged_df['embedding'].isnull()].iterrows():
|
| 54 |
# Update 'embedding' for the current row using our function
|
| 55 |
+
merged_df.at[index, 'embedding'] = update_embeddings(content=row['content'], client=client)
|
| 56 |
|
| 57 |
dataset['train'] = Dataset.from_pandas(merged_df)
|
| 58 |
logger.info(f"Updated {updated_rows} rows")
|
| 59 |
return dataset
|
| 60 |
|
| 61 |
|
| 62 |
+
def update_embeddings(content, client):
|
| 63 |
embedding = client.predict(content, api_name="/embed")
|
| 64 |
return np.array(embedding)
|