Spaces:
Runtime error
Runtime error
| from langchain_text_splitters import RecursiveCharacterTextSplitter | |
| from langchain.schema.document import Document | |
| from langchain_community.embeddings import HuggingFaceBgeEmbeddings | |
| from langchain_chroma import Chroma | |
| import spaces | |
| from langchain_text_splitters import MarkdownHeaderTextSplitter | |
| import os | |
| from transformers import AutoTokenizer | |
| api_token = os.getenv("HF_TOKEN") | |
| model_name = "meta-llama/Llama-3.1-8B-Instruct" | |
| tokenizer = AutoTokenizer.from_pretrained(model_name, token=api_token) | |
| embedding_model = HuggingFaceBgeEmbeddings( | |
| model_name="BAAI/bge-large-en-v1.5", | |
| model_kwargs={"device": "cuda"}, | |
| encode_kwargs={"normalize_embeddings": True}, | |
| query_instruction="" | |
| ) | |
| def create_rag_index(text_no_prefix): | |
| """Loads the PDF, splits its text, and builds a vectorstore for naive RAG.""" | |
| text_splitter = RecursiveCharacterTextSplitter.from_huggingface_tokenizer( | |
| tokenizer, | |
| chunk_size=256, | |
| chunk_overlap=0, | |
| add_start_index=True, | |
| strip_whitespace=True, | |
| separators=["\n\n", "\n", ".", " ", ""], | |
| ) | |
| # Concatenate pages and create Document objects. | |
| docs = [Document(page_content=x) for x in text_splitter.split_text(text_no_prefix)] | |
| vectorstore = Chroma.from_documents(documents=docs, embedding=embedding_model) | |
| return vectorstore | |
| def run_naive_rag_query(vectorstore, query, rag_token_size, prefix, task, few_shot_examples): | |
| """ | |
| For naive RAG, retrieves top-k chunks (k based on target token size) | |
| and generates an answer using those chunks. | |
| """ | |
| k = max(1, rag_token_size // 256) | |
| retriever = vectorstore.as_retriever(search_type="similarity", search_kwargs={"k": k}) | |
| retrieved_docs = retriever.invoke(query) | |
| for doc in retrieved_docs: | |
| print("=================") | |
| print(doc.page_content) | |
| print("=================") | |
| formatted_context = "\n\n".join([doc.page_content for doc in retrieved_docs]) | |
| rag_context = prefix + "Retrieved context: \n" + formatted_context + task + few_shot_examples | |
| return rag_context |