Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration | |
| import faiss # Ensure faiss is available | |
| # Load the tokenizer, retriever, and model | |
| tokenizer = RagTokenizer.from_pretrained("facebook/rag-sequence-nq") | |
| retriever = RagRetriever.from_pretrained("facebook/rag-sequence-nq", use_dummy_dataset=True) | |
| model = RagSequenceForGeneration.from_pretrained("facebook/rag-sequence-nq", retriever=retriever) | |
| # Define prediction function | |
| def predict(input_text): | |
| # Tokenize input | |
| input_ids = tokenizer([input_text], return_tensors="pt").input_ids | |
| # Generate response | |
| outputs = model.generate(input_ids) | |
| response = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0] | |
| return response | |
| # Add example texts | |
| examples = [ | |
| ["Patient admitted with a history of heart failure and requires detailed follow-up on cardiovascular treatment."], | |
| ["What are the complications of diabetes mellitus that need to be monitored in this patient?"], | |
| ["Describe the appropriate treatment for acute respiratory distress syndrome in a critical care setting."], | |
| ["Explain the signs and symptoms that indicate a neurological emergency in a stroke patient."], | |
| ["What are the best practices for managing an infectious disease outbreak in a hospital setting?"] | |
| ] | |
| # Create Gradio interface | |
| iface = gr.Interface( | |
| fn=predict, | |
| inputs=gr.Textbox(lines=10, placeholder="Enter your medical question or clinical notes here..."), | |
| outputs="text", | |
| examples=examples, | |
| title="MIMIC-IV RAG Implementation", | |
| description="Use RAG (Retrieval-Augmented Generation) to generate responses or provide additional information based on clinical notes and medical questions. This model helps in generating relevant information based on existing medical literature.", | |
| ) | |
| iface.launch() | |