|
|
import pandas as pd |
|
|
from transformers import DistilBertTokenizer, DistilBertModel |
|
|
import torch |
|
|
import os |
|
|
|
|
|
|
|
|
dtype_spec = { |
|
|
'ISBN': str, |
|
|
'Book-Title': str, |
|
|
'Book-Author': str, |
|
|
'Year-Of-Publication': str, |
|
|
'Publisher': str, |
|
|
'Image-URL-S': str, |
|
|
'Image-URL-M': str, |
|
|
'Image-URL-L': str |
|
|
} |
|
|
|
|
|
|
|
|
books_df = pd.read_csv("books.csv", encoding='latin1', delimiter=';', on_bad_lines='skip', dtype=dtype_spec) |
|
|
|
|
|
|
|
|
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased') |
|
|
model = DistilBertModel.from_pretrained('distilbert-base-uncased') |
|
|
|
|
|
|
|
|
|
|
|
def get_bert_embedding(text): |
|
|
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512) |
|
|
with torch.no_grad(): |
|
|
outputs = model(**inputs) |
|
|
|
|
|
cls_embedding = outputs.last_hidden_state[0][0].numpy() |
|
|
return cls_embedding |
|
|
|
|
|
|
|
|
|
|
|
output_file = "books_with_embeddings.csv" |
|
|
progress_file = "progress.txt" |
|
|
|
|
|
|
|
|
if os.path.exists(progress_file): |
|
|
with open(progress_file, "r") as pf: |
|
|
start_idx = int(pf.read().strip()) |
|
|
print(f"Resuming from row {start_idx}.") |
|
|
else: |
|
|
start_idx = 0 |
|
|
print("Starting from the beginning.") |
|
|
|
|
|
batch_size = 10 |
|
|
total_rows = len(books_df) |
|
|
|
|
|
|
|
|
write_header = start_idx == 0 |
|
|
|
|
|
try: |
|
|
|
|
|
for idx in range(start_idx, total_rows, batch_size): |
|
|
|
|
|
batch_df = books_df.iloc[idx: idx + batch_size].copy() |
|
|
|
|
|
batch_df['embedding'] = batch_df['Book-Title'].apply(lambda title: get_bert_embedding(title)) |
|
|
|
|
|
batch_df['embedding'] = batch_df['embedding'].apply(lambda x: x.tolist()) |
|
|
|
|
|
|
|
|
batch_df.to_csv(output_file, mode='a', header=write_header, index=False) |
|
|
|
|
|
write_header = False |
|
|
|
|
|
|
|
|
next_idx = idx + batch_size |
|
|
with open(progress_file, "w") as pf: |
|
|
pf.write(str(next_idx)) |
|
|
|
|
|
print(f"Processed rows {idx} to {min(next_idx, total_rows)} out of {total_rows}.") |
|
|
|
|
|
except KeyboardInterrupt: |
|
|
print(f"Process interrupted at row {idx}. Progress saved in {progress_file}.") |
|
|
|