Nolej / main.py
AshJem's picture
Update main.py
97dcf56 verified
from fastapi import FastAPI
from pydantic import BaseModel, Field
import gradio as gr
from dotenv import load_dotenv
load_dotenv()
from rag import answer_question
app = FastAPI(title="NOLEJ RAG", version="1.0", debug=True)
# -------- API --------
class AskRequest(BaseModel):
question: str = Field(..., min_length=3)
class AskResponse(BaseModel):
answer: str
citations: list
used_chunks: list
@app.get("/health")
async def health():
return {"status": "ok"}
@app.post("/ask", response_model=AskResponse)
async def ask(req: AskRequest):
return await answer_question(req.question)
# -------- Gradio UI --------
CSS = """
.gradio-container { max-width: 1100px !important; margin: 0 auto !important; }
#chatbot { height: 28vh; overflow: auto; border-radius: 14px; }
#panel { border-radius: 16px; }
.gr-chat-message { border-radius: 14px !important; padding: 10px 12px !important; }
"""
def _content_to_text(content) -> str:
"""Gradio can store message content as str OR list of parts. Normalize to str."""
if content is None:
return ""
if isinstance(content, str):
return content
if isinstance(content, list):
parts = []
for p in content:
if p is None:
continue
if isinstance(p, str):
parts.append(p)
elif isinstance(p, dict):
parts.append(str(p.get("text") or p.get("content") or ""))
else:
parts.append(str(p))
return "\n".join([x for x in parts if x.strip()])
if isinstance(content, dict):
return str(content.get("text") or content.get("content") or "")
return str(content)
def format_history_for_followup(history_msgs):
"""
history_msgs: list[dict] like {"role":"user","content":...}
content might be str OR list (depending on gradio version).
"""
if not history_msgs:
return ""
last = history_msgs[-12:] # ~6 turns
lines = []
for m in last:
if not isinstance(m, dict):
continue
role = (m.get("role") or "").strip()
content = _content_to_text(m.get("content")).strip()
if not content:
continue
if role == "user":
lines.append(f"User: {content}")
elif role == "assistant":
lines.append(f"Assistant: {content}")
return "\n".join(lines)
def render_chunks(chunks, n=3):
"""
chunks: list of dicts as returned in out["used_chunks"]
returns markdown with top-n chunks (content + metadata + scores)
"""
if not chunks:
return "Aucun chunk à afficher."
n = max(1, min(n, len(chunks)))
parts = []
def fmt(x):
return "—" if x is None else f"{x:.4f}"
for i in range(n):
ch = chunks[i]
meta = ch.get("metadata", {}) or {}
source = meta.get("source_file", "unknown")
did = meta.get("doc_id", "unknown")
chunk_index = meta.get("chunk_index", "?")
chunk_id = meta.get("chunk_id", "")
text = (ch.get("text") or "").strip()
# scores
faiss_score = ch.get("score", None)
rerank_score = ch.get("rerank_score", None)
final_score = ch.get("final_score", None)
hdr = f"**#{i+1}{source} | doc_id={did} | chunk_index={chunk_index}**"
score_line = (
f"<sub>"
f"FAISS score: {fmt(faiss_score)}"
f"{' | Rerank score: ' + fmt(rerank_score) if rerank_score is not None else ''}"
f"{' | Final: ' + fmt(final_score) if final_score is not None else ''}"
f"</sub>"
)
if len(text) > 1200:
text = text[:1200] + "…"
parts.append(
f"{hdr}\n\n"
f"{score_line}\n\n"
f"```text\n{text}\n```\n"
f"<sub>chunk_id: {chunk_id}</sub>\n"
)
return "\n\n---\n\n".join(parts)
async def chat_ask(user_message, history_msgs, show_sources, chunks_state):
user_message = (user_message or "").strip()
if len(user_message) < 3:
return history_msgs, "", gr.update(visible=False), chunks_state, 3, "Pose une question plus précise 🙂"
convo = format_history_for_followup(history_msgs)
effective_question = f"{user_message}\n\n[Conversation context]\n{convo}" if convo else user_message
out = await answer_question(effective_question)
answer = out["answer"]
history_msgs = (history_msgs or []) + [
{"role": "user", "content": user_message},
{"role": "assistant", "content": answer},
]
chunks_state = out.get("used_chunks", []) or []
shown_n = 3
sources_md = render_chunks(chunks_state, n=shown_n)
return (
history_msgs,
"",
gr.update(visible=bool(show_sources)),
chunks_state,
shown_n,
sources_md
)
def show_more(chunks_state, shown_n):
new_n = (shown_n or 3) + 3
return new_n, render_chunks(chunks_state, n=new_n)
def reset_chat():
return [], "", [], 3, "Aucun chunk à afficher."
with gr.Blocks(title="NOLEJ", css=CSS, theme="soft") as nolej_ui:
gr.Markdown("# 🧠 NOLEJ — Neurosciences RAG")
gr.Markdown("Interface de chat — réponses **uniquement** à partir des documents fournis.")
chunks_state = gr.State([])
shown_n_state = gr.State(3)
with gr.Row():
with gr.Column(scale=6, min_width=520):
chatbot = gr.Chatbot(elem_id="chatbot", label="NOLEJ Chat", height=300)
with gr.Row():
show_sources = gr.Checkbox(value=True, label="Afficher sources")
reset_btn = gr.Button(" Réinitialiser")
msg = gr.Textbox(placeholder="Pose ta question…", show_label=False)
send = gr.Button(" Envoyer")
with gr.Column(scale=4, min_width=420):
sources_box = gr.Accordion("📌 Sources / Chunks (Top 3 par défaut)", open=False, visible=True)
with sources_box:
sources_md = gr.Markdown("Aucun chunk à afficher.")
more_btn = gr.Button("➕ Plus (afficher 3 de plus)")
msg.submit(
chat_ask,
inputs=[msg, chatbot, show_sources, chunks_state],
outputs=[chatbot, msg, sources_box, chunks_state, shown_n_state, sources_md],
)
send.click(
chat_ask,
inputs=[msg, chatbot, show_sources, chunks_state],
outputs=[chatbot, msg, sources_box, chunks_state, shown_n_state, sources_md],
)
more_btn.click(
show_more,
inputs=[chunks_state, shown_n_state],
outputs=[shown_n_state, sources_md],
)
reset_btn.click(
reset_chat,
inputs=[],
outputs=[chatbot, msg, chunks_state, shown_n_state, sources_md],
)
# Mount Gradio on FastAPI (ASGI compatible)
app = gr.mount_gradio_app(app, nolej_ui, path="/")