| |
| |
| |
|
|
| |
| from llama_index.core import VectorStoreIndex, StorageContext, load_index_from_storage |
| from llama_index.core.indices.query.query_transform import HyDEQueryTransform |
| from llama_index.core.query_engine import TransformQueryEngine |
| from langchain_nvidia_ai_endpoints import NVIDIARerank |
| from langchain_core.documents import Document as LangDocument |
| from llama_index.core.llms import ChatMessage, MessageRole |
| from llama_index.llms.nvidia import NVIDIA |
| from llama_index.embeddings.nvidia import NVIDIAEmbedding |
| from llama_index.core import Document as LlamaDocument |
| from llama_index.core import Settings |
| from llama_parse import LlamaParse |
| import streamlit as st |
| import os |
| |
| nvidia_api_key = os.getenv("NVIDIA_KEY") |
| llamaparse_api_key = os.getenv("PARSE_KEY") |
|
|
| |
| client = NVIDIA( |
| model="meta/llama-3.1-8b-instruct", |
| api_key=nvidia_api_key, |
| temperature=0.2, |
| top_p=0.7, |
| max_tokens=1024 |
| ) |
| embed_model = NVIDIAEmbedding( |
| model="nvidia/nv-embedqa-e5-v5", |
| api_key=nvidia_api_key, |
| truncate="NONE" |
| ) |
|
|
|
|
| reranker = NVIDIARerank( |
| model="nvidia/nv-rerankqa-mistral-4b-v3", |
| api_key=nvidia_api_key, |
| ) |
|
|
| |
| Settings.embed_model = embed_model |
| Settings.llm = client |
|
|
| |
| parser = LlamaParse( |
| api_key=llamaparse_api_key, |
| result_type="markdown", |
| verbose=True |
| ) |
|
|
| |
| script_dir = os.path.dirname(os.path.abspath(__file__)) |
| data_file = os.path.join(script_dir, "FreightsDataset.pdf") |
|
|
| |
| documents = parser.load_data(data_file) |
| print("Document Parsed") |
|
|
| |
| def split_text(text, max_tokens=512): |
| words = text.split() |
| chunks = [] |
| current_chunk = [] |
| current_length = 0 |
|
|
| for word in words: |
| word_length = len(word) |
| if current_length + word_length + 1 > max_tokens: |
| chunks.append(" ".join(current_chunk)) |
| current_chunk = [word] |
| current_length = word_length + 1 |
| else: |
| current_chunk.append(word) |
| current_length += word_length + 1 |
|
|
| if current_chunk: |
| chunks.append(" ".join(current_chunk)) |
|
|
| return chunks |
|
|
| |
| all_embeddings = [] |
| all_documents = [] |
|
|
| for doc in documents: |
| text_chunks = split_text(doc.text) |
| for chunk in text_chunks: |
| embedding = embed_model.get_text_embedding(chunk) |
| all_embeddings.append(embedding) |
| all_documents.append(LlamaDocument(text=chunk)) |
| print("Embeddings generated") |
|
|
| |
| index = VectorStoreIndex.from_documents(all_documents, embeddings=all_embeddings, embed_model=embed_model) |
| index.set_index_id("vector_index") |
| index.storage_context.persist("./storage") |
| print("Index created") |
|
|
| |
| storage_context = StorageContext.from_defaults(persist_dir="storage") |
| index = load_index_from_storage(storage_context, index_id="vector_index") |
| print("Index loaded") |
|
|
| |
| hyde = HyDEQueryTransform(include_original=True) |
| query_engine = index.as_query_engine() |
| hyde_query_engine = TransformQueryEngine(query_engine, hyde) |
|
|
| |
| def query_model_with_context(question): |
| |
| hyde_response = hyde_query_engine.query(question) |
| print(f"HyDE Response: {hyde_response}") |
|
|
| if isinstance(hyde_response, str): |
| hyde_query = hyde_response |
| else: |
| hyde_query = hyde_response.response |
|
|
| |
| retriever = index.as_retriever(similarity_top_k=3) |
| nodes = retriever.retrieve(hyde_query) |
|
|
| for node in nodes: |
| print(node) |
|
|
| |
| ranked_documents = reranker.compress_documents( |
| query=question, |
| documents=[LangDocument(page_content=node.text) for node in nodes] |
| ) |
|
|
| |
| print(f"Most relevant node: {ranked_documents[0].page_content}") |
|
|
| |
| context = ranked_documents[0].page_content |
|
|
| |
| |
| messages = [ |
| ChatMessage(role=MessageRole.SYSTEM, content=context), |
| ChatMessage(role=MessageRole.USER, content=str(question)) |
| ] |
|
|
| |
| completion = client.chat(messages) |
|
|
| |
| response_text = "" |
|
|
| if isinstance(completion, (list, tuple)): |
| |
| response_text = ' '.join(completion) |
| elif isinstance(completion, str): |
| |
| response_text = completion |
| else: |
| |
| response_text = str(completion) |
| |
| response_text = response_text.replace("assistant:", "Final Response:").strip() |
|
|
| return response_text |
|
|
|
|
| |
| st.title("Chat with HyDE and Rerank RAG Freights App") |
| question = st.text_input("Enter a relevant question to chat with the attached FreightsDataset file:") |
|
|
| if st.button("Submit"): |
| if question: |
| st.write("**RAG Response:**") |
| response = query_model_with_context(question) |
| st.write(response) |
| else: |
| st.warning("Please enter a question.") |
|
|
|
|