Spaces:
Sleeping
Sleeping
| # app.py | |
| import streamlit as st | |
| import os | |
| import faiss | |
| import pickle | |
| from sentence_transformers import SentenceTransformer | |
| from groq import Groq | |
| from dotenv import load_dotenv | |
| import re # Import regular expressions for expand_query_with_llm_app | |
| # --- Page Configuration (MUST BE THE FIRST STREAMLIT COMMAND) --- | |
| st.set_page_config(page_title="RAG BITS Tutor", page_icon="🎓") | |
| # --- Configuration and Model Loading (best outside functions to use caching) --- | |
| # Important for caching large models and data | |
| def load_models_and_data(): | |
| # Load environment variables (if .env file is present in the Space) | |
| load_dotenv() | |
| groq_api_key_app = os.getenv("GROQ_API_KEY") | |
| # Paths to the index and chunks | |
| output_folder = "faiss_index_bits" # Must be present in the HF Space | |
| index_path = os.path.join(output_folder, "bits_tutor.index") | |
| chunks_path = os.path.join(output_folder, "bits_chunks.pkl") | |
| # Load FAISS Index | |
| if not os.path.exists(index_path): | |
| st.error(f"FAISS Index not found at: {index_path}") | |
| return None, None, None, None | |
| index_loaded = faiss.read_index(index_path) | |
| # Load Chunks | |
| if not os.path.exists(chunks_path): | |
| st.error(f"Chunks file not found at: {chunks_path}") | |
| return None, None, None, None | |
| with open(chunks_path, "rb") as f: | |
| chunks_loaded = pickle.load(f) | |
| # Load Embedding Model | |
| embedding_model_name_app = "Sahajtomar/German-semantic" | |
| embedding_model_loaded = SentenceTransformer(embedding_model_name_app) | |
| # Initialize Groq Client | |
| if not groq_api_key_app: | |
| st.error("GROQ_API_KEY not found. Please add it as a Secret in the Hugging Face Space settings.") | |
| return None, None, None, None | |
| groq_client_loaded = Groq(api_key=groq_api_key_app) | |
| return index_loaded, chunks_loaded, embedding_model_loaded, groq_client_loaded | |
| # Load models and data when the app starts | |
| # Important: The function load_models_and_data() uses st.error(), which is a Streamlit command. | |
| # Therefore, st.set_page_config() must be called BEFORE the first possible call to st.error(). | |
| faiss_index, chunks_data, embedding_model, groq_client = load_models_and_data() | |
| # --- Core RAG Functions (adapted for Streamlit app) --- | |
| def retrieve_relevant_chunks_app(query, k=5): | |
| # This function uses the globally loaded embedding_model, faiss_index, and chunks_data | |
| if embedding_model is None or faiss_index is None or chunks_data is None: | |
| st.warning("Models or data not loaded correctly. Cannot retrieve chunks.") | |
| return [] | |
| query_embedding = embedding_model.encode([query], convert_to_numpy=True) | |
| distances, indices = faiss_index.search(query_embedding, k) | |
| retrieved_chunks_data = [(chunks_data[i], distances[0][j]) for j, i in enumerate(indices[0])] | |
| return retrieved_chunks_data | |
| def generate_answer_app(query, retrieved_chunks_data): | |
| # This function uses the globally loaded groq_client | |
| if groq_client is None: | |
| st.error("Groq Client not initialized. Cannot generate answer.") | |
| return "Error: LLM client not available." | |
| context = "\n\n".join([chunk_text for chunk_text, dist in retrieved_chunks_data]) | |
| # This prompt_template remains in German as it instructs the LLM for the German-speaking tutor | |
| prompt_template = f"""Beantworte die folgende Frage ausschließlich basierend auf dem bereitgestellten Kontext aus den Lehrmaterialien zur Business IT Strategie. | |
| Antworte auf Deutsch. | |
| Kontext: | |
| {context} | |
| Frage: {query} | |
| Antwort: | |
| """ | |
| try: | |
| chat_completion = groq_client.chat.completions.create( | |
| messages=[{"role": "user", "content": prompt_template}], | |
| model="llama3-70b-8192", | |
| temperature=0.3, | |
| ) | |
| return chat_completion.choices[0].message.content | |
| except Exception as e: | |
| # Developer-facing error, UI will show a generic message or this | |
| st.error(f"Error during LLM request in generate_answer_app: {e}") | |
| return "An error occurred while generating the answer." | |
| def expand_query_with_llm_app(original_query, llm_client_app): | |
| # This function uses the passed llm_client_app (which should be the global groq_client) | |
| if llm_client_app is None: | |
| st.warning("LLM client for query expansion not initialized.") | |
| return [original_query] | |
| # This prompt_template remains in German | |
| prompt_template_expansion = f"""Gegeben ist die folgende Nutzerfrage zum Thema "Business IT Strategie": "{original_query}" | |
| Bitte generiere 2-3 alternative Formulierungen dieser Frage ODER eine Liste von 3-5 sehr relevanten Schlüsselbegriffen/Konzepten, | |
| die helfen würden, in einer Wissensdatenbank nach Antworten zu dieser Frage zu suchen. | |
| Formatiere die Ausgabe klar, z.B. als nummerierte Liste für alternative Fragen oder als kommaseparierte Liste für Schlüsselbegriffe. | |
| Gib NUR die alternativen Formulierungen oder die Schlüsselbegriffe aus. Keine Einleitungssätze. | |
| """ | |
| try: | |
| chat_completion = llm_client_app.chat.completions.create( | |
| messages=[ | |
| { | |
| "role": "user", | |
| "content": prompt_template_expansion, | |
| } | |
| ], | |
| model="llama3-8b-8192", | |
| temperature=0.5, | |
| ) | |
| expanded_terms_text = chat_completion.choices[0].message.content | |
| cleaned_queries = [] | |
| potential_queries = expanded_terms_text.split('\n') | |
| for line in potential_queries: | |
| line = line.strip() | |
| line = re.sub(r"^\s*\d+\.\s*", "", line) | |
| line = re.sub(r"^\s*[-\*]\s*", "", line) | |
| line = line.strip() | |
| if not line or \ | |
| line.lower().startswith("here are") or \ | |
| line.lower().startswith("sicher, hier sind") or \ | |
| line.lower().startswith("alternative formulierungen:") or \ | |
| line.lower().startswith("*alternative formulierungen:**") or \ | |
| len(line) < 5: | |
| continue | |
| cleaned_queries.append(line) | |
| if len(cleaned_queries) == 1 and ',' in cleaned_queries[0] and len(cleaned_queries[0].split(',')) > 1: | |
| final_expanded_list = [term.strip() for term in cleaned_queries[0].split(',') if term.strip() and len(term.strip()) > 4] | |
| else: | |
| final_expanded_list = cleaned_queries | |
| all_queries = [original_query] | |
| for q_exp in final_expanded_list: | |
| is_duplicate = False | |
| for q_all in all_queries: | |
| if q_all.lower() == q_exp.lower(): | |
| is_duplicate = True | |
| break | |
| if not is_duplicate: | |
| all_queries.append(q_exp) | |
| return all_queries[:4] | |
| except Exception as e: | |
| st.warning(f"Error during Query Expansion with LLM: {e}") | |
| return [original_query] | |
| def retrieve_with_expanded_queries_app(original_query, llm_client_app, retrieve_func, k_per_expansion=2): | |
| expanded_queries = expand_query_with_llm_app(original_query, llm_client_app) | |
| # For UI feedback / debugging | |
| # st.write(f"Using the following queries for retrieval after expansion:") | |
| # for i, eq_query in enumerate(expanded_queries): | |
| # st.caption(f" ExpQuery {i}: {eq_query}") | |
| all_retrieved_chunks_data = [] | |
| for eq_query in expanded_queries: | |
| retrieved_for_eq = retrieve_func(eq_query, k=k_per_expansion) | |
| all_retrieved_chunks_data.extend(retrieved_for_eq) | |
| unique_chunks_dict = {} | |
| for chunk_text, distance in all_retrieved_chunks_data: | |
| if chunk_text not in unique_chunks_dict or distance < unique_chunks_dict[chunk_text]: | |
| unique_chunks_dict[chunk_text] = distance | |
| sorted_unique_chunks_data = sorted(unique_chunks_dict.items(), key=lambda item: item[1]) | |
| final_chunks_for_context = sorted_unique_chunks_data[:5] | |
| # For UI feedback / debugging | |
| # st.write(f"\n{len(final_chunks_for_context)} unique chunks were selected for the context.") | |
| return final_chunks_for_context | |
| # --- Streamlit UI --- | |
| st.title("🎓 RAG Study Tutor for Business IT Strategy") | |
| st.write("Ask your questions about the content of the lecture notes and case studies (in German).") | |
| # User query input field (remains German for the user) | |
| user_query_streamlit = st.text_input("Deine Frage:", "") | |
| # Option to use query expansion | |
| use_expansion = st.checkbox("Use Query Expansion (may improve results for some questions)", value=True) | |
| if user_query_streamlit: | |
| # Check if models and data are loaded successfully before proceeding | |
| if faiss_index and chunks_data and embedding_model and groq_client: | |
| with st.spinner("Searching for relevant information and generating answer..."): # Loading spinner | |
| retrieved_chunks = [] | |
| if use_expansion: | |
| st.caption("Query expansion is active...") | |
| retrieved_chunks = retrieve_with_expanded_queries_app(user_query_streamlit, groq_client, retrieve_relevant_chunks_app, k_per_expansion=2) | |
| else: | |
| st.caption("Direct retrieval...") | |
| retrieved_chunks = retrieve_relevant_chunks_app(user_query_streamlit, k=3) | |
| if retrieved_chunks: | |
| # Optional display of retrieved context snippets (for debugging or transparency) | |
| # with st.expander("Show retrieved context snippets (German)"): | |
| # for i, (chunk, dist) in enumerate(retrieved_chunks): | |
| # st.caption(f"Chunk {i+1} (Distance: {dist:.2f})") | |
| # st.markdown(f"_{chunk[:200]}..._") | |
| # st.divider() | |
| answer = generate_answer_app(user_query_streamlit, retrieved_chunks) | |
| st.subheader("Tutor's Answer:") | |
| st.markdown(answer) # Displaying German answer | |
| else: | |
| st.warning("No relevant information could be found for your query.") | |
| else: | |
| st.error("The application could not be initialized correctly. Please check for error messages related to model or data loading.") | |
| st.sidebar.header("About this Project") | |
| st.sidebar.info( | |
| "This RAG application was developed as part of the 'AI Applications' module. " | |
| "It uses Sentence Transformers for embeddings, FAISS for vector search, " | |
| "and an LLM via Groq for answer generation." | |
| ) |