Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,16 +1,16 @@
|
|
| 1 |
-
# app.py
|
| 2 |
import os
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
import streamlit as st
|
| 4 |
-
from langchain.document_loaders import PyPDFLoader
|
| 5 |
-
from langchain.embeddings import HuggingFaceEmbeddings
|
| 6 |
-
from langchain.vectorstores import FAISS
|
| 7 |
from langchain.chains import RetrievalQA
|
|
|
|
|
|
|
| 8 |
from langchain.llms import HuggingFacePipeline
|
| 9 |
from transformers import pipeline
|
| 10 |
from groq import Groq
|
| 11 |
-
import requests
|
| 12 |
-
from PyPDF2 import PdfReader
|
| 13 |
-
import io
|
| 14 |
|
| 15 |
# Set up API key for Groq API
|
| 16 |
#GROQ_API_KEY = "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"
|
|
@@ -27,58 +27,67 @@ def get_groq_client():
|
|
| 27 |
groq_client = get_groq_client()
|
| 28 |
|
| 29 |
|
| 30 |
-
# Predefined PDF link
|
| 31 |
-
pdf_url = "https://drive.google.com/file/d/1P9InkDWyaybb8jR_xS4f4KsxTlYip8RA/view?usp=drive_link"
|
| 32 |
|
| 33 |
-
def extract_text_from_pdf(pdf_url):
|
| 34 |
-
"""Extract text from a PDF file given its Google Drive shared link."""
|
| 35 |
-
# Extract file ID from the Google Drive link
|
| 36 |
-
file_id = pdf_url.split('/d/')[1].split('/view')[0]
|
| 37 |
-
download_url = f"https://drive.google.com/uc?export=download&id={file_id}"
|
| 38 |
-
response = requests.get(download_url)
|
| 39 |
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
return text
|
| 45 |
-
else:
|
| 46 |
-
st.error("Failed to download PDF.")
|
| 47 |
-
return ""
|
| 48 |
|
| 49 |
-
|
| 50 |
-
|
|
|
|
|
|
|
| 51 |
|
| 52 |
-
|
| 53 |
-
|
| 54 |
|
| 55 |
-
|
| 56 |
-
|
|
|
|
|
|
|
|
|
|
| 57 |
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
|
|
|
|
|
|
| 62 |
|
| 63 |
-
|
| 64 |
-
st.info("Storing embeddings in FAISS...")
|
| 65 |
-
faiss_index = FAISS.from_texts([extracted_text], embeddings_model)
|
| 66 |
|
| 67 |
-
|
| 68 |
-
st.info("Setting up RAG pipeline...")
|
| 69 |
-
hf_pipeline = pipeline("text-generation", model="google/flan-t5-base", tokenizer="google/flan-t5-base")
|
| 70 |
-
llm = HuggingFacePipeline(pipeline=hf_pipeline)
|
| 71 |
|
| 72 |
-
|
| 73 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 74 |
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import os
|
| 2 |
+
import requests
|
| 3 |
+
from io import BytesIO
|
| 4 |
+
from PyPDF2 import PdfReader
|
| 5 |
+
from sentence_transformers import SentenceTransformer
|
| 6 |
+
import faiss
|
| 7 |
import streamlit as st
|
|
|
|
|
|
|
|
|
|
| 8 |
from langchain.chains import RetrievalQA
|
| 9 |
+
from langchain.vectorstores import FAISS
|
| 10 |
+
from langchain.embeddings import HuggingFaceEmbeddings
|
| 11 |
from langchain.llms import HuggingFacePipeline
|
| 12 |
from transformers import pipeline
|
| 13 |
from groq import Groq
|
|
|
|
|
|
|
|
|
|
| 14 |
|
| 15 |
# Set up API key for Groq API
|
| 16 |
#GROQ_API_KEY = "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"
|
|
|
|
| 27 |
groq_client = get_groq_client()
|
| 28 |
|
| 29 |
|
|
|
|
|
|
|
| 30 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
|
| 32 |
+
def download_pdf(url):
|
| 33 |
+
response = requests.get(url)
|
| 34 |
+
response.raise_for_status()
|
| 35 |
+
return BytesIO(response.content)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
|
| 37 |
+
def extract_text_from_pdf(pdf_data):
|
| 38 |
+
reader = PdfReader(pdf_data)
|
| 39 |
+
text = "\n".join(page.extract_text() for page in reader.pages if page.extract_text())
|
| 40 |
+
return text
|
| 41 |
|
| 42 |
+
def preprocess_text(text):
|
| 43 |
+
return " ".join(text.split())
|
| 44 |
|
| 45 |
+
def build_faiss_index(embeddings, texts):
|
| 46 |
+
index = faiss.IndexFlatL2(embeddings.embedding_dim)
|
| 47 |
+
text_store = FAISS(embeddings, index)
|
| 48 |
+
text_store.add_texts(texts)
|
| 49 |
+
return text_store
|
| 50 |
|
| 51 |
+
# URLs of ASD-related PDF documents
|
| 52 |
+
pdf_links = [
|
| 53 |
+
"https://drive.google.com/file/d/1P9InkDWyaybb8jR_xS4f4KsxTlYip8RA/view?usp=drive_link", # Replace X, Y, Z with actual URLs of ASD-related literature
|
| 54 |
+
"https://drive.google.com/file/d/1P9InkDWyaybb8jR_xS4f4KsxTlYip8RA/view?usp=drive_link",
|
| 55 |
+
"https://drive.google.com/file/d/1P9InkDWyaybb8jR_xS4f4KsxTlYip8RA/view?usp=drive_link"
|
| 56 |
+
]
|
| 57 |
|
| 58 |
+
st.title("ASD Diagnosis and Therapy Chatbot")
|
|
|
|
|
|
|
| 59 |
|
| 60 |
+
st.markdown("This application assists in diagnosing types of ASD and recommends evidence-based therapies and treatments.")
|
|
|
|
|
|
|
|
|
|
| 61 |
|
| 62 |
+
with st.spinner("Downloading and extracting text from PDFs..."):
|
| 63 |
+
texts = []
|
| 64 |
+
for link in pdf_links:
|
| 65 |
+
pdf_data = download_pdf(link)
|
| 66 |
+
text = extract_text_from_pdf(pdf_data)
|
| 67 |
+
cleaned_text = preprocess_text(text)
|
| 68 |
+
texts.append(cleaned_text)
|
| 69 |
|
| 70 |
+
with st.spinner("Generating embeddings..."):
|
| 71 |
+
embeddings_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
|
| 72 |
+
text_store = build_faiss_index(embeddings_model, texts)
|
| 73 |
|
| 74 |
+
with st.spinner("Setting up the RAG pipeline..."):
|
| 75 |
+
hf_pipeline = pipeline("text-generation", model="gpt-2") # Replace with a model optimized for medical text, if available
|
| 76 |
+
llm = HuggingFacePipeline(pipeline=hf_pipeline)
|
| 77 |
+
qa_chain = RetrievalQA(llm=llm, retriever=text_store.as_retriever())
|
| 78 |
+
|
| 79 |
+
query = st.text_input("Ask a question about ASD diagnosis, types, or therapies:")
|
| 80 |
+
if query:
|
| 81 |
+
with st.spinner("Processing your query..."):
|
| 82 |
+
answer = qa_chain.run(query)
|
| 83 |
+
st.success("Answer:")
|
| 84 |
+
st.write(answer)
|
| 85 |
+
|
| 86 |
+
st.markdown("---")
|
| 87 |
+
st.markdown("### Example Queries:")
|
| 88 |
+
st.markdown("- What type of ASD does an individual with sensory issues have?")
|
| 89 |
+
st.markdown("- What therapies are recommended for social communication challenges?")
|
| 90 |
+
st.markdown("- What treatments are supported by clinical guidelines for repetitive behaviors?")
|
| 91 |
+
|
| 92 |
+
st.markdown("---")
|
| 93 |
+
st.markdown("Powered by Streamlit, Hugging Face, and LangChain")
|