udituen commited on
Commit
362de84
Β·
1 Parent(s): ec4695f

code refactor

Browse files
src/streamlit_app.py β†’ app_archive.py RENAMED
File without changes
chains/qa_chain.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Conversational QA chain setup."""
2
+
3
+ from langchain.chains import ConversationalRetrievalChain
4
+ from langchain.memory import ConversationBufferMemory
5
+
6
+
7
+ def create_qa_chain(llm, retriever):
8
+ """
9
+ Create conversational QA chain with memory.
10
+
11
+ Args:
12
+ llm: Language model
13
+ retriever: Document retriever
14
+
15
+ Returns:
16
+ ConversationalRetrievalChain
17
+ """
18
+ memory = ConversationBufferMemory(
19
+ memory_key="chat_history",
20
+ return_messages=True,
21
+ output_key="answer"
22
+ )
23
+
24
+ chain = ConversationalRetrievalChain.from_llm(
25
+ llm=llm,
26
+ retriever=retriever,
27
+ memory=memory,
28
+ return_source_documents=True
29
+ )
30
+
31
+ return chain
config.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Configuration settings for the RAG application."""
2
+
3
+ SAMPLE_TEXT = """Fertilizers help improve soil nutrients and crop yield.
4
+ Irrigation methods vary depending on climate and crop type.
5
+ Crop rotation can enhance soil health and reduce pests.
6
+ Composting is an organic way to enrich the soil.
7
+ Weed management is essential for higher productivity."""
8
+
9
+ EXAMPLE_QUESTIONS = [
10
+ "What is this document about?",
11
+ "What is the role of fertilizers in agriculture?",
12
+ "Why is crop rotation important?",
13
+ "How does composting help farming?",
14
+ ]
15
+
16
+
17
+ # Model configurations
18
+ QWEN_MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct"
19
+ EMBEDDING_MODEL_NAME = "all-MiniLM-L6-v2"
20
+
21
+ # Generation parameters
22
+ MAX_NEW_TOKENS = 256
23
+ TEMPERATURE = 0.7
24
+ TOP_P = 0.95
25
+
models/llm_loader.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """LLM loading and initialization."""
2
+
3
+ import torch
4
+ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
5
+ from langchain_community.llms import HuggingFacePipeline
6
+ import streamlit as st
7
+
8
+
9
+ @st.cache_resource
10
+ def load_qwen_llm(model_name, max_new_tokens=256, temperature=0.7, top_p=0.95):
11
+ """
12
+ Load Qwen LLM model.
13
+
14
+ Args:
15
+ model_name: HuggingFace model identifier
16
+ max_new_tokens: Maximum tokens to generate
17
+ temperature: Sampling temperature
18
+ top_p: Nucleus sampling parameter
19
+
20
+ Returns:
21
+ HuggingFacePipeline: Wrapped LLM for LangChain
22
+ """
23
+
24
+ # Load tokenizer and model
25
+ tokenizer = AutoTokenizer.from_pretrained(
26
+ model_name,
27
+ trust_remote_code=True
28
+ )
29
+
30
+ model = AutoModelForCausalLM.from_pretrained(
31
+ model_name,
32
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
33
+ device_map="auto" if torch.cuda.is_available() else None,
34
+ trust_remote_code=True
35
+ )
36
+
37
+ # Create pipeline
38
+ pipe = pipeline(
39
+ "text-generation",
40
+ model=model,
41
+ tokenizer=tokenizer,
42
+ max_new_tokens=max_new_tokens,
43
+ temperature=temperature,
44
+ top_p=top_p,
45
+ do_sample=True,
46
+ return_full_text=False
47
+ )
48
+
49
+ return HuggingFacePipeline(pipeline=pipe)
models/retriever.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Document retrieval system."""
2
+
3
+ from langchain_community.vectorstores import FAISS
4
+ from langchain_community.embeddings import HuggingFaceEmbeddings
5
+
6
+
7
+ def build_retriever(docs, embedding_model_name="all-MiniLM-L6-v2"):
8
+ """
9
+ Build FAISS retriever from documents.
10
+
11
+ Args:
12
+ docs: List of text documents
13
+ embedding_model_name: Name of the embedding model
14
+
15
+ Returns:
16
+ Retriever object
17
+ """
18
+ embeddings = HuggingFaceEmbeddings(model_name=embedding_model_name)
19
+ db = FAISS.from_texts(docs, embeddings)
20
+ return db.as_retriever()
streamlit_app.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Main Streamlit application."""
2
+
3
+ import streamlit as st
4
+ from ui.sidebar import render_sidebar
5
+ from ui.chat import render_chat_interface
6
+
7
+
8
+ def initialize_session_state():
9
+ """Initialize session state variables."""
10
+ if 'chat_history' not in st.session_state:
11
+ st.session_state.chat_history = []
12
+ if 'qa_chain' not in st.session_state:
13
+ st.session_state.qa_chain = None
14
+ if 'document_processed' not in st.session_state:
15
+ st.session_state.document_processed = False
16
+
17
+
18
+ def main():
19
+ """Main application entry point."""
20
+ # Page configuration
21
+ st.set_page_config(
22
+ page_title="DocsQA",
23
+ page_icon="",
24
+ layout="wide"
25
+ )
26
+
27
+ # Initialize session state
28
+ initialize_session_state()
29
+
30
+ # Header
31
+ st.title("DocsQA: Chat with Your Document")
32
+ st.markdown("Upload a document and have a conversation about its contents! (Powered by Qwen)")
33
+
34
+ # Render UI components
35
+ render_sidebar()
36
+ render_chat_interface()
37
+
38
+
39
+ if __name__ == "__main__":
40
+ main()
ui/chat.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Chat interface components."""
2
+
3
+ import streamlit as st
4
+
5
+
6
+ def render_chat_interface():
7
+ """Render the main chat interface."""
8
+ if not st.session_state.document_processed:
9
+ st.info("<-- Please upload a document in the sidebar and click 'Process Document' to start chatting!")
10
+ return
11
+
12
+ # Display chat history
13
+ _display_chat_history()
14
+
15
+ # Handle new user input
16
+ _handle_user_input()
17
+
18
+
19
+ def _display_chat_history():
20
+ """Display all messages in chat history."""
21
+ for message in st.session_state.chat_history:
22
+ with st.chat_message(message["role"]):
23
+ st.markdown(message["content"])
24
+
25
+ # Show sources if available
26
+ if message["role"] == "assistant" and "sources" in message:
27
+ with st.expander("View Sources"):
28
+ for i, source in enumerate(message["sources"]):
29
+ st.markdown(f"**Source {i+1}:** {source}")
30
+
31
+
32
+ def _handle_user_input():
33
+ """Handle new user input and generate response."""
34
+ if prompt := st.chat_input("Ask a question about your document..."):
35
+ # Add user message
36
+ st.session_state.chat_history.append({
37
+ "role": "user",
38
+ "content": prompt
39
+ })
40
+
41
+ # Display user message
42
+ with st.chat_message("user"):
43
+ st.markdown(prompt)
44
+
45
+ # Generate and display response
46
+ _generate_response(prompt)
47
+
48
+
49
+ def _generate_response(prompt):
50
+ """Generate AI response to user prompt."""
51
+ with st.chat_message("assistant"):
52
+ with st.spinner("Thinking..."):
53
+ try:
54
+ result = st.session_state.qa_chain({"question": prompt})
55
+
56
+ answer = result["answer"]
57
+ sources = [
58
+ doc.page_content
59
+ for doc in result.get("source_documents", [])
60
+ ]
61
+
62
+ st.markdown(answer)
63
+
64
+ # Show sources
65
+ if sources:
66
+ with st.expander("View Sources"):
67
+ for i, source in enumerate(sources):
68
+ st.markdown(f"**Source {i+1}:** {source}")
69
+
70
+ # Add to chat history
71
+ st.session_state.chat_history.append({
72
+ "role": "assistant",
73
+ "content": answer,
74
+ "sources": sources
75
+ })
76
+
77
+ except Exception as e:
78
+ error_msg = f"Sorry, I encountered an error: {str(e)}"
79
+ st.error(error_msg)
80
+ st.session_state.chat_history.append({
81
+ "role": "assistant",
82
+ "content": error_msg
83
+ })
ui/sidebar.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Sidebar UI components."""
2
+
3
+ import streamlit as st
4
+ from config import SAMPLE_TEXT, EXAMPLE_QUESTIONS
5
+ from utils.document_processor import read_uploaded_file
6
+ from models.retriever import build_retriever
7
+ from models.llm_loader import load_qwen_llm
8
+ from chains.qa_chain import create_qa_chain
9
+ from config import QWEN_MODEL_NAME, EMBEDDING_MODEL_NAME, MAX_NEW_TOKENS, TEMPERATURE, TOP_P
10
+
11
+
12
+ def render_sidebar():
13
+ """Render the sidebar with upload and controls."""
14
+ with st.sidebar:
15
+ st.header("πŸ“„ Document Upload")
16
+
17
+ # Sample file download
18
+ st.download_button(
19
+ label="πŸ“„ Download Sample File",
20
+ data=SAMPLE_TEXT,
21
+ file_name="sample_agri.txt",
22
+ mime="text/plain"
23
+ )
24
+
25
+ # File uploader
26
+ uploaded_file = st.file_uploader(
27
+ "Upload your file",
28
+ type=["txt", "pdf"]
29
+ )
30
+
31
+ if uploaded_file is not None:
32
+ st.success(f"{uploaded_file.name}")
33
+ _handle_document_upload(uploaded_file)
34
+
35
+ # Example questions
36
+ if st.session_state.document_processed:
37
+ _render_example_questions()
38
+
39
+ # Clear chat button
40
+ if st.session_state.chat_history:
41
+ _render_clear_button()
42
+
43
+
44
+ def _handle_document_upload(uploaded_file):
45
+ """Handle document processing."""
46
+ if st.button("Process Document", type="primary"):
47
+ with st.spinner("Processing document..."):
48
+ try:
49
+ docs = read_uploaded_file(uploaded_file)
50
+
51
+ if len(docs) > 0:
52
+ retriever = build_retriever(docs, EMBEDDING_MODEL_NAME)
53
+ llm = load_qwen_llm(
54
+ QWEN_MODEL_NAME,
55
+ MAX_NEW_TOKENS,
56
+ TEMPERATURE,
57
+ TOP_P
58
+ )
59
+
60
+ st.session_state.qa_chain = create_qa_chain(llm, retriever)
61
+ st.session_state.document_processed = True
62
+ st.session_state.chat_history = []
63
+
64
+ st.success(f"Processed {len(docs)} text chunks!")
65
+ st.rerun()
66
+ else:
67
+ st.error("No content found in file.")
68
+
69
+ except Exception as e:
70
+ st.error(f"Error: {str(e)}")
71
+
72
+
73
+ def _render_example_questions():
74
+ """Render example question buttons."""
75
+ st.markdown("---")
76
+ st.subheader("πŸ’‘ Example Questions")
77
+ for q in EXAMPLE_QUESTIONS:
78
+ if st.button(q, key=f"example_{q}"):
79
+ st.session_state.user_input = q
80
+ st.rerun()
81
+
82
+
83
+ def _render_clear_button():
84
+ """Render clear chat history button."""
85
+ st.markdown("---")
86
+ if st.button("πŸ—‘οΈ Clear Chat History"):
87
+ st.session_state.chat_history = []
88
+ st.rerun()
utils/doc_processor.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Utilities for processing uploaded documents."""
2
+
3
+ import io
4
+
5
+ try:
6
+ from pypdf import PdfReader
7
+ except ImportError:
8
+ from PyPDF2 import PdfReader
9
+
10
+
11
+ def read_uploaded_file(uploaded_file):
12
+ """
13
+ Read and process uploaded file (TXT or PDF).
14
+
15
+ Args:
16
+ uploaded_file: Streamlit UploadedFile object
17
+
18
+ Returns:
19
+ list: List of text chunks from the document
20
+ """
21
+
22
+ uploaded_file.seek(0)
23
+
24
+ if uploaded_file.type == "application/pdf":
25
+ return process_pdf(uploaded_file)
26
+ else:
27
+ return process_text(uploaded_file)
28
+
29
+
30
+ def process_pdf(uploaded_file):
31
+ """Extract text from PDF file."""
32
+ pdf_reader = PdfReader(io.BytesIO(uploaded_file.read()))
33
+ text = ""
34
+ for page in pdf_reader.pages:
35
+ text += page.extract_text() + "\n"
36
+ return split_into_chunks(text)
37
+
38
+
39
+ def process_text(uploaded_file):
40
+ """Read text file."""
41
+ text = uploaded_file.read().decode("utf-8")
42
+ return split_into_chunks(text)
43
+
44
+
45
+ def split_into_chunks(text):
46
+ """Split text into chunks by lines."""
47
+ docs = text.split("\n")
48
+ docs = [doc.strip() for doc in docs if doc.strip()]
49
+ return docs