Files changed (1) hide show
  1. app.py +34 -6
app.py CHANGED
@@ -8,8 +8,9 @@ from langchain.embeddings import HuggingFaceEmbeddings
8
  from langchain.vectorstores import FAISS
9
  from langchain.llms import HuggingFacePipeline
10
  from langchain.chains import RetrievalQA
 
11
 
12
- from transformers.pipelines import pipeline # βœ… FIXED IMPORT
13
 
14
  # -------------------------------
15
  # Page Config
@@ -19,7 +20,7 @@ st.title("πŸ“„ Chat with Your Documents (RAG)")
19
  st.write("πŸš€ App started successfully")
20
 
21
  # -------------------------------
22
- # Load Documents (FIXED)
23
  # -------------------------------
24
  def load_documents(uploaded_files):
25
  documents = []
@@ -27,7 +28,6 @@ def load_documents(uploaded_files):
27
  for file in uploaded_files:
28
  file_extension = os.path.splitext(file.name)[1]
29
 
30
- # Save safely as temp file
31
  with tempfile.NamedTemporaryFile(delete=False, suffix=file_extension) as tmp:
32
  tmp.write(file.getbuffer())
33
  temp_path = tmp.name
@@ -76,30 +76,56 @@ def create_vectorstore(chunks):
76
 
77
 
78
  # -------------------------------
79
- # Cached LLM (FIXED)
80
  # -------------------------------
81
  @st.cache_resource
82
  def load_llm():
83
  pipe = pipeline(
84
- "text2text-generation", # βœ… CORRECT TASK
85
  model="google/flan-t5-small",
86
  max_length=256
87
  )
88
  return HuggingFacePipeline(pipeline=pipe)
89
 
90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
  # -------------------------------
92
  # Build QA Chain
93
  # -------------------------------
94
  def build_qa(vectorstore):
95
  llm = load_llm()
96
- retriever = vectorstore.as_retriever()
 
 
 
97
 
98
  qa = RetrievalQA.from_chain_type(
99
  llm=llm,
100
  retriever=retriever,
 
101
  return_source_documents=False
102
  )
 
103
  return qa
104
 
105
 
@@ -134,7 +160,9 @@ if uploaded_files:
134
  with st.spinner("πŸ€– Generating answer..."):
135
  try:
136
  result = qa_chain.run(query)
 
137
  st.markdown("### 🧠 Answer:")
138
  st.write(result)
 
139
  except Exception as e:
140
  st.error(f"❌ Error generating answer: {e}")
 
8
  from langchain.vectorstores import FAISS
9
  from langchain.llms import HuggingFacePipeline
10
  from langchain.chains import RetrievalQA
11
+ from langchain.prompts import PromptTemplate
12
 
13
+ from transformers.pipelines import pipeline
14
 
15
  # -------------------------------
16
  # Page Config
 
20
  st.write("πŸš€ App started successfully")
21
 
22
  # -------------------------------
23
+ # Load Documents
24
  # -------------------------------
25
  def load_documents(uploaded_files):
26
  documents = []
 
28
  for file in uploaded_files:
29
  file_extension = os.path.splitext(file.name)[1]
30
 
 
31
  with tempfile.NamedTemporaryFile(delete=False, suffix=file_extension) as tmp:
32
  tmp.write(file.getbuffer())
33
  temp_path = tmp.name
 
76
 
77
 
78
  # -------------------------------
79
+ # Cached LLM
80
  # -------------------------------
81
  @st.cache_resource
82
  def load_llm():
83
  pipe = pipeline(
84
+ "text2text-generation",
85
  model="google/flan-t5-small",
86
  max_length=256
87
  )
88
  return HuggingFacePipeline(pipeline=pipe)
89
 
90
 
91
+ # -------------------------------
92
+ # Custom Prompt (IMPORTANT)
93
+ # -------------------------------
94
+ prompt_template = """
95
+ Use the following context to answer the question clearly.
96
+
97
+ Context:
98
+ {context}
99
+
100
+ Question:
101
+ {question}
102
+
103
+ Answer:
104
+ """
105
+
106
+ PROMPT = PromptTemplate(
107
+ template=prompt_template,
108
+ input_variables=["context", "question"]
109
+ )
110
+
111
+
112
  # -------------------------------
113
  # Build QA Chain
114
  # -------------------------------
115
  def build_qa(vectorstore):
116
  llm = load_llm()
117
+
118
+ retriever = vectorstore.as_retriever(
119
+ search_kwargs={"k": 3} # πŸ”₯ improves answer quality
120
+ )
121
 
122
  qa = RetrievalQA.from_chain_type(
123
  llm=llm,
124
  retriever=retriever,
125
+ chain_type_kwargs={"prompt": PROMPT},
126
  return_source_documents=False
127
  )
128
+
129
  return qa
130
 
131
 
 
160
  with st.spinner("πŸ€– Generating answer..."):
161
  try:
162
  result = qa_chain.run(query)
163
+
164
  st.markdown("### 🧠 Answer:")
165
  st.write(result)
166
+
167
  except Exception as e:
168
  st.error(f"❌ Error generating answer: {e}")