Spaces:
Runtime error
Runtime error
| import os | |
| import PyPDF2 | |
| import pandas as pd | |
| import warnings | |
| import re | |
| from transformers import DPRContextEncoder, DPRContextEncoderTokenizer | |
| from transformers import DPRQuestionEncoder, DPRQuestionEncoderTokenizer | |
| import torch | |
| import gradio as gr | |
| from typing import Union | |
| import numpy as np | |
| from cassandra.cluster import Cluster | |
| from cassandra.auth import PlainTextAuthProvider | |
| from dotenv import load_dotenv, find_dotenv | |
| warnings.filterwarnings("ignore") | |
| # Load environment variables | |
| load_dotenv(find_dotenv()) | |
| ASTRADB_TOKEN = os.getenv("ASTRADB_TOKEN") | |
| ASTRADB_API_ENDPOINT = os.getenv("ASTRADB_API_ENDPOINT") | |
| # AstraDB connection setup using token and endpoint | |
| auth_provider = PlainTextAuthProvider(username="token", password=ASTRADB_TOKEN) | |
| cluster = Cluster([ASTRADB_API_ENDPOINT], auth_provider=auth_provider) | |
| session = cluster.connect("your_keyspace_name") | |
| # Load DPR models and tokenizers | |
| ctx_encoder = DPRContextEncoder.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base") | |
| ctx_tokenizer = DPRContextEncoderTokenizer.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base") | |
| q_encoder = DPRQuestionEncoder.from_pretrained("facebook/dpr-question_encoder-single-nq-base") | |
| q_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained("facebook/dpr-question_encoder-single-nq-base") | |
| def process_pdfs(parent_dir: Union[str, list]): | |
| """Processes the PDF files and returns a dataframe with the text of each page in a different line.""" | |
| df = pd.DataFrame(columns=["title", "text"]) | |
| if type(parent_dir) == str: | |
| parent_dir = [parent_dir] | |
| for file_path in parent_dir: | |
| if ".pdf" not in file_path: # Skip non-pdf files | |
| raise Exception("only pdf files are supported") | |
| pdfFileObj = open(file_path, 'rb') | |
| pdfReader = PyPDF2.PdfReader(pdfFileObj) | |
| num_pages = len(pdfReader.pages) | |
| for i in range(num_pages): | |
| pageObj = pdfReader.pages[i] | |
| txt = pageObj.extract_text().replace("\n", "").replace("\t", "") | |
| txt = re.sub(r" +", " ", txt) # Strip extra space | |
| file_name = file_path.split("/")[-1] | |
| if len(txt) < 512: | |
| new_data = pd.DataFrame([[f"{file_name}-page-{i}", txt]], columns=["title", "text"]) | |
| df = pd.concat([df, new_data], ignore_index=True) | |
| else: | |
| while len(txt) > 512: | |
| new_data = pd.DataFrame([[f"{file_name}-page-{i}", txt[:512]]], columns=["title", "text"]) | |
| df = pd.concat([df, new_data], ignore_index=True) | |
| txt = txt[512:] | |
| pdfFileObj.close() | |
| return df | |
| def process_dataset(df): | |
| """Processes the dataframe and stores embeddings in AstraDB.""" | |
| if len(df) == 0: | |
| raise Exception("empty pdf files, or can't read text from them") | |
| for _, row in df.iterrows(): | |
| title = row['title'] | |
| text = row['text'] | |
| tokens = ctx_tokenizer(text, return_tensors="pt") | |
| embed = ctx_encoder(**tokens)[0][0].detach().numpy().tolist() | |
| query = "INSERT INTO your_table_name (title, text, embeddings) VALUES (%s, %s, %s)" | |
| session.execute(query, (title, text, embed)) | |
| return df | |
| def search(query, k=3): | |
| """Searches the query in the database and returns the k most similar.""" | |
| try: | |
| tokens = q_tokenizer(query, return_tensors="pt") | |
| query_embed = q_encoder(**tokens)[0][0].detach().numpy().tolist() | |
| # Perform vector search in AstraDB | |
| query = """ | |
| SELECT title, text, embeddings | |
| FROM your_table_name | |
| ORDER BY embeddings ANN OF %s LIMIT %s | |
| """ | |
| rows = session.execute(query, (query_embed, k)) | |
| retrieved_examples = [] | |
| for row in rows: | |
| retrieved_examples.append({ | |
| "title": row.title, | |
| "text": row.text, | |
| "embeddings": np.array(row.embeddings) | |
| }) | |
| out = f"""**title** : {retrieved_examples[0]["title"]},\ncontent: {retrieved_examples[0]["text"]}\n\n\n**similar resources:** {[example["title"] for example in retrieved_examples]} | |
| """ | |
| except Exception as e: | |
| out = f"error in search: {e}" | |
| return out | |
| def predict(query, file_paths, k=3): | |
| """Predicts the most similar files to the query.""" | |
| try: | |
| df = process_pdfs(file_paths) | |
| process_dataset(df) | |
| out = search(query, k=k) | |
| except Exception as e: | |
| out = f"error in predict: {e}" | |
| return out | |
| # Gradio interface | |
| with gr.Blocks() as demo: | |
| gr.Markdown("<h1 style='text-align: center'> PDF Search Engine </h1>") | |
| with gr.Row(): | |
| with gr.Column(): | |
| files = gr.Files(label="Upload PDFs", type="filepath", file_count="multiple") | |
| query = gr.Text(label="query") | |
| with gr.Accordion("number of references", open=False): | |
| k = gr.Number(value=3, show_label=False, precision=0, minimum=1, container=False) | |
| button = gr.Button("search") | |
| with gr.Column(): | |
| output = gr.Markdown(label="output") | |
| button.click(predict, [query, files, k], outputs=output) | |
| demo.launch() | |