Spaces:
Sleeping
Sleeping
| from langchain_community.llms import HuggingFacePipeline | |
| from langchain.prompts import PromptTemplate | |
| from langchain.chains import RetrievalQA | |
| from langchain_community.embeddings import HuggingFaceEmbeddings | |
| from langchain_community.vectorstores import Chroma | |
| from langchain_community.document_loaders import TextLoader | |
| from langchain.text_splitter import CharacterTextSplitter | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline | |
| # Load Gemma model and tokenizer | |
| #model_name = "google/gemma-2-2b" | |
| #model_name = "google/gemma-1.1-2b-it" | |
| model_name = "HuggingFaceH4/zephyr-7b-beta" | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| model = AutoModelForCausalLM.from_pretrained(model_name) | |
| # Create a text generation pipeline | |
| text_generation_pipeline = pipeline( | |
| "text-generation", | |
| model=model, | |
| tokenizer=tokenizer, | |
| max_new_tokens=512, | |
| temperature=0.7 | |
| ) | |
| # Create a LangChain LLM from the pipeline | |
| llm = HuggingFacePipeline(pipeline=text_generation_pipeline) | |
| # Load and process documents | |
| #loader = TextLoader("https://en.wikipedia.org/wiki/Cheetah") | |
| loader = TextLoader("https://en.wikipedia.org/wiki/Artificial_neuron") | |
| documents = loader.load() | |
| text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0) | |
| texts = text_splitter.split_documents(documents) | |
| # Create embeddings and vector store | |
| embeddings = HuggingFaceEmbeddings() | |
| db = Chroma.from_documents(texts, embeddings) | |
| # Create a retriever | |
| retriever = db.as_retriever() | |
| # Create a prompt template | |
| template = """Use the following pieces of context to answer the question at the end. | |
| If you don't know the answer, just say that you don't know, don't try to make up an answer. | |
| {context} | |
| Question: {question} | |
| Answer:""" | |
| prompt = PromptTemplate(template=template, input_variables=["context", "question"]) | |
| # Create the RetrievalQA chain | |
| qa_chain = RetrievalQA.from_chain_type( | |
| llm=llm, | |
| chain_type="stuff", | |
| retriever=retriever, | |
| return_source_documents=True, | |
| chain_type_kwargs={"prompt": prompt} | |
| ) | |
| # Example query | |
| #query = "How fast cheetah can run?" | |
| query = "What is an artifical neuron?" | |
| result = qa_chain({"query": query}) | |
| print("Question:", query) | |
| print("Answer:", result["result"]) |