Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import os | |
| import tempfile | |
| import google.generativeai as genai | |
| from pypdf import PdfReader | |
| from pinecone import Pinecone | |
| import uuid | |
| import time | |
| import json | |
| from dotenv import load_dotenv | |
| from langchain_community.embeddings import HuggingFaceEmbeddings | |
| # Load environment variables from .env file | |
| load_dotenv() | |
| # Configuration | |
| st.set_page_config(page_title="PDF Learning Assistant", layout="wide") | |
| # Initialize session state | |
| if "messages" not in st.session_state: | |
| st.session_state.messages = [] | |
| if "current_pdf_content" not in st.session_state: | |
| st.session_state.current_pdf_content = "" | |
| if "current_pdf_name" not in st.session_state: | |
| st.session_state.current_pdf_name = "" | |
| if "index_name" not in st.session_state: | |
| st.session_state.index_name = "index1" # Using your specific index name | |
| if "is_initialized" not in st.session_state: | |
| st.session_state.is_initialized = False | |
| if "index_dimensions" not in st.session_state: | |
| st.session_state.index_dimensions = 1024 # Set this based on your Pinecone index | |
| if "quiz_submitted" not in st.session_state: | |
| st.session_state.quiz_submitted = False | |
| if "quiz_answers" not in st.session_state: | |
| st.session_state.quiz_answers = {} | |
| # Functions for PDF processing | |
| def extract_text_from_pdf(pdf_file): | |
| reader = PdfReader(pdf_file) | |
| text = "" | |
| for page in reader.pages: | |
| text += page.extract_text() + "\n" | |
| return text | |
| def chunk_text(text, chunk_size=1000, overlap=200): | |
| chunks = [] | |
| start = 0 | |
| text_length = len(text) | |
| while start < text_length: | |
| end = min(start + chunk_size, text_length) | |
| if end < text_length and end - start == chunk_size: | |
| # Find the last period or newline to make more natural chunks | |
| last_period = text.rfind('.', start, end) | |
| last_newline = text.rfind('\n', start, end) | |
| if last_period > start + chunk_size // 2: | |
| end = last_period + 1 | |
| elif last_newline > start + chunk_size // 2: | |
| end = last_newline + 1 | |
| chunks.append(text[start:end]) | |
| start = end - overlap if end < text_length else text_length | |
| return chunks | |
| # Embeddings and Vector Store functions | |
| def get_embedding_model(): | |
| # Using a model that produces 1024-dimensional embeddings | |
| return HuggingFaceEmbeddings(model_name="sentence-transformers/all-roberta-large-v1") | |
| def initialize_pinecone(): | |
| # Get API key from environment variables | |
| api_key = os.getenv("PINECONE_API_KEY") | |
| if not api_key: | |
| st.error("Pinecone API key not found. Please add it to your .env file as PINECONE_API_KEY=your_api_key") | |
| return False | |
| try: | |
| # Initialize Pinecone with your specific configuration | |
| pc = Pinecone(api_key=api_key) | |
| # Store Pinecone client in session state | |
| st.session_state.pinecone_client = pc | |
| # Check if your index exists | |
| index_list = [idx.name for idx in pc.list_indexes()] | |
| if st.session_state.index_name not in index_list: | |
| st.error(f"Index '{st.session_state.index_name}' not found in your Pinecone account.") | |
| st.info("Available indexes: " + ", ".join(index_list)) | |
| return False | |
| # Get index details to check dimensions | |
| try: | |
| index = pc.Index(st.session_state.index_name) | |
| index_stats = index.describe_index_stats() | |
| if 'dimension' in index_stats: | |
| st.session_state.index_dimensions = index_stats['dimension'] | |
| st.info(f"Detected index dimension: {st.session_state.index_dimensions}") | |
| else: | |
| st.warning("Could not detect index dimensions. Using default: 1024") | |
| except Exception as e: | |
| st.warning(f"Could not get index details: {str(e)}") | |
| return True | |
| except Exception as e: | |
| st.error(f"Error initializing Pinecone: {str(e)}") | |
| return False | |
| def get_pinecone_index(): | |
| # Connect to your existing index | |
| return st.session_state.pinecone_client.Index(st.session_state.index_name) | |
| def embed_chunks(chunks): | |
| model = get_embedding_model() | |
| embeddings = [] | |
| for chunk in chunks: | |
| # HuggingFaceEmbeddings returns a list with a single embedding | |
| embed = model.embed_documents([chunk])[0] | |
| embeddings.append(embed) | |
| return embeddings | |
| def store_embeddings(chunks, embeddings, pdf_name): | |
| index = get_pinecone_index() | |
| batch_size = 100 | |
| for i in range(0, len(chunks), batch_size): | |
| i_end = min(i + batch_size, len(chunks)) | |
| ids = [f"{pdf_name}-{uuid.uuid4()}" for _ in range(i, i_end)] | |
| metadata = [{"text": chunks[j], "pdf_name": pdf_name, "chunk_id": j} for j in range(i, i_end)] | |
| vectors = [(ids[j-i], embeddings[j], metadata[j-i]) for j in range(i, i_end)] | |
| try: | |
| index.upsert(vectors=vectors) | |
| st.success(f"Successfully stored batch {i//batch_size + 1} of chunks to Pinecone") | |
| except Exception as e: | |
| st.error(f"Error storing embeddings: {str(e)}") | |
| # Display the first embedding's dimension for debugging | |
| if embeddings and len(embeddings) > 0: | |
| st.info(f"Embedding dimension: {len(embeddings[0])}") | |
| return False | |
| st.success(f"Successfully stored all {len(chunks)} chunks to Pinecone") | |
| return True | |
| def search_similar_chunks(query, top_k=5, pdf_name=None): | |
| model = get_embedding_model() | |
| query_embedding = model.embed_query(query) | |
| index = get_pinecone_index() | |
| filter_query = {"pdf_name": pdf_name} if pdf_name else None | |
| results = index.query( | |
| vector=query_embedding, | |
| top_k=top_k, | |
| include_metadata=True, | |
| filter=filter_query | |
| ) | |
| return results.matches | |
| # Gemini LLM Integration | |
| def initialize_gemini(): | |
| api_key = os.getenv("GOOGLE_API_KEY") | |
| if not api_key: | |
| st.error("Google API key not found. Please add it to your .env file as GOOGLE_API_KEY=your_api_key") | |
| return False | |
| try: | |
| genai.configure(api_key=api_key) | |
| return True | |
| except Exception as e: | |
| st.error(f"Error initializing Google Generative AI: {str(e)}") | |
| return False | |
| def get_gemini_response(prompt, context=None, temperature=0.7): | |
| try: | |
| model = genai.GenerativeModel('gemini-2.0-flash') | |
| if context: | |
| full_prompt = f""" | |
| Context information: | |
| {context} | |
| Question: {prompt} | |
| Please provide a helpful, accurate response based on the context information provided. | |
| If the answer cannot be determined from the context, please state that clearly. | |
| """ | |
| else: | |
| full_prompt = prompt | |
| response = model.generate_content(full_prompt, generation_config={"temperature": temperature}) | |
| return response.text | |
| except Exception as e: | |
| st.error(f"Error getting response from Gemini: {str(e)}") | |
| return "Sorry, I couldn't generate a response at this time." | |
| # Quiz and Assignment Generation | |
| def generate_quiz(pdf_content, num_questions=5): | |
| prompt = f""" | |
| Based on the following content, generate a quiz with {num_questions} multiple-choice questions. | |
| For each question, provide 4 options and indicate the correct answer. | |
| Format the response as a JSON array of question objects with the structure: | |
| [ | |
| {{ | |
| "question": "Question text", | |
| "options": ["Option A", "Option B", "Option C", "Option D"], | |
| "correct_answer": "Correct option (A, B, C, or D)", | |
| "explanation": "Brief explanation of why this is the correct answer" | |
| }}, | |
| // more questions... | |
| ] | |
| Content: {pdf_content[:2000]}... (truncated for brevity) | |
| """ | |
| response = get_gemini_response(prompt, temperature=0.2) | |
| try: | |
| # Extract JSON from response if it's embedded in markdown or text | |
| if "```json" in response: | |
| json_start = response.find("```json") + 7 | |
| json_end = response.find("```", json_start) | |
| json_str = response[json_start:json_end].strip() | |
| elif "```" in response: | |
| json_start = response.find("```") + 3 | |
| json_end = response.find("```", json_start) | |
| json_str = response[json_start:json_end].strip() | |
| else: | |
| json_str = response | |
| quiz_data = json.loads(json_str) | |
| return quiz_data | |
| except Exception as e: | |
| st.error(f"Error parsing quiz response: {str(e)}") | |
| return [] | |
| def generate_assignment(pdf_content, assignment_type="short_answer", num_questions=3): | |
| prompt = f""" | |
| Based on the following content, generate a {assignment_type} assignment with {num_questions} questions. | |
| If the assignment type is 'short_answer', create questions that require brief explanations. | |
| If the assignment type is 'essay', create deeper questions that require longer responses. | |
| If the assignment type is 'research', create questions that encourage further exploration of the topics. | |
| Format the response as a JSON array with the structure: | |
| [ | |
| {{ | |
| "question": "Question text", | |
| "hints": ["Hint 1", "Hint 2"], | |
| "key_points": ["Key point 1", "Key point 2", "Key point 3"] | |
| }}, | |
| // more questions... | |
| ] | |
| Content: {pdf_content[:2000]}... (truncated for brevity) | |
| """ | |
| response = get_gemini_response(prompt, temperature=0.3) | |
| try: | |
| # Extract JSON from response if it's embedded in markdown or text | |
| if "```json" in response: | |
| json_start = response.find("```json") + 7 | |
| json_end = response.find("```", json_start) | |
| json_str = response[json_start:json_end].strip() | |
| elif "```" in response: | |
| json_start = response.find("```") + 3 | |
| json_end = response.find("```", json_start) | |
| json_str = response[json_start:json_end].strip() | |
| else: | |
| json_str = response | |
| assignment_data = json.loads(json_str) | |
| return assignment_data | |
| except Exception as e: | |
| st.error(f"Error parsing assignment response: {str(e)}") | |
| return [] | |
| # Callback functions for quiz submission and reset | |
| def submit_quiz(): | |
| st.session_state.quiz_submitted = True | |
| def reset_quiz(): | |
| st.session_state.quiz_submitted = False | |
| st.session_state.quiz_answers = {} | |
| # Streamlit UI | |
| def main(): | |
| st.title("📚 PDF Learning Assistant") | |
| # Initialize services | |
| if not st.session_state.is_initialized: | |
| with st.spinner("Initializing services..."): | |
| pinecone_init = initialize_pinecone() | |
| gemini_init = initialize_gemini() | |
| if pinecone_init and gemini_init: | |
| st.session_state.is_initialized = True | |
| st.success("Services initialized successfully!") | |
| # Display Pinecone connection info | |
| st.info(f""" | |
| Connected to Pinecone index: | |
| - Index name: {st.session_state.index_name} | |
| - Dimension: {st.session_state.index_dimensions} | |
| - Host: https://index1-mwog0w0.svc.aped-4627-b74a.pinecone.io | |
| - Region: us-east-1 | |
| - Type: Dense | |
| - Capacity: Serverless | |
| """) | |
| else: | |
| st.error("Failed to initialize all required services. Please check your API keys in the .env file.") | |
| # Show .env file template | |
| st.code(""" | |
| # Create a .env file in the same directory with the following content: | |
| PINECONE_API_KEY=your_pinecone_api_key | |
| GOOGLE_API_KEY=your_google_api_key | |
| """) | |
| return | |
| # Sidebar for PDF upload and main actions | |
| with st.sidebar: | |
| st.header("Upload PDF") | |
| uploaded_file = st.file_uploader("Choose a PDF file", type="pdf") | |
| if uploaded_file: | |
| with st.spinner("Processing PDF..."): | |
| # Save uploaded file to temp location | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") as temp_file: | |
| temp_file.write(uploaded_file.getvalue()) | |
| temp_path = temp_file.name | |
| # Extract text | |
| pdf_text = extract_text_from_pdf(temp_path) | |
| os.unlink(temp_path) # Delete temp file | |
| # Store in session state | |
| st.session_state.current_pdf_content = pdf_text | |
| st.session_state.current_pdf_name = uploaded_file.name | |
| # Chunk and embed | |
| chunks = chunk_text(pdf_text) | |
| embeddings = embed_chunks(chunks) | |
| # Store in Pinecone | |
| success = store_embeddings(chunks, embeddings, st.session_state.current_pdf_name) | |
| if success: | |
| st.success(f"Successfully processed {uploaded_file.name}") | |
| else: | |
| st.error(f"Failed to process {uploaded_file.name}") | |
| st.divider() | |
| st.header("Learning Tools") | |
| # Only enable these buttons if a PDF is loaded | |
| if st.session_state.current_pdf_content: | |
| # Quiz generation | |
| quiz_questions = st.slider("Number of quiz questions", min_value=3, max_value=10, value=5) | |
| if st.button("Generate Quiz"): | |
| with st.spinner("Generating quiz..."): | |
| quiz_data = generate_quiz(st.session_state.current_pdf_content, num_questions=quiz_questions) | |
| st.session_state.quiz_data = quiz_data | |
| # Reset quiz state when generating a new quiz | |
| reset_quiz() | |
| # Assignment generation | |
| assignment_type = st.selectbox( | |
| "Assignment Type", | |
| ["short_answer", "essay", "research"] | |
| ) | |
| assignment_questions = st.slider("Number of assignment questions", min_value=1, max_value=5, value=3) | |
| if st.button("Generate Assignment"): | |
| with st.spinner("Generating assignment..."): | |
| assignment_data = generate_assignment( | |
| st.session_state.current_pdf_content, | |
| assignment_type, | |
| num_questions=assignment_questions | |
| ) | |
| st.session_state.assignment_data = assignment_data | |
| else: | |
| st.info("Please upload a PDF first to use these features") | |
| # Main content area | |
| tab1, tab2, tab3 = st.tabs(["Chatbot", "Quiz", "Assignment"]) | |
| # Tab 1: Chatbot | |
| with tab1: | |
| st.header("Chat with your PDF") | |
| # Display chat messages | |
| for message in st.session_state.messages: | |
| with st.chat_message(message["role"]): | |
| st.write(message["content"]) | |
| # Chat input | |
| if st.session_state.current_pdf_content: | |
| user_input = st.chat_input("Ask a question about your PDF...") | |
| if user_input: | |
| # Add user message to chat history | |
| st.session_state.messages.append({"role": "user", "content": user_input}) | |
| with st.chat_message("user"): | |
| st.write(user_input) | |
| # Generate response | |
| with st.chat_message("assistant"): | |
| with st.spinner("Thinking..."): | |
| # Search for relevant context | |
| similar_chunks = search_similar_chunks( | |
| user_input, | |
| top_k=3, | |
| pdf_name=st.session_state.current_pdf_name | |
| ) | |
| # Extract text from results | |
| context = "\n\n".join([match.metadata["text"] for match in similar_chunks]) | |
| # Get response from Gemini | |
| response = get_gemini_response(user_input, context) | |
| st.write(response) | |
| # Add assistant message to chat history | |
| st.session_state.messages.append({"role": "assistant", "content": response}) | |
| else: | |
| st.info("Please upload a PDF to start chatting") | |
| # Tab 2: Quiz | |
| with tab2: | |
| st.header("Quiz") | |
| if "quiz_data" in st.session_state and st.session_state.quiz_data: | |
| quiz_data = st.session_state.quiz_data | |
| # Quiz display logic - static until submitted | |
| if not st.session_state.quiz_submitted: | |
| # Quiz form | |
| with st.form(key="quiz_form"): | |
| for i, question in enumerate(quiz_data): | |
| st.subheader(f"Question {i+1}") | |
| st.write(question["question"]) | |
| options = question["options"] | |
| option_labels = ["A", "B", "C", "D"] | |
| # Create radio buttons for options | |
| answer = st.radio( | |
| "Select your answer:", | |
| options=option_labels[:len(options)], | |
| key=f"q{i}", | |
| index=None | |
| ) | |
| # Display options | |
| for j, option in enumerate(options): | |
| st.write(f"{option_labels[j]}: {option}") | |
| # Store answer in session state | |
| if answer: | |
| st.session_state.quiz_answers[i] = answer | |
| st.divider() | |
| # Submit button inside the form | |
| submit_button = st.form_submit_button("Submit Quiz") | |
| if submit_button: | |
| st.session_state.quiz_submitted = True | |
| else: | |
| # Show results after submission | |
| correct_count = 0 | |
| for i, question in enumerate(quiz_data): | |
| st.subheader(f"Question {i+1}") | |
| st.write(question["question"]) | |
| options = question["options"] | |
| option_labels = ["A", "B", "C", "D"] | |
| correct_letter = question["correct_answer"] | |
| user_answer = st.session_state.quiz_answers.get(i) | |
| # Display options with correct/incorrect indicators | |
| for j, option in enumerate(options): | |
| current_label = option_labels[j] | |
| if current_label == correct_letter: | |
| st.success(f"{current_label}: {option} ✓") | |
| if user_answer == current_label: | |
| correct_count += 1 | |
| elif user_answer == current_label: | |
| st.error(f"{current_label}: {option} ✗") | |
| else: | |
| st.write(f"{current_label}: {option}") | |
| # Show explanation | |
| st.info(f"Explanation: {question['explanation']}") | |
| st.divider() | |
| st.subheader(f"Your Score: {correct_count}/{len(quiz_data)}") | |
| if st.button("Retake Quiz"): | |
| reset_quiz() | |
| else: | |
| st.info("Generate a quiz from the sidebar to see it here") | |
| # Tab 3: Assignment | |
| # Tab 3: Assignment | |
| with tab3: | |
| st.header("Assignment") | |
| if "assignment_data" in st.session_state and st.session_state.assignment_data: | |
| assignment_data = st.session_state.assignment_data | |
| # Create a form for the assignment | |
| with st.form(key="assignment_form"): | |
| for i, question in enumerate(assignment_data): | |
| st.subheader(f"Question {i+1}") | |
| st.write(question["question"]) | |
| # Use a checkbox to toggle hints instead of a nested expander | |
| if "hints" in question and question["hints"]: | |
| show_hints = st.checkbox(f"Show hints for Question {i+1}", key=f"hint_checkbox_{i}") | |
| if show_hints: | |
| for hint in question["hints"]: | |
| st.write(f"- {hint}") | |
| # Input area for the answer | |
| st.text_area("Your Answer:", key=f"assignment_q{i}", height=150) | |
| st.divider() | |
| # Add the submit button as a direct child of the form | |
| submit_assignment = st.form_submit_button("Submit Assignment") | |
| # Process form submission outside the form block | |
| if submit_assignment: | |
| st.success("Assignment submitted! Here are the key points for each question:") | |
| for i, question in enumerate(assignment_data): | |
| with st.expander(f"Key Points for Question {i+1}", expanded=True): | |
| for point in question["key_points"]: | |
| st.write(f"- {point}") | |
| else: | |
| st.info("Generate an assignment from the sidebar to see it here") | |
| if __name__ == "__main__": | |
| main() |