Spaces:
Runtime error
Runtime error
| import argparse | |
| import os | |
| from dataclasses import dataclass | |
| import chromadb | |
| import yaml | |
| from langchain.chains.llm import LLMChain | |
| from langchain.vectorstores.chroma import Chroma | |
| from langchain_community.embeddings import HuggingFaceEmbeddings | |
| from langchain_community.llms.huggingface_endpoint import HuggingFaceEndpoint | |
| from langchain_core.prompts import PromptTemplate | |
| CONFIG_PATH = os.path.join('config', 'default_config.yaml') | |
| CHROMA_PATH = "chroma" | |
| MODEL_CACHE = "model_cache" | |
| PROMPT_TEMPLATE = """ | |
| Answer the question based only on the following context: | |
| {context} | |
| --- | |
| Answer the question based on the above context: {question} | |
| """ | |
| def main(): | |
| # Create CLI. | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("query_text", type=str, help="The query text.") | |
| args = parser.parse_args() | |
| query_text = args.query_text | |
| # Prepare the DB. | |
| hf_embed_func = HuggingFaceEmbeddings( | |
| model_name="all-MiniLM-L6-v2", | |
| model_kwargs={'device': 'cpu'}, | |
| encode_kwargs={'normalize_embeddings': False}, | |
| cache_folder=MODEL_CACHE | |
| ) | |
| db = Chroma(persist_directory=CHROMA_PATH, embedding_function=hf_embed_func, collection_name="jscholar_rag") | |
| client = chromadb.PersistentClient(path=CHROMA_PATH) | |
| collection = client.get_collection(name="jscholar_rag") | |
| print(f"Total Embeddings: {collection.count()}") | |
| print(collection.peek()) | |
| # Search the DB. | |
| results = db.similarity_search_with_relevance_scores(query_text, k=5) | |
| # results = db.similarity_search(query_text) | |
| if len(results) == 0 or results[0][1] < 0.1: | |
| print(f"Unable to find matching results.") | |
| return | |
| context_text = "\n\n---\n\n".join([doc.page_content for doc, _score in results]) | |
| # prompt_template = ChatPromptTemplate.from_template(PROMPT_TEMPLATE) | |
| prompt = PromptTemplate( | |
| input_variables=[context_text, query_text], template=PROMPT_TEMPLATE | |
| ) | |
| #prompt = prompt_template.format(context=context_text, question=query_text) | |
| llm = HuggingFaceEndpoint( | |
| repo_id="HuggingFaceH4/zephyr-7b-beta", | |
| task="text-generation", | |
| top_k=30, | |
| temperature=0.1, | |
| repetition_penalty=1.03, | |
| max_new_tokens=512, | |
| ) | |
| chat_model = LLMChain(prompt=prompt, llm=llm) | |
| response_text = chat_model.invoke({'question': query_text, 'context': context_text}) | |
| sources = [doc.metadata.get("source", None) for doc, _score in results] | |
| formatted_response = f"{response_text.get('text')}" | |
| formatted_sources = f"Citations: {sources}" | |
| print(formatted_response) | |
| print(formatted_sources) | |
| def load_config(): | |
| with open(CONFIG_PATH, 'r') as file: | |
| loaded_data = yaml.safe_load(file) | |
| return loaded_data | |
| if __name__ == "__main__": | |
| main() |