Spaces:
Runtime error
Runtime error
| from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer, pipeline | |
| from langchain.llms import HuggingFaceHub, HuggingFacePipeline | |
| from dotenv import load_dotenv | |
| from langchain.embeddings import HuggingFaceBgeEmbeddings | |
| from langchain.vectorstores import Chroma | |
| from langchain.chains import RetrievalQA | |
| import textwrap | |
| import os | |
| def load_vector_store(): | |
| model_name = "BAAI/bge-small-en" | |
| model_kwargs = {"device": "cpu"} | |
| encode_kwargs = {"normalize_embeddings": True} | |
| embeddings = HuggingFaceBgeEmbeddings( | |
| model_name=model_name, model_kwargs=model_kwargs, encode_kwargs=encode_kwargs | |
| ) | |
| print('Embeddings loaded!') | |
| load_vector_store = Chroma(persist_directory = 'vector stores/ncertdb', embedding_function = embeddings) | |
| print('Vector store loaded!') | |
| retriever = load_vector_store.as_retriever( | |
| search_kwargs = {"k" : 2}, | |
| ) | |
| return retriever | |
| #model | |
| def load_model(): | |
| load_dotenv() | |
| os.environ["HUGGINGFACEHUB_API_TOKEN"] = os.getenv("HUGGINGFACEHUB_API_TOKEN") | |
| repo_id = 'llmware/bling-sheared-llama-1.3b-0.1' | |
| llm = HuggingFaceHub( | |
| repo_id = repo_id, | |
| model_kwargs = {'max_new_tokens' : 100} | |
| ) | |
| print(llm('HI!')) | |
| return llm | |
| def qa_chain(): | |
| retriever = load_vector_store() | |
| llm = load_model() | |
| qa = RetrievalQA.from_chain_type( | |
| llm = llm, | |
| chain_type = 'stuff', | |
| retriever = retriever, | |
| return_source_documents = True, | |
| verbose = True | |
| ) | |
| return qa | |
| def wrap_text_preserve_newlines(text, width=110): | |
| # Split the input text into lines based on newline characters | |
| lines = text.split('\n') | |
| # Wrap each line individually | |
| wrapped_lines = [textwrap.fill(line, width=width) for line in lines] | |
| # Join the wrapped lines back together using newline characters | |
| wrapped_text = '\n'.join(wrapped_lines) | |
| return wrapped_text | |
| def process_llm_response(llm_response): | |
| print(wrap_text_preserve_newlines(llm_response['result'])) | |
| print('\n\nSources:') | |
| for source in llm_response["source_documents"]: | |
| print(source.metadata['source']) | |
| qa = qa_chain() | |
| response = qa('What are types of Embedded system?') | |
| process_llm_response(response) |