Spaces:
Sleeping
Sleeping
| """ | |
| IJNet Chatbot Assistant — Streamlit Interface (v2) | |
| ---------------------------------------------------- | |
| Improvements over v1: | |
| - Streaming responses (word-by-word output) | |
| - Sidebar filters (region, opportunity type) | |
| - Guardrail feedback in the UI | |
| - Better error handling with user-friendly messages | |
| """ | |
| import os | |
| import sys | |
| import streamlit as st | |
| from pathlib import Path | |
| os.environ["TRANSFORMERS_VERBOSITY"] = "error" | |
| os.environ["HF_HUB_DISABLE_TELEMETRY"] = "1" | |
| sys.path.insert(0, str(Path(__file__).parent)) | |
| from src.ingest import ( | |
| load_knowledge_base, | |
| build_documents, | |
| get_embeddings, | |
| build_vector_store, | |
| load_vector_store, | |
| ) | |
| from src.retriever import HybridRetriever | |
| from src.chain import IJNetRAGChain | |
| # --------------------------------------------------------------------------- | |
| # PAGE CONFIG | |
| # --------------------------------------------------------------------------- | |
| st.set_page_config( | |
| page_title="IJNet Assistant", | |
| page_icon="🌍", | |
| layout="centered", | |
| ) | |
| # --------------------------------------------------------------------------- | |
| # CUSTOM CSS | |
| # --------------------------------------------------------------------------- | |
| st.markdown(""" | |
| <style> | |
| .main-header { | |
| text-align: center; | |
| padding: 1rem 0 0.5rem 0; | |
| } | |
| .main-header h1 { | |
| color: #1a5276; | |
| font-size: 2rem; | |
| margin-bottom: 0.2rem; | |
| } | |
| .main-header p { | |
| color: #666; | |
| font-size: 0.95rem; | |
| } | |
| .source-card { | |
| background: #f8f9fa; | |
| border-left: 3px solid #2980b9; | |
| padding: 0.6rem 0.8rem; | |
| margin: 0.3rem 0; | |
| border-radius: 0 4px 4px 0; | |
| font-size: 0.85rem; | |
| } | |
| .source-card a { | |
| color: #2980b9; | |
| text-decoration: none; | |
| font-weight: 600; | |
| } | |
| .source-card .meta { | |
| color: #888; | |
| font-size: 0.78rem; | |
| } | |
| .guardrail-msg { | |
| background: #fff3cd; | |
| border-left: 3px solid #ffc107; | |
| padding: 0.8rem; | |
| border-radius: 0 4px 4px 0; | |
| font-size: 0.9rem; | |
| } | |
| .stButton > button { | |
| width: 100%; | |
| text-align: left; | |
| padding: 0.4rem 0.8rem; | |
| font-size: 0.85rem; | |
| } | |
| </style> | |
| """, unsafe_allow_html=True) | |
| # --------------------------------------------------------------------------- | |
| # INITIALIZATION | |
| # --------------------------------------------------------------------------- | |
| def initialize_pipeline(): | |
| """Initialize the full RAG pipeline (cached across reruns).""" | |
| kb_path = Path(__file__).parent / "data" / "knowledge_base.json" | |
| index_path = Path(__file__).parent / "data" / "faiss_index" | |
| kb = load_knowledge_base(str(kb_path)) | |
| documents = build_documents(kb) | |
| embeddings = get_embeddings() | |
| if index_path.exists(): | |
| vector_store = load_vector_store(str(index_path), embeddings) | |
| else: | |
| vector_store = build_vector_store(documents, embeddings, str(index_path)) | |
| retriever = HybridRetriever( | |
| vector_store=vector_store, | |
| documents=documents, | |
| semantic_k=8, | |
| bm25_k=8, | |
| final_k=5, | |
| ) | |
| return retriever, documents | |
| def get_chain(retriever, api_key: str) -> IJNetRAGChain: | |
| """Create the RAG chain with the provided API key.""" | |
| return IJNetRAGChain( | |
| retriever=retriever, | |
| groq_api_key=api_key, | |
| model_name="llama-3.3-70b-versatile", | |
| temperature=0.1, | |
| ) | |
| # --------------------------------------------------------------------------- | |
| # UI COMPONENTS | |
| # --------------------------------------------------------------------------- | |
| def render_header(): | |
| st.markdown(""" | |
| <div class="main-header"> | |
| <h1>🌍 IJNet Assistant</h1> | |
| <p>Find journalism opportunities, resources, and expert guidance worldwide</p> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| def render_sources(sources: list[dict]): | |
| """Render source cards in an expander.""" | |
| if not sources: | |
| return | |
| REAL_URLS = { | |
| "opportunity": "https://ijnet.org/en/opportunities", | |
| "article": "https://ijnet.org/en/resources", | |
| "ijnet_info": "https://ijnet.org/en/about", | |
| } | |
| with st.expander(f"📚 Sources ({len(sources)}) — *sample knowledge base*", expanded=False): | |
| for src in sources: | |
| source_type = src.get("type", "") | |
| title = src.get("title", "Unknown") | |
| real_url = REAL_URLS.get(source_type, "https://ijnet.org") | |
| meta_parts = [] | |
| if src.get("opp_type"): | |
| meta_parts.append(f"📋 {src['opp_type'].capitalize()}") | |
| if src.get("deadline"): | |
| meta_parts.append(f"⏰ Deadline: {src['deadline']}") | |
| if src.get("organization"): | |
| meta_parts.append(f"🏢 {src['organization']}") | |
| if src.get("author"): | |
| meta_parts.append(f"✍️ {src['author']}") | |
| if src.get("date"): | |
| meta_parts.append(f"📅 {src['date']}") | |
| meta_str = " · ".join(meta_parts) | |
| st.markdown( | |
| f'<div class="source-card">' | |
| f'<strong>{title}</strong> ' | |
| f'<a href="{real_url}" target="_blank" style="font-size:0.78rem;">[Browse on IJNet ↗]</a><br>' | |
| f'<span class="meta">{meta_str}</span>' | |
| f'</div>', | |
| unsafe_allow_html=True, | |
| ) | |
| def render_debug(debug_info: dict): | |
| """Show retrieval debug information.""" | |
| with st.expander("🔍 Debug: Retrieval Details", expanded=False): | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| st.markdown("**Query Classification**") | |
| st.json(debug_info["classification"]) | |
| with col2: | |
| st.markdown(f"**Retrieved:** {debug_info['num_retrieved']} documents") | |
| st.markdown("**Semantic Top-3:**") | |
| for title, score in debug_info.get("semantic_top3", []): | |
| st.text(f" {score:.3f} — {title[:50]}") | |
| st.markdown("**BM25 Top-3:**") | |
| for title, score in debug_info.get("bm25_top3", []): | |
| st.text(f" {score:.1f} — {title[:50]}") | |
| def handle_error(e: Exception): | |
| """Display user-friendly error messages based on error type.""" | |
| error_msg = str(e).lower() | |
| if "api_key" in error_msg or "auth" in error_msg or "invalid" in error_msg: | |
| st.error("🔑 **Invalid API key.** Please check your Groq API key in the sidebar.") | |
| elif "rate_limit" in error_msg or "429" in error_msg: | |
| st.warning( | |
| "⏳ **Rate limit reached.** Groq's free tier has request limits. " | |
| "Please wait 30-60 seconds and try again." | |
| ) | |
| elif "timeout" in error_msg or "timed out" in error_msg: | |
| st.warning( | |
| "⏱️ **Request timed out.** The server took too long to respond. " | |
| "Please try again — if the issue persists, try a shorter question." | |
| ) | |
| elif "connection" in error_msg or "network" in error_msg: | |
| st.error( | |
| "🌐 **Connection error.** Could not reach the Groq API. " | |
| "Please check your internet connection and try again." | |
| ) | |
| elif "model" in error_msg and ("not found" in error_msg or "unavailable" in error_msg): | |
| st.error( | |
| "🤖 **Model unavailable.** The selected LLM model is temporarily unavailable. " | |
| "Please try again in a few minutes." | |
| ) | |
| else: | |
| st.error(f"❌ **Error:** {str(e)[:200]}") | |
| with st.expander("Full error details"): | |
| st.code(str(e)) | |
| EXAMPLE_QUERIES = [ | |
| "What opportunities are available for investigative journalists in Africa?", | |
| "Find fellowships with deadlines in the next 30 days", | |
| "What resources does IJNet have on AI tools for journalists?", | |
| "Can you summarize the latest opportunities for product/design people in newsrooms?", | |
| "Which IJNet newsletter should I subscribe to?", | |
| "What grants are available for data journalism projects?", | |
| "Tell me about digital security tools for journalists", | |
| "What training programs exist for journalists in the Middle East?", | |
| ] | |
| # --------------------------------------------------------------------------- | |
| # MAIN APP | |
| # --------------------------------------------------------------------------- | |
| def main(): | |
| render_header() | |
| # --- Sidebar --- | |
| with st.sidebar: | |
| st.markdown("### ⚙️ Configuration") | |
| api_key = st.text_input( | |
| "Groq API Key", | |
| type="password", | |
| placeholder="gsk_...", | |
| help="Get a free API key at https://console.groq.com", | |
| ) | |
| if not api_key: | |
| env_key = os.environ.get("GROQ_API_KEY", "") | |
| if env_key: | |
| api_key = env_key | |
| st.success("Using API key from environment") | |
| # After the existing env_key check, add: | |
| if not api_key: | |
| try: | |
| api_key = st.secrets.get("GROQ_API_KEY", "") | |
| if api_key: | |
| st.success("Using configured API key") | |
| except Exception: | |
| pass | |
| st.markdown("---") | |
| st.markdown("### 🔎 Filters") | |
| filter_region = st.selectbox( | |
| "Region", | |
| ["All", "Africa", "Asia", "Europe", "Latin America", | |
| "Middle East", "South Asia", "Global"], | |
| index=0, | |
| help="Prioritize results from this region", | |
| ) | |
| filter_type = st.selectbox( | |
| "Opportunity Type", | |
| ["All", "fellowship", "grant", "training", "award"], | |
| index=0, | |
| help="Prioritize this type of opportunity", | |
| ) | |
| st.markdown("---") | |
| debug_mode = st.checkbox("Show retrieval debug info", value=False) | |
| st.markdown("---") | |
| st.markdown("### 📊 Knowledge Base") | |
| st.markdown(""" | |
| - **20** opportunity records | |
| - **6** resource articles | |
| - **Hybrid retrieval**: Semantic + BM25 | |
| - **LLM**: Llama 3.3 70B via Groq | |
| - **Embeddings**: MiniLM-L6-v2 | |
| """) | |
| st.markdown("---") | |
| st.markdown( | |
| "<small>Built with LangChain, FAISS, Groq, and HuggingFace.<br>" | |
| "Prototype for IJNet chatbot assistance.</small>", | |
| unsafe_allow_html=True, | |
| ) | |
| # Build filters dict from sidebar | |
| ui_filters = {} | |
| if filter_region != "All": | |
| ui_filters["region"] = filter_region | |
| if filter_type != "All": | |
| ui_filters["opp_type"] = filter_type | |
| # --- Initialize pipeline --- | |
| try: | |
| retriever, documents = initialize_pipeline() | |
| except Exception as e: | |
| st.error(f"Failed to initialize pipeline: {e}") | |
| st.info("Try deleting `data/faiss_index/` and restarting.") | |
| return | |
| # --- Chat state --- | |
| if "messages" not in st.session_state: | |
| st.session_state.messages = [] | |
| if "rag_chain" not in st.session_state: | |
| st.session_state.rag_chain = None | |
| # --- Example queries (only show when no messages) --- | |
| if not st.session_state.messages: | |
| st.markdown("#### Try asking:") | |
| cols = st.columns(2) | |
| for i, query in enumerate(EXAMPLE_QUERIES[:6]): | |
| with cols[i % 2]: | |
| if st.button(query, key=f"example_{i}", use_container_width=True): | |
| st.session_state.pending_query = query | |
| st.rerun() | |
| # --- Display chat history --- | |
| for msg in st.session_state.messages: | |
| with st.chat_message(msg["role"]): | |
| st.markdown(msg["content"]) | |
| if msg.get("sources"): | |
| render_sources(msg["sources"]) | |
| if msg.get("debug"): | |
| render_debug(msg["debug"]) | |
| # --- Handle input --- | |
| pending = st.session_state.pop("pending_query", None) | |
| user_input = st.chat_input("Ask about journalism opportunities, resources, or training...") | |
| query = pending or user_input | |
| if query: | |
| # Validate API key | |
| if not api_key: | |
| st.warning("⚠️ Please enter your Groq API key in the sidebar to start chatting.") | |
| return | |
| # Initialize chain if needed | |
| if st.session_state.rag_chain is None: | |
| try: | |
| st.session_state.rag_chain = get_chain(retriever, api_key) | |
| except Exception as e: | |
| handle_error(e) | |
| return | |
| # Display user message | |
| st.session_state.messages.append({"role": "user", "content": query}) | |
| with st.chat_message("user"): | |
| st.markdown(query) | |
| # Generate streaming response | |
| with st.chat_message("assistant"): | |
| try: | |
| result = st.session_state.rag_chain.query_stream( | |
| question=query, | |
| filters=ui_filters if ui_filters else None, | |
| include_debug=debug_mode, | |
| ) | |
| # Check if guardrails blocked it | |
| if result.get("guardrail_blocked"): | |
| full_answer = "".join(result["answer_stream"]) | |
| st.markdown( | |
| f'<div class="guardrail-msg">{full_answer}</div>', | |
| unsafe_allow_html=True, | |
| ) | |
| st.session_state.messages.append({ | |
| "role": "assistant", | |
| "content": full_answer, | |
| }) | |
| else: | |
| # Stream the response token by token | |
| full_answer = st.write_stream(result["answer_stream"]) | |
| # Show sources after streaming completes | |
| render_sources(result["sources"]) | |
| if debug_mode and result.get("debug"): | |
| render_debug(result["debug"]) | |
| # Save to history | |
| msg_data = { | |
| "role": "assistant", | |
| "content": full_answer, | |
| "sources": result["sources"], | |
| } | |
| if debug_mode and result.get("debug"): | |
| msg_data["debug"] = result["debug"] | |
| st.session_state.messages.append(msg_data) | |
| except Exception as e: | |
| handle_error(e) | |
| # --- Clear chat button --- | |
| if st.session_state.messages: | |
| if st.button("🗑️ Clear Chat", use_container_width=True): | |
| st.session_state.messages = [] | |
| if st.session_state.rag_chain: | |
| st.session_state.rag_chain.reset_history() | |
| st.session_state.rag_chain = None | |
| st.rerun() | |
| if __name__ == "__main__": | |
| main() | |