Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| from utils.vector_base import KnowledgeBase | |
| from utils.embedding import Embeddings | |
| from utils.llm import LLM | |
| from config import config | |
| import json | |
| def get_emdedding_model(): | |
| return Embeddings() | |
| def get_llm(api_key): | |
| return LLM(api_key) | |
| def get_metadata(path): | |
| titles, texts = [], [] | |
| with open(path, 'rb') as file: | |
| metadata = json.load(file) | |
| for data in metadata: | |
| titles.append(data['title']) | |
| texts.append(data['text']) | |
| return texts, titles | |
| def combine_docs(indexes, texts): | |
| result = "" | |
| for i, index in enumerate(indexes): | |
| result += " [" + str(i + 1) + "] " + texts[index] | |
| return result | |
| def create_prompt(query, docs): | |
| system_prompt = f"""You are a language model integrated into a retrieval-augmented generation (RAG) system. | |
| Your task is to answer the user's query strictly based on the provided documents. Do not invent, speculate, or include any information not found in the documents. | |
| If the required information is available in the documents, use it to construct your response and cite the source by indicating the document number in square brackets. For example: | |
| DL stands for Deep Learning, a subset of Machine Learning that involves learning complex non-linear relationships between large datasets [6]. | |
| If the information required to answer the query is not available in the documents, explicitly state: | |
| "The required information is not available in the provided documents." | |
| Ensure that: | |
| - The response is entirely based on the content of the documents. | |
| - Citations are accurate and directly linked to the information being cited. | |
| - No assumptions, speculations, or fabricated details are included. | |
| User query: {query} | |
| Documents: | |
| {docs} | |
| """ | |
| return system_prompt | |
| def main(query, search_types, llm_api_key): | |
| model, llm = get_emdedding_model(), get_llm(llm_api_key) | |
| texts, titles = get_metadata(config.PATH_METADATA) | |
| embedding = model.get_query_embedding(query) | |
| knowledge_base = KnowledgeBase(config.PATH_FAISS, config.PATH_PREPROCESSING_TEXT) | |
| vector_search = [] | |
| bm25_search = [] | |
| if "Vector" in search_types: | |
| vector_search = knowledge_base.search_by_embedding(embedding, 5)[0].tolist() | |
| if "BM25" in search_types: | |
| bm25_search = knowledge_base.search_by_BM25(query, 3) | |
| docs = combine_docs(vector_search + bm25_search, texts) | |
| prompt = create_prompt(query, docs) | |
| response = llm.generate_response(prompt) | |
| return response, docs | |
| # Streamlit Interface | |
| if __name__ == '__main__': | |
| st.title("PaperRAG") | |
| st.subheader("RAG system for scientific papers with selectable search types") | |
| query = st.text_input("Enter your query") | |
| search_types = st.multiselect( | |
| "Select search types", | |
| options=["Vector", "BM25"], | |
| default=["Vector", "BM25"] | |
| ) | |
| llm_api_key = st.text_input("Cohere API Key", type="password") | |
| if st.button("Get Response"): | |
| if query and llm_api_key: | |
| response, docs = main(query, search_types, llm_api_key) | |
| st.subheader("LLM Response:") | |
| st.text_area("Response", value=response, height=300) | |
| st.subheader("Citations:") | |
| st.text_area("Documents", value=docs, height=300) | |
| else: | |
| st.error("Please enter both a query and an API key.") | |