Spaces:
Sleeping
Sleeping
| """ | |
| app.py | |
| ββββββ | |
| Gradio UI β the entry point for Hugging Face Spaces. | |
| Delegates ALL logic to rag_pipeline.py. | |
| """ | |
| import logging | |
| import sys | |
| import gradio as gr | |
| from config import cfg | |
| from rag_pipeline import RAGPipeline, build_pipeline | |
| # ββ Gradio version guard ββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| import inspect as _inspect | |
| _chatbot_params = set(_inspect.signature(gr.Chatbot.__init__).parameters) | |
| _SUPPORTS_COPY = "show_copy_button" in _chatbot_params | |
| _SUPPORTS_BUBBLE = "bubble_full_width" in _chatbot_params | |
| # ββ Logging setup βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format="%(asctime)s | %(levelname)-8s | %(name)s | %(message)s", | |
| handlers=[logging.StreamHandler(sys.stdout)], | |
| ) | |
| logger = logging.getLogger(__name__) | |
| # ββ Pipeline (initialised once at startup) ββββββββββββββββββββββββββββββββββββ | |
| pipeline: RAGPipeline | None = None | |
| init_error: str | None = None | |
| try: | |
| pipeline = build_pipeline() | |
| except Exception as exc: | |
| init_error = str(exc) | |
| logger.exception("Pipeline initialisation failed: %s", exc) | |
| # ββ Helpers βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _msg(role: str, content: str) -> dict: | |
| """Return a Gradio-compatible message dict.""" | |
| return {"role": role, "content": content} | |
| def _handle_debug_command(command: str) -> str: | |
| """ | |
| Handle special slash commands for in-chat debugging. | |
| No terminal needed β results appear directly in the chat. | |
| """ | |
| from data_loader import get_dataset_info | |
| import vector_store as vs_module | |
| cmd = command.strip().lower() | |
| # ββ /debug β show dataset info ββββββββββββββββββββββββββββββββββββββββββββ | |
| if cmd == "/debug": | |
| info = get_dataset_info() | |
| if info["status"] == "error": | |
| return f"β **Dataset error:**\n```\n{info['error']}\n```" | |
| lines = [ | |
| "### π Dataset Debug Info", | |
| f"**Dataset:** `{info['dataset']}`", | |
| f"**Total rows:** {info['total_rows']}", | |
| f"**All columns:** `{info['columns']}`", | |
| f"**Detected text column:** `{info['detected_text_col']}`", | |
| f"**Non-empty rows:** {info['non_empty_rows']}", | |
| "", | |
| "**Sample text from row 0:**", | |
| f"```\n{info['sample_text']}\n```", | |
| "", | |
| ] | |
| if info["detected_text_col"] not in ["text", "content", "body", "page_content", "extracted_text"]: | |
| lines.append( | |
| f"β οΈ **Column `{info['detected_text_col']}` is not a standard name.**\n" | |
| "Add it to `text_column_candidates` in `config.py`." | |
| ) | |
| lines.append( | |
| "β **No usable text rows found.**" if info["non_empty_rows"] == 0 | |
| else "β Dataset looks healthy." | |
| ) | |
| return "\n".join(lines) | |
| # ββ /retrieve <query> β show raw retrieval results ββββββββββββββββββββββββ | |
| if cmd.startswith("/retrieve "): | |
| test_query = command[len("/retrieve "):].strip() | |
| if not test_query: | |
| return "Usage: `/retrieve your test query here`" | |
| if pipeline is None: | |
| return "β Pipeline not initialised." | |
| docs = vs_module.retrieve(pipeline._index, test_query, k=5) | |
| if not docs: | |
| return ( | |
| f"β **No chunks retrieved** for: `{test_query}`\n" | |
| "FAISS index may be empty or text column is wrong." | |
| ) | |
| lines = [f"### π Retrieved {len(docs)} chunks for: `{test_query}`\n"] | |
| for i, doc in enumerate(docs, 1): | |
| src = doc.metadata.get("source", doc.metadata.get("source_row", "?")) | |
| lines.append(f"**Chunk {i}** (source: {src})") | |
| lines.append(f"```\n{doc.page_content[:300]}\n```") | |
| return "\n".join(lines) | |
| # ββ /status β pipeline health βββββββββββββββββββββββββββββββββββββββββββββ | |
| if cmd == "/status": | |
| if init_error: | |
| return f"β **Pipeline failed:**\n```\n{init_error}\n```" | |
| if pipeline is None: | |
| return "β Pipeline is None β startup may still be in progress." | |
| total_vectors = pipeline._index.index.ntotal | |
| lines = [ | |
| "### β Pipeline Status", | |
| f"**FAISS vectors:** {total_vectors}", | |
| f"**Groq model:** `{cfg.groq_model}`", | |
| f"**Dataset:** `{cfg.hf_dataset}`", | |
| f"**Chunk size:** {cfg.chunk_size} | **Top-K:** {cfg.top_k}", | |
| ( | |
| "\nβ **0 vectors β retrieval will always fail!**" | |
| if total_vectors == 0 | |
| else "\nβ Index looks healthy." | |
| ), | |
| ] | |
| return "\n".join(lines) | |
| return ( | |
| "**Debug commands:**\n" | |
| "- `/debug` β dataset columns, row count, sample text\n" | |
| "- `/status` β pipeline health and vector count\n" | |
| "- `/retrieve your question` β raw retrieval results" | |
| ) | |
| def chat(user_message: str, history: list, show_sources: bool): | |
| """Called by Gradio on every user message.""" | |
| # ββ Handle debug slash commands first βββββββββββββββββββββββββββββββββββββ | |
| if user_message.strip().startswith("/"): | |
| bot_reply = _handle_debug_command(user_message) | |
| return "", history + [_msg("user", user_message), _msg("assistant", bot_reply)], "" | |
| if init_error: | |
| bot_reply = f"β οΈ **Setup error:** {init_error}\n\nCheck Space secrets and logs." | |
| return "", history + [_msg("user", user_message), _msg("assistant", bot_reply)], "" | |
| if not user_message.strip(): | |
| return "", history, "" | |
| try: | |
| response = pipeline.query(user_message) # type: ignore[union-attr] | |
| bot_reply = response.answer | |
| sources_md = response.format_sources() if show_sources else "" | |
| except Exception as exc: | |
| logger.exception("Error during query: %s", exc) | |
| bot_reply = "π Something went wrong while consulting the stars. Please try again." | |
| sources_md = "" | |
| return "", history + [_msg("user", user_message), _msg("assistant", bot_reply)], sources_md | |
| # ββ Gradio UI βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| CSS = """ | |
| body, .gradio-container { font-family: 'Georgia', serif; } | |
| .title-banner { text-align: center; padding: 1rem 0 0.5rem; } | |
| .title-banner h1 { font-size: 2rem; letter-spacing: 0.04em; } | |
| .sources-box { font-size: 0.82rem; color: #718096; } | |
| footer { display: none !important; } | |
| """ | |
| EXAMPLE_QUESTIONS = [ | |
| "What is the difference between the Sun sign and Rising sign?", | |
| "Explain what retrograde motion means for planets.", | |
| "What are the 12 houses in a birth chart?", | |
| "How do I interpret a conjunction aspect?", | |
| "What does it mean when Mars is in Aries?", | |
| "Explain the concept of planetary dignities and debilities.", | |
| "What is the difference between sidereal and tropical zodiac?", | |
| "How does the Moon sign influence emotions?", | |
| ] | |
| _SUPPORTS_THEMES = hasattr(gr, "themes") and hasattr(gr.themes, "Base") | |
| _theme = gr.themes.Base( | |
| primary_hue="indigo", secondary_hue="purple", neutral_hue="slate", | |
| ) if _SUPPORTS_THEMES else None | |
| with gr.Blocks(title=cfg.app_title, theme=_theme, css=CSS) as demo: | |
| # ββ Header ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| gr.HTML(""" | |
| <div class="title-banner"> | |
| <h1>π AstroBot Demo</h1> | |
| <p style="color:#9b8ec4; font-size:1.05rem;"> | |
| Your AI Astrology Assistant Β· Powered by Groq LLaMA-3.1-8b-instant | |
| </p> | |
| </div> | |
| """) | |
| # ββ Disclaimer β fully inline styles for reliability ββββββββββββββββββββββ | |
| gr.HTML(""" | |
| <div style="background-color:#3b3777; color:#f0eeff; border:1px solid #6c67c4; | |
| border-radius:8px; padding:10px 16px; font-size:0.92rem; | |
| margin-bottom:8px; line-height:1.6;"> | |
| π <strong style="color:#ffffff;">For students only.</strong> | |
| AstroBot explains astrological concepts drawn from custom course materials. | |
| It does <strong style="color:#ffffff;">not</strong> make personal predictions | |
| or interpret individual birth charts. | |
| </div> | |
| """) | |
| # ββ Main layout βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| _chatbot_kwargs = {"label": "AstroBot", "height": 500} | |
| if _SUPPORTS_BUBBLE: _chatbot_kwargs["bubble_full_width"] = False | |
| if _SUPPORTS_COPY: _chatbot_kwargs["show_copy_button"] = True | |
| if "type" in _chatbot_params: _chatbot_kwargs["type"] = "messages" | |
| chatbot = gr.Chatbot(**_chatbot_kwargs) | |
| with gr.Row(): | |
| txt_input = gr.Textbox( | |
| placeholder="Ask a concept question about astrologyβ¦", | |
| show_label=False, | |
| scale=9, | |
| ) | |
| send_btn = gr.Button("Ask β¨", variant="primary", scale=1) | |
| with gr.Column(scale=1): | |
| gr.Markdown("### βοΈ Options") | |
| _checkbox_kwargs = {"label": "Show source excerpts", "value": False} | |
| _checkbox_params = set(_inspect.signature(gr.Checkbox.__init__).parameters) | |
| if "info" in _checkbox_params: | |
| _checkbox_kwargs["info"] = "Display course material passages used to answer." | |
| show_sources = gr.Checkbox(**_checkbox_kwargs) | |
| gr.Markdown("### π‘ Example Questions") | |
| for q in EXAMPLE_QUESTIONS: | |
| gr.Button(q, size="sm").click(fn=lambda x=q: x, outputs=txt_input) | |
| gr.Markdown( | |
| "---\nπ οΈ **Debug commands:**\n" | |
| "`/status` Β· `/debug` Β· `/retrieve <query>`" | |
| ) | |
| # ββ Source citations panel ββββββββββββββββββββββββββββββββββββββββββββββββ | |
| sources_display = gr.Markdown( | |
| value="", label="Source Excerpts", elem_classes=["sources-box"] | |
| ) | |
| # ββ State & event wiring ββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| state = gr.State([]) | |
| send_btn.click( | |
| fn=chat, | |
| inputs=[txt_input, state, show_sources], | |
| outputs=[txt_input, chatbot, sources_display], | |
| ) | |
| txt_input.submit( | |
| fn=chat, | |
| inputs=[txt_input, state, show_sources], | |
| outputs=[txt_input, chatbot, sources_display], | |
| ) | |
| gr.Markdown( | |
| "_Built with [Groq](https://groq.com) Β· [LangChain](https://langchain.com) Β· " | |
| "[Hugging Face](https://huggingface.co) β for astrology students everywhere π_" | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(server_name="0.0.0.0", server_port=7860) |