status-law-gbot / app.py
Rulga's picture
Refactor chat interface to use tabs for organization and add model training features
8f0f58a
raw
history blame
10 kB
import gradio as gr
import os
from huggingface_hub import InferenceClient
from config.constants import DEFAULT_SYSTEM_MESSAGE
from config.settings import DEFAULT_MODEL, HF_TOKEN
from src.knowledge_base.vector_store import create_vector_store, load_vector_store
from web.training_interface import (
get_models_df,
generate_chat_analysis,
register_model_action,
start_finetune_action
)
if not HF_TOKEN:
raise ValueError("HUGGINGFACE_TOKEN not found in environment variables")
# Initialize HF client with token
client = InferenceClient(
DEFAULT_MODEL,
token=HF_TOKEN
)
# State for storing context
context_store = {}
def get_context(message, conversation_id):
"""Get context from knowledge base"""
vector_store = load_vector_store()
if vector_store is None:
return "Knowledge base not found. Please create it first."
try:
# Extract context
context_docs = vector_store.similarity_search(message, k=3)
context_text = "\n\n".join([f"From {doc.metadata.get('source', 'unknown')}: {doc.page_content}" for doc in context_docs])
# Save context for this conversation
context_store[conversation_id] = context_text
return context_text
except Exception as e:
print(f"Error getting context: {str(e)}")
return ""
def respond(
message,
history,
conversation_id,
system_message,
max_tokens,
temperature,
top_p,
):
# Create ID for new conversation
if not conversation_id:
import uuid
conversation_id = str(uuid.uuid4())
# Get context from knowledge base
context = get_context(message, conversation_id)
# Convert history from Gradio format to OpenAI format
messages = [{"role": "system", "content": system_message}]
if context:
messages[0]["content"] += f"\n\nContext for response:\n{context}"
# Convert history to OpenAI format
for user_msg, assistant_msg in history:
messages.extend([
{"role": "user", "content": user_msg},
{"role": "assistant", "content": assistant_msg}
])
# Add current user message
messages.append({"role": "user", "content": message})
# Send API request and stream response
response = ""
is_complete = False
try:
for chunk in client.chat_completion(
messages,
max_tokens=max_tokens,
stream=True,
temperature=temperature,
top_p=top_p,
):
# Check for finish_reason in chunk
if hasattr(chunk.choices[0], 'finish_reason') and chunk.choices[0].finish_reason is not None:
is_complete = True
break
token = chunk.choices[0].delta.content
if token:
response += token
yield [(message, response)], conversation_id
# Save history if response is complete
if is_complete or response: # add response check as fallback
messages.append({"role": "assistant", "content": response})
try:
from src.knowledge_base.dataset import DatasetManager
from config.settings import HF_TOKEN
dataset = DatasetManager(token=HF_TOKEN) # Explicitly pass the token
success, msg = dataset.save_chat_history(conversation_id, messages)
print(f"Chat history save attempt: {success}, Message: {msg}") # Add debug log
if not success:
print(f"Failed to save chat history: {msg}")
except Exception as e:
import traceback
print(f"Exception while saving chat history: {str(e)}")
print(traceback.format_exc()) # Print full traceback for debugging
except Exception as e:
print(f"Error generating response: {str(e)}")
yield [(message, "An error occurred while generating the response.")], conversation_id
def build_kb():
"""Function to create knowledge base"""
try:
success, message = create_vector_store()
return message
except Exception as e:
return f"Error creating knowledge base: {str(e)}"
def load_vector_store():
"""Load knowledge base from dataset"""
try:
from src.knowledge_base.dataset import DatasetManager
dataset = DatasetManager()
success, store = dataset.download_vector_store()
if success:
return store
print(f"Error loading knowledge base: {store}")
return None
except Exception as e:
print(f"Error loading knowledge base: {str(e)}")
return None
# Create interface
with gr.Blocks() as demo:
with gr.Tabs():
with gr.Tab("Chat"):
gr.Markdown("# ⚖️ Status Law Assistant")
conversation_id = gr.State(None)
with gr.Row():
with gr.Column(scale=3):
chatbot = gr.Chatbot(
label="Chat",
bubble_full_width=False,
avatar_images=["user.png", "assistant.png"] # optional
)
with gr.Row():
msg = gr.Textbox(
label="Your question",
placeholder="Enter your question...",
scale=4
)
submit_btn = gr.Button("Send", variant="primary")
with gr.Column(scale=1):
gr.Markdown("### Knowledge Base Management")
build_kb_btn = gr.Button("Create/Update Knowledge Base", variant="primary")
kb_status = gr.Textbox(label="Knowledge Base Status", interactive=False)
gr.Markdown("### Generation Settings")
max_tokens = gr.Slider(
minimum=1,
maximum=2048,
value=512,
step=1,
label="Maximum Response Length",
info="Limits the number of tokens in response. More tokens = longer response"
)
temperature = gr.Slider(
minimum=0.1,
maximum=2.0,
value=0.7,
step=0.1,
label="Temperature",
info="Controls creativity. Lower value = more predictable responses"
)
top_p = gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.95,
step=0.05,
label="Top-p",
info="Controls diversity. Lower value = more focused responses"
)
clear_btn = gr.Button("Clear Chat History")
def respond_and_clear(
message,
history,
conversation_id,
max_tokens,
temperature,
top_p,
):
# Use existing respond function
response_generator = respond(
message,
history,
conversation_id,
DEFAULT_SYSTEM_MESSAGE,
max_tokens,
temperature,
top_p,
)
# Return result and empty string to clear input field
for response in response_generator:
yield response[0], response[1], "" # chatbot, conversation_id, empty string for msg
# Event handlers
msg.submit(
respond_and_clear,
[msg, chatbot, conversation_id, max_tokens, temperature, top_p],
[chatbot, conversation_id, msg] # Add msg to output parameters
)
submit_btn.click(
respond_and_clear,
[msg, chatbot, conversation_id, max_tokens, temperature, top_p],
[chatbot, conversation_id, msg] # Add msg to output parameters
)
build_kb_btn.click(build_kb, None, kb_status)
clear_btn.click(lambda: ([], None), None, [chatbot, conversation_id])
with gr.Tab("Model Training"):
gr.Markdown("### Model Training Interface")
with gr.Row():
with gr.Column():
epochs = gr.Slider(minimum=1, maximum=10, value=3, step=1, label="Number of Epochs")
batch_size = gr.Slider(minimum=1, maximum=32, value=4, step=1, label="Batch Size")
learning_rate = gr.Slider(minimum=1e-6, maximum=1e-3, value=2e-4, label="Learning Rate")
train_btn = gr.Button("Start Training", variant="primary")
training_output = gr.Textbox(label="Training Status", interactive=False)
with gr.Column():
analysis_btn = gr.Button("Generate Chat Analysis")
analysis_output = gr.Markdown()
train_btn.click(
start_finetune_action,
inputs=[epochs, batch_size, learning_rate],
outputs=[training_output]
)
analysis_btn.click(
generate_chat_analysis,
inputs=[],
outputs=[analysis_output]
)
# Launch application
if __name__ == "__main__":
# Check knowledge base availability in dataset
if not load_vector_store():
print("Knowledge base not found. Please create it through the interface.")
demo.launch()