Himel09 commited on
Commit
f3c9795
·
verified ·
1 Parent(s): 81aead1

Create src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +151 -37
src/streamlit_app.py CHANGED
@@ -1,40 +1,154 @@
1
- import altair as alt
2
- import numpy as np
3
  import pandas as pd
4
  import streamlit as st
 
 
 
 
 
 
 
 
5
 
6
- """
7
- # Welcome to Streamlit!
8
-
9
- Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
11
- forums](https://discuss.streamlit.io).
12
-
13
- In the meantime, below is an example of what you can do with just a few lines of code:
14
- """
15
-
16
- num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
- num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
18
-
19
- indices = np.linspace(0, 1, num_points)
20
- theta = 2 * np.pi * num_turns * indices
21
- radius = indices
22
-
23
- x = radius * np.cos(theta)
24
- y = radius * np.sin(theta)
25
-
26
- df = pd.DataFrame({
27
- "x": x,
28
- "y": y,
29
- "idx": indices,
30
- "rand": np.random.randn(num_points),
31
- })
32
-
33
- st.altair_chart(alt.Chart(df, height=700, width=700)
34
- .mark_point(filled=True)
35
- .encode(
36
- x=alt.X("x", axis=None),
37
- y=alt.Y("y", axis=None),
38
- color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
- size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
40
- ))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
  import pandas as pd
4
  import streamlit as st
5
+ from langchain_community.llms import Ollama
6
+ from langchain_community.document_loaders import PyPDFLoader
7
+ from langchain_text_splitters import RecursiveCharacterTextSplitter
8
+ from langchain_community.embeddings import OllamaEmbeddings
9
+ from langchain_community.vectorstores import Chroma
10
+ from langchain_core.prompts import ChatPromptTemplate
11
+ from langchain_core.output_parsers import StrOutputParser
12
+ from langchain_groq import ChatGroq
13
 
14
+
15
+ st.set_page_config(page_title="📘 PDF Q&A Generator", page_icon="🤖", layout="wide")
16
+ st.title("📘 PDF Question–Answer Generator (GORQ + RAG)")
17
+
18
+ st.markdown("""
19
+ Welcome! Upload a PDF and ask questions about its content.
20
+ The system will generate answers and save all Q&A pairs as a CSV.
21
+ """)
22
+
23
+
24
+ st.sidebar.header("🔑 API Settings")
25
+ groq_api_key = st.sidebar.text_input("Enter your Groq API Key:", type="password")
26
+
27
+ # Stop execution if API key is missing
28
+ if not groq_api_key or groq_api_key.strip() == "":
29
+ st.warning("⚠️ Please enter your Groq API Key to proceed.")
30
+ st.stop()
31
+
32
+ try:
33
+ groq_api_key = groq_api_key.strip()
34
+ llm = ChatGroq(model="llama-3.1-8b-instant", api_key=groq_api_key, temperature=0)
35
+
36
+ # Test call: ask a trivial question
37
+ response = llm.invoke("Hello")
38
+
39
+ except Exception as e:
40
+ st.error(f"❌ Invalid Groq API Key or connection error: {e}")
41
+ st.stop()
42
+
43
+
44
+ uploaded_file = st.file_uploader("📄 Upload a PDF file", type=["pdf"])
45
+ if not uploaded_file:
46
+ st.info("Please upload a PDF file to begin.")
47
+ st.stop()
48
+
49
+
50
+ if "processed" not in st.session_state:
51
+ with st.spinner("📚 Loading and splitting PDF..."):
52
+ pdf_path = os.path.join("temp.pdf")
53
+ with open(pdf_path, "wb") as f:
54
+ f.write(uploaded_file.read())
55
+
56
+ loader = PyPDFLoader(pdf_path)
57
+ documents = loader.load()
58
+
59
+ splitter = RecursiveCharacterTextSplitter(chunk_size=600, chunk_overlap=100)
60
+ texts = splitter.split_documents(documents)
61
+
62
+ embedding = OllamaEmbeddings(model="mxbai-embed-large")
63
+ vectorstore = Chroma.from_documents(documents=texts, embedding=embedding)
64
+ retriever = vectorstore.as_retriever(search_type="similarity", search_kwargs={"k": 4})
65
+
66
+ st.session_state["retriever"] = retriever
67
+ st.session_state["texts"] = texts
68
+ st.session_state["processed"] = True
69
+
70
+ st.success(f"✅ Processed {len(st.session_state['texts'])} text chunks from your PDF.")
71
+
72
+
73
+ system_prompt = (
74
+ "You are an intelligent question–answer generation assistant. "
75
+ "Your task is to read the provided text content (retrieved from a PDF document) "
76
+ "and create meaningful, diverse, and contextually accurate question–answer pairs.\n\n"
77
+ "Follow these rules strictly:\n"
78
+ "1. Generate clear and concise questions based only on the given text.\n"
79
+ "2. Each question must be answerable from the context — do not invent facts.\n"
80
+ "3. Write the corresponding answer immediately after each question.\n"
81
+ "4. Prefer factual, conceptual, or reasoning-based questions rather than trivial ones.\n"
82
+ "5. Output format must be clean and structured like this:\n\n"
83
+ "Q1: <question text>\n"
84
+ "A1: <answer text>\n\n"
85
+ "Q2: <question text>\n"
86
+ "A2: <answer text>\n\n"
87
+ "6. If the text contains multiple sections, cover all major ideas fairly.\n"
88
+ "7. Avoid repeating the same type of question; vary the question style (factual, analytical, summary, etc.).\n\n"
89
+ "Your output should only include the question–answer pairs. Do not add explanations or comments.\n\n"
90
+ "Here is the context:\n\n{context}"
91
+ )
92
+
93
+ prompt = ChatPromptTemplate.from_messages([
94
+ ("system", system_prompt),
95
+ ("user", "{question}")
96
+ ])
97
+
98
+
99
+ llm = ChatGroq(model="llama-3.1-8b-instant",
100
+ api_key=groq_api_key, temperature=0.7)
101
+ parser = StrOutputParser()
102
+
103
+
104
+ def create_rag_chain(retriever, model, prompt):
105
+ def fetch_context(user_input):
106
+ docs = retriever.get_relevant_documents(user_input)
107
+ context = "\n\n".join([doc.page_content for doc in docs])
108
+ return {"context": context, "question": user_input}
109
+
110
+ chain = fetch_context | prompt | model | parser
111
+ return chain
112
+
113
+ rag_chain = create_rag_chain(st.session_state["retriever"], llm, prompt)
114
+
115
+
116
+ def parse_qa_pairs(model_output):
117
+ pattern = r"Q\d+:\s*(.*?)\nA\d+:\s*(.*?)(?=\nQ\d+:|\Z)"
118
+ matches = re.findall(pattern, model_output, re.DOTALL)
119
+ return [{"Question": q.strip(), "Answer": a.strip()} for q, a in matches]
120
+
121
+
122
+ st.subheader("💬 Ask Questions from the PDF")
123
+ user_question = st.text_input("Enter your question or request Q&A generation:")
124
+
125
+ if "qa_data" not in st.session_state:
126
+ st.session_state.qa_data = []
127
+
128
+ if st.button("Generate Answer") and user_question.strip():
129
+ with st.spinner("🤖 Generating answer..."):
130
+ rag_chain = create_rag_chain(st.session_state["retriever"], llm, prompt)
131
+ model_output = rag_chain.invoke({"question": user_question})
132
+
133
+ # Parse Q&A pairs
134
+ parsed_qa = parse_qa_pairs(model_output)
135
+ st.session_state.qa_data.extend(parsed_qa)
136
+
137
+ for i, item in enumerate(parsed_qa, start=1):
138
+ question = item.get("Question", "No Question Found")
139
+ answer = item.get("Answer", "No Answer Found")
140
+ st.markdown(f"**Q{i}:** {question}")
141
+ st.markdown(f"**A{i}:** {answer}")
142
+ st.markdown("---") # separator between Q&A
143
+
144
+
145
+
146
+
147
+ if st.session_state.qa_data:
148
+ df = pd.DataFrame(st.session_state.qa_data)
149
+ st.download_button(
150
+ label="📥 Download Q&A CSV",
151
+ data=df.to_csv(index=False).encode("utf-8"),
152
+ file_name="qa_results.csv",
153
+ mime="text/csv"
154
+ )