Spaces:
Build error
Build error
| import os | |
| import argparse | |
| from datasets import load_dataset | |
| from langchain.schema import Document | |
| from langchain.vectorstores import Chroma | |
| from langchain.embeddings import HuggingFaceEmbeddings | |
| from langchain.llms import LlamaCpp | |
| from langchain.chains import RetrievalQA | |
| from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler | |
| # Initialize the database | |
| def initialize_database(): | |
| print("๐น Loading medical dataset...") | |
| ds = load_dataset("lavita/ChatDoctor-HealthCareMagic-100k", split="train") | |
| qa_pairs = [{"question": x["instruction"], "answer": x["output"]} for x in ds.select(range(1000))] | |
| # Convert to LangChain Documents | |
| print("๐น Converting to LangChain documents...") | |
| docs = [ | |
| Document( | |
| page_content=f"Question: {item['question']}\nAnswer: {item['answer']}", | |
| metadata={"source": "ChatDoctor"} | |
| ) | |
| for item in qa_pairs | |
| ] | |
| # Embedding documents | |
| print("๐น Embedding documents...") | |
| embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2") | |
| # ChromaDB setup | |
| persist_dir = "./chroma_medical_db" | |
| if not os.path.exists(persist_dir): | |
| print("๐น Creating new ChromaDB...") | |
| vectorstore = Chroma.from_documents(docs, embedding_model, persist_directory=persist_dir) | |
| vectorstore.persist() | |
| else: | |
| print("๐น Loading existing ChromaDB...") | |
| vectorstore = Chroma(persist_directory=persist_dir, embedding_function=embedding_model) | |
| # Setup the retriever | |
| retriever = vectorstore.as_retriever(search_kwargs={"k": 3}) | |
| # Local LLM setup | |
| print("๐น Loading local LLM model...") | |
| llm = LlamaCpp( | |
| model_path="models/tinyllama-1.1b-chat-v1.0.Q4_K_M.gguf", | |
| n_ctx=1024, | |
| temperature=0.7, | |
| max_tokens=512, | |
| streaming=True, | |
| callbacks=[StreamingStdOutCallbackHandler()], | |
| verbose=True, | |
| f16_kv=True, | |
| use_mlock=True, | |
| use_mmap=True, | |
| n_threads=4, | |
| n_batch=64 | |
| ) | |
| # Build RAG QA chain | |
| print("๐น Building RAG chain...") | |
| qa_chain = RetrievalQA.from_chain_type( | |
| llm=llm, | |
| retriever=retriever, | |
| return_source_documents=True | |
| ) | |
| return qa_chain | |
| # Function to handle the query | |
| def handle_query(query): | |
| qa_chain = initialize_database() | |
| print(f"๐น Query: {query}") | |
| result = qa_chain(query) | |
| response = { | |
| "answer": result['result'], | |
| "sources": result['source_documents'] | |
| } | |
| return response | |
| # Main CLI functionality | |
| def main(): | |
| parser = argparse.ArgumentParser(description="Medical Question-Answering CLI Application") | |
| parser.add_argument("query", type=str, help="Query to ask the medical AI agent") | |
| args = parser.parse_args() | |
| query = args.query | |
| result = handle_query(query) | |
| print("\n๐ง Answer:") | |
| print(result["answer"]) | |
| print("\nSource Documents:") | |
| for doc in result["sources"]: | |
| print(doc["text"]) | |
| if __name__ == "__main__": | |
| main() | |