Files changed (1) hide show
  1. app.py +30 -17
app.py CHANGED
@@ -1,5 +1,6 @@
1
  import streamlit as st
2
  import tempfile
 
3
 
4
  from langchain_community.document_loaders import PyPDFLoader, TextLoader
5
  from langchain.text_splitter import RecursiveCharacterTextSplitter
@@ -8,7 +9,7 @@ from langchain.vectorstores import FAISS
8
  from langchain.llms import HuggingFacePipeline
9
  from langchain.chains import RetrievalQA
10
 
11
- from transformers import pipeline
12
 
13
  # -------------------------------
14
  # Page Config
@@ -24,18 +25,23 @@ def load_documents(uploaded_files):
24
  documents = []
25
 
26
  for file in uploaded_files:
27
- # Save file safely using temp file
28
- with tempfile.NamedTemporaryFile(delete=False, suffix=file.name) as tmp:
 
 
29
  tmp.write(file.getbuffer())
30
  temp_path = tmp.name
31
 
32
- # Load based on type
33
- if file.name.endswith(".pdf"):
34
- loader = PyPDFLoader(temp_path)
35
- else:
36
- loader = TextLoader(temp_path)
 
 
37
 
38
- documents.extend(loader.load())
 
39
 
40
  return documents
41
 
@@ -52,7 +58,7 @@ def split_documents(documents):
52
 
53
 
54
  # -------------------------------
55
- # Cached Embeddings (IMPORTANT)
56
  # -------------------------------
57
  @st.cache_resource
58
  def get_embeddings():
@@ -70,13 +76,13 @@ def create_vectorstore(chunks):
70
 
71
 
72
  # -------------------------------
73
- # Cached LLM (IMPORTANT)
74
  # -------------------------------
75
  @st.cache_resource
76
  def load_llm():
77
  pipe = pipeline(
78
- "text-generation",
79
- model="google/flan-t5-small", # lightweight model
80
  max_length=256
81
  )
82
  return HuggingFacePipeline(pipeline=pipe)
@@ -108,6 +114,11 @@ uploaded_files = st.file_uploader(
108
  if uploaded_files:
109
  with st.spinner("πŸ“„ Processing documents..."):
110
  docs = load_documents(uploaded_files)
 
 
 
 
 
111
  chunks = split_documents(docs)
112
  vectorstore = create_vectorstore(chunks)
113
  qa_chain = build_qa(vectorstore)
@@ -121,7 +132,9 @@ if uploaded_files:
121
 
122
  if query:
123
  with st.spinner("πŸ€– Generating answer..."):
124
- result = qa_chain.run(query)
125
-
126
- st.markdown("### 🧠 Answer:")
127
- st.write(result)
 
 
 
1
  import streamlit as st
2
  import tempfile
3
+ import os
4
 
5
  from langchain_community.document_loaders import PyPDFLoader, TextLoader
6
  from langchain.text_splitter import RecursiveCharacterTextSplitter
 
9
  from langchain.llms import HuggingFacePipeline
10
  from langchain.chains import RetrievalQA
11
 
12
+ from transformers.pipelines import pipeline # βœ… FIXED IMPORT
13
 
14
  # -------------------------------
15
  # Page Config
 
25
  documents = []
26
 
27
  for file in uploaded_files:
28
+ file_extension = os.path.splitext(file.name)[1]
29
+
30
+ # Save safely as temp file
31
+ with tempfile.NamedTemporaryFile(delete=False, suffix=file_extension) as tmp:
32
  tmp.write(file.getbuffer())
33
  temp_path = tmp.name
34
 
35
+ try:
36
+ if file_extension.lower() == ".pdf":
37
+ loader = PyPDFLoader(temp_path)
38
+ else:
39
+ loader = TextLoader(temp_path)
40
+
41
+ documents.extend(loader.load())
42
 
43
+ except Exception as e:
44
+ st.error(f"❌ Error loading file: {e}")
45
 
46
  return documents
47
 
 
58
 
59
 
60
  # -------------------------------
61
+ # Cached Embeddings
62
  # -------------------------------
63
  @st.cache_resource
64
  def get_embeddings():
 
76
 
77
 
78
  # -------------------------------
79
+ # Cached LLM (FIXED)
80
  # -------------------------------
81
  @st.cache_resource
82
  def load_llm():
83
  pipe = pipeline(
84
+ "text2text-generation", # βœ… CORRECT TASK
85
+ model="google/flan-t5-small",
86
  max_length=256
87
  )
88
  return HuggingFacePipeline(pipeline=pipe)
 
114
  if uploaded_files:
115
  with st.spinner("πŸ“„ Processing documents..."):
116
  docs = load_documents(uploaded_files)
117
+
118
+ if not docs:
119
+ st.error("❌ No valid documents loaded.")
120
+ st.stop()
121
+
122
  chunks = split_documents(docs)
123
  vectorstore = create_vectorstore(chunks)
124
  qa_chain = build_qa(vectorstore)
 
132
 
133
  if query:
134
  with st.spinner("πŸ€– Generating answer..."):
135
+ try:
136
+ result = qa_chain.run(query)
137
+ st.markdown("### 🧠 Answer:")
138
+ st.write(result)
139
+ except Exception as e:
140
+ st.error(f"❌ Error generating answer: {e}")