| | from fastapi import FastAPI |
| | from pydantic import BaseModel, Field |
| | import gradio as gr |
| | from dotenv import load_dotenv |
| |
|
| | load_dotenv() |
| |
|
| | from rag import answer_question |
| |
|
| | app = FastAPI(title="NOLEJ RAG", version="1.0", debug=True) |
| |
|
| | |
| | class AskRequest(BaseModel): |
| | question: str = Field(..., min_length=3) |
| |
|
| | class AskResponse(BaseModel): |
| | answer: str |
| | citations: list |
| | used_chunks: list |
| |
|
| | @app.get("/health") |
| | async def health(): |
| | return {"status": "ok"} |
| |
|
| | @app.post("/ask", response_model=AskResponse) |
| | async def ask(req: AskRequest): |
| | return await answer_question(req.question) |
| |
|
| |
|
| | |
| | CSS = """ |
| | .gradio-container { max-width: 1100px !important; margin: 0 auto !important; } |
| | #chatbot { height: 28vh; overflow: auto; border-radius: 14px; } |
| | #panel { border-radius: 16px; } |
| | .gr-chat-message { border-radius: 14px !important; padding: 10px 12px !important; } |
| | """ |
| |
|
| | def _content_to_text(content) -> str: |
| | """Gradio can store message content as str OR list of parts. Normalize to str.""" |
| | if content is None: |
| | return "" |
| | if isinstance(content, str): |
| | return content |
| | if isinstance(content, list): |
| | parts = [] |
| | for p in content: |
| | if p is None: |
| | continue |
| | if isinstance(p, str): |
| | parts.append(p) |
| | elif isinstance(p, dict): |
| | parts.append(str(p.get("text") or p.get("content") or "")) |
| | else: |
| | parts.append(str(p)) |
| | return "\n".join([x for x in parts if x.strip()]) |
| | if isinstance(content, dict): |
| | return str(content.get("text") or content.get("content") or "") |
| | return str(content) |
| |
|
| |
|
| | def format_history_for_followup(history_msgs): |
| | """ |
| | history_msgs: list[dict] like {"role":"user","content":...} |
| | content might be str OR list (depending on gradio version). |
| | """ |
| | if not history_msgs: |
| | return "" |
| |
|
| | last = history_msgs[-12:] |
| | lines = [] |
| | for m in last: |
| | if not isinstance(m, dict): |
| | continue |
| | role = (m.get("role") or "").strip() |
| | content = _content_to_text(m.get("content")).strip() |
| | if not content: |
| | continue |
| |
|
| | if role == "user": |
| | lines.append(f"User: {content}") |
| | elif role == "assistant": |
| | lines.append(f"Assistant: {content}") |
| |
|
| | return "\n".join(lines) |
| |
|
| |
|
| | def render_chunks(chunks, n=3): |
| | """ |
| | chunks: list of dicts as returned in out["used_chunks"] |
| | returns markdown with top-n chunks (content + metadata + scores) |
| | """ |
| | if not chunks: |
| | return "Aucun chunk à afficher." |
| |
|
| | n = max(1, min(n, len(chunks))) |
| | parts = [] |
| |
|
| | def fmt(x): |
| | return "—" if x is None else f"{x:.4f}" |
| |
|
| | for i in range(n): |
| | ch = chunks[i] |
| | meta = ch.get("metadata", {}) or {} |
| |
|
| | source = meta.get("source_file", "unknown") |
| | did = meta.get("doc_id", "unknown") |
| | chunk_index = meta.get("chunk_index", "?") |
| | chunk_id = meta.get("chunk_id", "") |
| |
|
| | text = (ch.get("text") or "").strip() |
| |
|
| | |
| | faiss_score = ch.get("score", None) |
| | rerank_score = ch.get("rerank_score", None) |
| | final_score = ch.get("final_score", None) |
| |
|
| | hdr = f"**#{i+1} — {source} | doc_id={did} | chunk_index={chunk_index}**" |
| |
|
| | score_line = ( |
| | f"<sub>" |
| | f"FAISS score: {fmt(faiss_score)}" |
| | f"{' | Rerank score: ' + fmt(rerank_score) if rerank_score is not None else ''}" |
| | f"{' | Final: ' + fmt(final_score) if final_score is not None else ''}" |
| | f"</sub>" |
| | ) |
| |
|
| | if len(text) > 1200: |
| | text = text[:1200] + "…" |
| |
|
| | parts.append( |
| | f"{hdr}\n\n" |
| | f"{score_line}\n\n" |
| | f"```text\n{text}\n```\n" |
| | f"<sub>chunk_id: {chunk_id}</sub>\n" |
| | ) |
| |
|
| | return "\n\n---\n\n".join(parts) |
| |
|
| |
|
| | async def chat_ask(user_message, history_msgs, show_sources, chunks_state): |
| | user_message = (user_message or "").strip() |
| | if len(user_message) < 3: |
| | return history_msgs, "", gr.update(visible=False), chunks_state, 3, "Pose une question plus précise 🙂" |
| |
|
| | convo = format_history_for_followup(history_msgs) |
| | effective_question = f"{user_message}\n\n[Conversation context]\n{convo}" if convo else user_message |
| |
|
| | out = await answer_question(effective_question) |
| | answer = out["answer"] |
| |
|
| | history_msgs = (history_msgs or []) + [ |
| | {"role": "user", "content": user_message}, |
| | {"role": "assistant", "content": answer}, |
| | ] |
| |
|
| | chunks_state = out.get("used_chunks", []) or [] |
| |
|
| | shown_n = 3 |
| | sources_md = render_chunks(chunks_state, n=shown_n) |
| |
|
| | return ( |
| | history_msgs, |
| | "", |
| | gr.update(visible=bool(show_sources)), |
| | chunks_state, |
| | shown_n, |
| | sources_md |
| | ) |
| |
|
| |
|
| | def show_more(chunks_state, shown_n): |
| | new_n = (shown_n or 3) + 3 |
| | return new_n, render_chunks(chunks_state, n=new_n) |
| |
|
| | def reset_chat(): |
| | return [], "", [], 3, "Aucun chunk à afficher." |
| |
|
| |
|
| | with gr.Blocks(title="NOLEJ", css=CSS, theme="soft") as nolej_ui: |
| | gr.Markdown("# 🧠 NOLEJ — Neurosciences RAG") |
| | gr.Markdown("Interface de chat — réponses **uniquement** à partir des documents fournis.") |
| |
|
| | chunks_state = gr.State([]) |
| | shown_n_state = gr.State(3) |
| |
|
| | with gr.Row(): |
| | with gr.Column(scale=6, min_width=520): |
| | chatbot = gr.Chatbot(elem_id="chatbot", label="NOLEJ Chat", height=300) |
| |
|
| | with gr.Row(): |
| | show_sources = gr.Checkbox(value=True, label="Afficher sources") |
| | reset_btn = gr.Button(" Réinitialiser") |
| |
|
| | msg = gr.Textbox(placeholder="Pose ta question…", show_label=False) |
| | send = gr.Button(" Envoyer") |
| |
|
| | with gr.Column(scale=4, min_width=420): |
| | sources_box = gr.Accordion("📌 Sources / Chunks (Top 3 par défaut)", open=False, visible=True) |
| | with sources_box: |
| | sources_md = gr.Markdown("Aucun chunk à afficher.") |
| | more_btn = gr.Button("➕ Plus (afficher 3 de plus)") |
| |
|
| | msg.submit( |
| | chat_ask, |
| | inputs=[msg, chatbot, show_sources, chunks_state], |
| | outputs=[chatbot, msg, sources_box, chunks_state, shown_n_state, sources_md], |
| | ) |
| | send.click( |
| | chat_ask, |
| | inputs=[msg, chatbot, show_sources, chunks_state], |
| | outputs=[chatbot, msg, sources_box, chunks_state, shown_n_state, sources_md], |
| | ) |
| |
|
| | more_btn.click( |
| | show_more, |
| | inputs=[chunks_state, shown_n_state], |
| | outputs=[shown_n_state, sources_md], |
| | ) |
| |
|
| | reset_btn.click( |
| | reset_chat, |
| | inputs=[], |
| | outputs=[chatbot, msg, chunks_state, shown_n_state, sources_md], |
| | ) |
| |
|
| | |
| | app = gr.mount_gradio_app(app, nolej_ui, path="/") |