Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration | |
| from datasets import load_dataset | |
| # Load dataset (pubmed_qa) and tokenizer | |
| dataset = load_dataset("pubmed_qa", split="test") | |
| tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-nq") | |
| retriever = RagRetriever.from_pretrained("facebook/rag-token-nq", index_name="compressed", passages_path="./path_to_dataset") | |
| # Initialize the RAG model | |
| model = RagSequenceForGeneration.from_pretrained("facebook/rag-token-nq") | |
| # Define Streamlit app | |
| st.title('Medical QA Assistant') | |
| st.markdown("This app uses a RAG model to answer medical queries based on the PubMed QA dataset.") | |
| # User input for query | |
| user_query = st.text_input("Ask a medical question:") | |
| if user_query: | |
| # Tokenize input question and retrieve related documents | |
| inputs = tokenizer(user_query, return_tensors="pt") | |
| input_ids = inputs['input_ids'] | |
| question_encoder_outputs = model.question_encoder(input_ids) | |
| # Use the retriever to get context | |
| retrieved_docs = retriever.retrieve(input_ids) | |
| # Generate an answer based on the context | |
| generated_ids = model.generate(input_ids, context_input_ids=retrieved_docs) | |
| answer = tokenizer.decode(generated_ids[0], skip_special_tokens=True) | |
| # Show the answer | |
| st.write(f"Answer: {answer}") | |
| # Display the most relevant documents | |
| st.subheader("Relevant Documents:") | |
| for doc in retrieved_docs: | |
| st.write(doc['text'][:300] + '...') # Display first 300 characters of each doc | |