|
|
"""Gradio chat interface for HPMOR Q&A system.""" |
|
|
|
|
|
import gradio as gr |
|
|
import json |
|
|
from typing import List, Tuple, Optional |
|
|
from datetime import datetime |
|
|
|
|
|
from src.rag_engine import RAGEngine |
|
|
from src.model_chain import ModelType |
|
|
from src.config import config |
|
|
|
|
|
|
|
|
class ChatInterface: |
|
|
"""Gradio-based chat interface for HPMOR Q&A.""" |
|
|
|
|
|
def __init__(self): |
|
|
"""Initialize the chat interface.""" |
|
|
print("Initializing HPMOR Q&A Chat Interface...") |
|
|
self.engine = RAGEngine(force_recreate=False) |
|
|
self.conversation_history = [] |
|
|
|
|
|
def format_sources(self, sources: List[dict]) -> str: |
|
|
"""Format sources for display.""" |
|
|
if not sources: |
|
|
return "No sources found" |
|
|
|
|
|
formatted = [] |
|
|
for i, source in enumerate(sources, 1): |
|
|
formatted.append( |
|
|
f"**Source {i}** - Chapter {source['chapter_number']}: {source['chapter_title']}\n" |
|
|
f"Relevance Score: {source['score']:.2f}\n" |
|
|
f"Preview: *{source['text_preview'][:150]}...*" |
|
|
) |
|
|
return "\n\n".join(formatted) |
|
|
|
|
|
def process_message( |
|
|
self, |
|
|
message: str, |
|
|
history: List[List[str]], |
|
|
model_choice: str, |
|
|
top_k: int, |
|
|
show_sources: bool |
|
|
) -> Tuple[str, str, str]: |
|
|
"""Process a chat message and return response.""" |
|
|
if not message: |
|
|
return "", "", "Please enter a question." |
|
|
|
|
|
|
|
|
model_map = { |
|
|
"Auto (Intelligent Routing)": None, |
|
|
"Local Small (Fast)": ModelType.LOCAL_SMALL, |
|
|
"Local Large (Better)": ModelType.LOCAL_LARGE, |
|
|
"Groq API (Best)": ModelType.GROQ_API |
|
|
} |
|
|
force_model = model_map.get(model_choice) |
|
|
|
|
|
|
|
|
messages = [] |
|
|
for user_msg, assistant_msg in history: |
|
|
if user_msg: |
|
|
messages.append({"role": "user", "content": user_msg}) |
|
|
if assistant_msg: |
|
|
messages.append({"role": "assistant", "content": assistant_msg}) |
|
|
messages.append({"role": "user", "content": message}) |
|
|
|
|
|
try: |
|
|
|
|
|
response = self.engine.chat(messages, stream=False) |
|
|
|
|
|
|
|
|
if isinstance(response.get("answer"), str): |
|
|
answer = response["answer"] |
|
|
else: |
|
|
|
|
|
answer = str(response.get("answer", "No response generated")) |
|
|
|
|
|
|
|
|
model_info = f"**Model Used:** {response.get('model_used', 'Unknown')}" |
|
|
if response.get("fallback_used"): |
|
|
model_info += " (via fallback)" |
|
|
model_info += f"\n**Context Size:** {response.get('context_size', 0)} characters" |
|
|
|
|
|
|
|
|
sources_text = "" |
|
|
if show_sources and response.get("sources"): |
|
|
sources_text = self.format_sources(response["sources"]) |
|
|
|
|
|
return answer, sources_text, model_info |
|
|
|
|
|
except Exception as e: |
|
|
error_msg = f"Error: {str(e)}" |
|
|
return error_msg, "", "Error occurred" |
|
|
|
|
|
def clear_conversation(self): |
|
|
"""Clear conversation history and cache.""" |
|
|
self.conversation_history = [] |
|
|
self.engine.clear_cache() |
|
|
return None, "", "", "Conversation cleared" |
|
|
|
|
|
def create_interface(self) -> gr.Blocks: |
|
|
"""Create the Gradio interface.""" |
|
|
with gr.Blocks(title="HPMOR Q&A System", theme=gr.themes.Soft()) as interface: |
|
|
gr.Markdown( |
|
|
""" |
|
|
# 🧙♂️ Chat with Harry James Potter-Evans-Verres |
|
|
|
|
|
Hello! I'm Harry Potter-Evans-Verres from "Harry Potter and the Methods of Rationality." |
|
|
Ask me anything about my adventures, experiments with magic, or my thoughts on rationality and science. |
|
|
I'll respond based on my experiences and the scientific method, of course! |
|
|
|
|
|
*Powered by RAG with Ollama (local) and Groq API for complex reasoning* |
|
|
""" |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=2): |
|
|
chatbot = gr.Chatbot( |
|
|
label="Chat", |
|
|
height=500, |
|
|
show_copy_button=True |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
msg_input = gr.Textbox( |
|
|
label="Your Question", |
|
|
placeholder="Ask me anything... For example: 'What do you think about magic?' or 'Tell me about your experiments'", |
|
|
lines=2, |
|
|
scale=4 |
|
|
) |
|
|
submit_btn = gr.Button("Send", variant="primary", scale=1) |
|
|
|
|
|
with gr.Column(scale=1): |
|
|
gr.Markdown("### Settings") |
|
|
|
|
|
model_choice = gr.Radio( |
|
|
choices=[ |
|
|
"Auto (Intelligent Routing)", |
|
|
"Local Small (Fast)", |
|
|
"Local Large (Better)", |
|
|
"Groq API (Best)" |
|
|
], |
|
|
value="Auto (Intelligent Routing)", |
|
|
label="Model Selection" |
|
|
) |
|
|
|
|
|
top_k = gr.Slider( |
|
|
minimum=1, |
|
|
maximum=10, |
|
|
value=5, |
|
|
step=1, |
|
|
label="Number of Context Chunks" |
|
|
) |
|
|
|
|
|
show_sources = gr.Checkbox( |
|
|
value=True, |
|
|
label="Show Sources" |
|
|
) |
|
|
|
|
|
clear_btn = gr.Button("Clear Conversation", variant="secondary") |
|
|
|
|
|
gr.Markdown("### Model Info") |
|
|
model_info = gr.Markdown( |
|
|
value="Ready to answer questions", |
|
|
elem_classes=["model-info"] |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
sources_display = gr.Markdown( |
|
|
label="Retrieved Sources", |
|
|
value="", |
|
|
visible=True |
|
|
) |
|
|
|
|
|
|
|
|
gr.Examples( |
|
|
examples=[ |
|
|
"Harry, how did you first react when you learned magic was real?", |
|
|
"What's your opinion on the way Hogwarts teaches magic?", |
|
|
"Can you explain your scientific experiments with magic?", |
|
|
"What do you think about Hermione?", |
|
|
"How do you apply rationality to magical problems?", |
|
|
"What's your relationship with Professor Quirrell like?", |
|
|
], |
|
|
inputs=msg_input, |
|
|
label="Example Questions for Harry" |
|
|
) |
|
|
|
|
|
|
|
|
def respond(message, history, model, topk, sources): |
|
|
"""Handle message submission.""" |
|
|
answer, sources_text, info = self.process_message( |
|
|
message, history, model, topk, sources |
|
|
) |
|
|
history.append([message, answer]) |
|
|
return "", history, sources_text, info |
|
|
|
|
|
msg_input.submit( |
|
|
respond, |
|
|
inputs=[msg_input, chatbot, model_choice, top_k, show_sources], |
|
|
outputs=[msg_input, chatbot, sources_display, model_info] |
|
|
) |
|
|
|
|
|
submit_btn.click( |
|
|
respond, |
|
|
inputs=[msg_input, chatbot, model_choice, top_k, show_sources], |
|
|
outputs=[msg_input, chatbot, sources_display, model_info] |
|
|
) |
|
|
|
|
|
clear_btn.click( |
|
|
lambda: self.clear_conversation(), |
|
|
outputs=[chatbot, sources_display, msg_input, model_info] |
|
|
) |
|
|
|
|
|
|
|
|
interface.css = """ |
|
|
.model-info { |
|
|
background-color: #f0f0f0; |
|
|
padding: 10px; |
|
|
border-radius: 5px; |
|
|
font-size: 0.9em; |
|
|
} |
|
|
""" |
|
|
|
|
|
return interface |
|
|
|
|
|
def launch(self): |
|
|
"""Launch the Gradio interface.""" |
|
|
interface = self.create_interface() |
|
|
|
|
|
print(f"\nLaunching HPMOR Q&A Chat Interface...") |
|
|
print(f"Server will be available at: http://localhost:{config.gradio_server_port}") |
|
|
|
|
|
interface.launch( |
|
|
server_name="0.0.0.0", |
|
|
server_port=config.gradio_server_port, |
|
|
share=config.gradio_share, |
|
|
favicon_path=None |
|
|
) |
|
|
|
|
|
|
|
|
def main(): |
|
|
"""Launch the chat interface.""" |
|
|
chat = ChatInterface() |
|
|
chat.launch() |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |