Yosef Skolnick
Enhance RAG processing and UI components
bbde628
import streamlit as st
from typing import Dict, Any, List
import asyncio
import logging
import traceback
import nest_asyncio
# Apply nest_asyncio once during module import
nest_asyncio.apply()
# Setup logger
logger = logging.getLogger(__name__)
# Import our refactored modules
from ui.hebrew import handle_mixed_language_text
from ui.chat_render import display_chat_message, display_status_updates, format_source_html
from pipeline.rag import process_rag_request, create_async_execution_context, extract_citations
def process_prompt(prompt: str, rag_params: Dict[str, Any]):
"""
Process a user prompt and generate a response.
Args:
prompt (str): User input prompt (may contain template)
rag_params (Dict[str, Any]): RAG parameters from sidebar
"""
# Import here to avoid circular imports
from i18n import get_direction, get_text
from utils.sanitization import sanitize_html
from rag_processor import PIPELINE_VALIDATE_GENERATE_GPT4O
# Initialize session state if needed
if 'messages' not in st.session_state:
st.session_state.messages = []
text_direction = get_direction()
hebrew_font = st.session_state.hebrew_font
# Add the visible prompt to chat history (what the user sees)
st.session_state.messages.append({"role": "user", "content": prompt})
display_chat_message(st.session_state.messages[-1])
with st.chat_message("assistant"):
msg_placeholder = st.empty()
status_container = st.status(get_text('processing'), expanded=True)
chunks: List[str] = []
try:
def status_cb(m): status_container.update(label=f"{get_text('processing_step')} {m}")
def stream_cb(c):
# Sanitize the chunk before appending
if isinstance(c, str):
c = sanitize_html(c)
chunks.append(c)
# Process the entire response with mixed language handler first
joined_text = ''.join(chunks) + "▌" # Add cursor
display_html = handle_mixed_language_text(joined_text, hebrew_font)
# Sanitize the final HTML to prevent injection
safe_html = sanitize_html(display_html)
msg_placeholder.markdown(safe_html, unsafe_allow_html=True)
try:
# Get the current event loop or create a new one
loop = create_async_execution_context()
final_rag = loop.run_until_complete(
process_rag_request(
history=st.session_state.messages,
params=rag_params,
status_callback=status_cb,
stream_callback=stream_cb
)
)
# Extract citations after getting the response
cited_ids = []
if isinstance(final_rag, dict):
raw = final_rag.get("final_response", "")
# Only attempt to extract citations if we have a valid response
if raw and raw.strip():
cited_ids = extract_citations(raw)
except (RuntimeError, asyncio.CancelledError, asyncio.TimeoutError) as loop_err:
st.error(f"{get_text('error_async')} {loop_err}", icon="⚠️")
# Format error message with RTL support
err_html = f"""
<div dir='{text_direction}' class='{text_direction}-text hebrew-font'>
<strong>{get_text('request_error')}</strong><br>
{get_text('error_async')}<br>
{type(loop_err).__name__}
</div>
"""
# Sanitize error HTML
err_html = sanitize_html(err_html)
msg_placeholder.markdown(err_html, unsafe_allow_html=True)
# Create minimal final_rag result for error case
final_rag = {
"final_response": get_text('error_async'),
"error": str(loop_err),
"status_log": [f"Asyncio error: {type(loop_err).__name__}"],
"generator_input_documents": [],
"pipeline_used": "Error"
}
cited_ids = []
if isinstance(final_rag, dict):
# Sanitize raw content
raw = final_rag.get("final_response", "")
if isinstance(raw, str):
raw = sanitize_html(raw)
err = final_rag.get("error")
log = final_rag.get("status_log", [])
docs = final_rag.get("generator_input_documents", [])
pipeline = final_rag.get("pipeline_used", PIPELINE_VALIDATE_GENERATE_GPT4O)
# Process final content with mixed language handler
if not (err and raw.strip().startswith("<div")):
final = handle_mixed_language_text(raw, hebrew_font)
# Final sanitization after processing
final = sanitize_html(final)
else:
final = sanitize_html(raw)
msg_placeholder.markdown(final, unsafe_allow_html=True)
# Use the citations extracted earlier instead of making another async call
if cited_ids:
enumerated_docs = list(enumerate(docs, start=1))
docs_to_show = [(idx, doc) for idx, doc in enumerated_docs if str(idx) in cited_ids]
else:
docs_to_show = list(enumerate(docs, start=1))
if docs_to_show:
# Use a simple text title for the expander
with st.expander(f"{get_text('sources_title')} ({len(docs_to_show)})", expanded=False):
# Add RTL Hebrew wrapper with embed
expander_title = f"""
<div class='expander-title rtl-text hebrew-font' dir='rtl' lang="he">
{get_text('sources_text').format(len(docs_to_show))}
</div>
"""
st.markdown(expander_title, unsafe_allow_html=True)
# Container for all sources with RTL direction
st.markdown(f"""
<div dir='rtl' lang="he" class='expander-content rtl-text hebrew-font'>
""", unsafe_allow_html=True)
# Format each source consistently using our helper function
for idx, doc in docs_to_show:
source_html, text_html = format_source_html(doc, idx, hebrew_font, get_text)
st.markdown(source_html, unsafe_allow_html=True)
st.markdown(text_html, unsafe_allow_html=True)
st.markdown("</div>", unsafe_allow_html=True)
# store message
assistant_data = {
"role": "assistant",
"content": final,
"final_docs": docs,
"pipeline_used": pipeline,
"status_log": log,
"error": err
}
st.session_state.messages.append(assistant_data)
display_status_updates(log)
if err:
status_container.update(label=f"{get_text('error')}!", state="error", expanded=False)
else:
status_container.update(label=get_text('processing_complete'), state="complete", expanded=False)
else:
# Format communication error message with proper RTL support
err_msg = f"""
<div dir='{text_direction}' class='{text_direction}-text hebrew-font'>
<strong>{get_text('communication_error')}</strong>
</div>
"""
msg_placeholder.markdown(sanitize_html(err_msg), unsafe_allow_html=True)
st.session_state.messages.append({
"role": "assistant",
"content": get_text('communication_error'),
"final_docs": [],
"pipeline_used": "Error",
"status_log": ["Unexpected result"],
"error": "Unexpected"
})
status_container.update(label=f"{get_text('error')}!", state="error", expanded=False)
except Exception as e:
logger.exception("Unhandled exception in RAG processing")
traceback.print_exc()
# Format critical error with RTL support
err_html = f"""
<div dir='{text_direction}' class='{text_direction}-text hebrew-font'>
<strong>{get_text('critical_error')}</strong><br>
{get_text('reload')}
<details>
<summary>{get_text('details')}</summary>
<pre>{sanitize_html(traceback.format_exc())}</pre>
</details>
</div>
"""
# Sanitize error HTML
err_html = sanitize_html(err_html)
msg_placeholder.error(err_html, icon="🔥")
st.session_state.messages.append({
"role": "assistant",
"content": err_html,
"final_docs": [],
"pipeline_used": "Critical Error",
"status_log": [f"Critical: {type(e).__name__}"],
"error": str(e)
})
status_container.update(label=get_text('processing_error'), state="error", expanded=False)