Spaces:
Running
Running
| import gradio as gr | |
| import os | |
| from huggingface_hub import InferenceClient | |
| from config.constants import DEFAULT_SYSTEM_MESSAGE | |
| from config.settings import HF_TOKEN, MODEL_CONFIG, EMBEDDING_MODEL | |
| from src.knowledge_base.vector_store import create_vector_store, load_vector_store | |
| from web.training_interface import ( | |
| get_models_df, | |
| generate_chat_analysis, | |
| register_model_action, | |
| start_finetune_action | |
| ) | |
| if not HF_TOKEN: | |
| raise ValueError("HUGGINGFACE_TOKEN not found in environment variables") | |
| # Initialize HF client with token | |
| client = InferenceClient( | |
| MODEL_CONFIG["id"], | |
| token=HF_TOKEN | |
| ) | |
| # State for storing context | |
| context_store = {} | |
| def get_context(message, conversation_id): | |
| """Get context from knowledge base""" | |
| vector_store = load_vector_store() | |
| if vector_store is None: | |
| print("Knowledge base not found or failed to load") | |
| return "" | |
| # Check if vector_store is a string (error message) instead of an actual store | |
| if isinstance(vector_store, str): | |
| print(f"Error with vector store: {vector_store}") | |
| return "" | |
| try: | |
| # Extract context | |
| context_docs = vector_store.similarity_search(message, k=3) | |
| context_text = "\n\n".join([f"From {doc.metadata.get('source', 'unknown')}: {doc.page_content}" for doc in context_docs]) | |
| # Save context for this conversation | |
| context_store[conversation_id] = context_text | |
| return context_text | |
| except Exception as e: | |
| print(f"Error getting context: {str(e)}") | |
| return "" | |
| def load_vector_store(): | |
| """Load knowledge base from dataset""" | |
| try: | |
| from src.knowledge_base.dataset import DatasetManager | |
| print("Debug - Attempting to load vector store...") | |
| dataset = DatasetManager() | |
| success, result = dataset.download_vector_store() | |
| print(f"Debug - Download result: success={success}, result_type={type(result)}") | |
| if success: | |
| if isinstance(result, str): | |
| print(f"Debug - Error message received: {result}") | |
| return None | |
| return result | |
| else: | |
| print(f"Debug - Failed to load vector store: {result}") | |
| return None | |
| except Exception as e: | |
| import traceback | |
| print(f"Exception loading knowledge base: {str(e)}") | |
| print(traceback.format_exc()) | |
| return None | |
| def respond( | |
| message, | |
| history, | |
| conversation_id, | |
| system_message, | |
| max_tokens, | |
| temperature, | |
| top_p, | |
| ): | |
| # Create ID for new conversation | |
| if not conversation_id: | |
| import uuid | |
| conversation_id = str(uuid.uuid4()) | |
| # Get context from knowledge base | |
| context = get_context(message, conversation_id) | |
| # Convert history from Gradio format to OpenAI format | |
| messages = [{"role": "system", "content": system_message}] | |
| if context: | |
| messages[0]["content"] += f"\n\nContext for response:\n{context}" | |
| # Debug: print the history format | |
| print("Debug - Processing history format:", history) | |
| # Convert history to OpenAI format for API call | |
| if history: | |
| try: | |
| for item in history: | |
| # Check if we have a pair of messages as expected | |
| if len(item) == 2: | |
| user_msg, assistant_msg = item | |
| # Add user message | |
| messages.append({"role": "user", "content": user_msg}) | |
| # Add assistant message | |
| messages.append({"role": "assistant", "content": assistant_msg}) | |
| except Exception as e: | |
| print(f"Error processing history: {str(e)}") | |
| # Continue with empty history if there was an error | |
| # Add current user message | |
| messages.append({"role": "user", "content": message}) | |
| # Debug: print API messages | |
| print("Debug - API messages:", messages) | |
| # Send API request and stream response | |
| response = "" | |
| is_complete = False | |
| try: | |
| # Non-streaming version for debugging | |
| full_response = client.chat_completion( | |
| messages, | |
| max_tokens=max_tokens, | |
| stream=False, | |
| temperature=temperature, | |
| top_p=top_p, | |
| ) | |
| response = full_response.choices[0].message.content | |
| print(f"Debug - Full response from API: {response}") | |
| # Return complete response immediately | |
| final_history = history.copy() if history else [] | |
| final_history.append((message, response)) | |
| yield final_history, conversation_id | |
| except Exception as e: | |
| print(f"Debug - Error during API call: {str(e)}") | |
| error_history = history.copy() if history else [] | |
| error_history.append((message, f"An error occurred: {str(e)}")) | |
| yield error_history, conversation_id | |
| def update_kb(): | |
| """Function to update existing knowledge base with new documents""" | |
| try: | |
| success, message = create_vector_store(mode="update") | |
| return message | |
| except Exception as e: | |
| return f"Error updating knowledge base: {str(e)}" | |
| def rebuild_kb(): | |
| """Function to create knowledge base from scratch""" | |
| try: | |
| success, message = create_vector_store(mode="rebuild") | |
| return message | |
| except Exception as e: | |
| return f"Error creating knowledge base: {str(e)}" | |
| def respond_and_clear(message, history, conversation_id): | |
| """Handle chat message and clear input""" | |
| # Get model parameters from config | |
| max_tokens = MODEL_CONFIG['parameters']['max_length'] | |
| temperature = MODEL_CONFIG['parameters']['temperature'] | |
| top_p = MODEL_CONFIG['parameters']['top_p'] | |
| # Print debug information to help diagnose the issue | |
| print("Debug - Message type:", type(message), "Content:", message) | |
| print("Debug - History type:", type(history), "Content:", history) | |
| try: | |
| # Get response generator | |
| response_generator = respond( | |
| message=message, | |
| history=history if history else [], | |
| conversation_id=conversation_id, | |
| system_message=DEFAULT_SYSTEM_MESSAGE, | |
| max_tokens=max_tokens, | |
| temperature=temperature, | |
| top_p=top_p | |
| ) | |
| # Get first response from generator | |
| new_history, conv_id = next(response_generator) | |
| # Debug the response | |
| print("Debug - Final history:", new_history) | |
| return new_history, conv_id, "" # Clear message input | |
| except Exception as e: | |
| print(f"Error in respond_and_clear: {str(e)}") | |
| return history + [(message, f"An error occurred: {str(e)}")], conversation_id, "" | |
| # Create interface | |
| with gr.Blocks() as demo: | |
| with gr.Tabs(): | |
| with gr.Tab("Chat"): | |
| gr.Markdown("# ⚖️ Status Law Assistant") | |
| conversation_id = gr.State(None) | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| chatbot = gr.Chatbot( | |
| label="Chat", | |
| avatar_images=["user.png", "assistant.png"] | |
| ) | |
| with gr.Row(): | |
| msg = gr.Textbox( | |
| label="Your question", | |
| placeholder="Enter your question...", | |
| scale=4 | |
| ) | |
| submit_btn = gr.Button("Send", variant="primary") | |
| clear_btn = gr.Button("Clear") # Add clear button | |
| with gr.Column(scale=1): | |
| gr.Markdown("### Knowledge Base Management") | |
| gr.Markdown(""" | |
| - **Update**: Add new documents to existing base | |
| - **Rebuild**: Create new base from scratch | |
| """) | |
| with gr.Row(): | |
| update_kb_btn = gr.Button("📝 Update Base", variant="secondary", scale=1) | |
| rebuild_kb_btn = gr.Button("🔄 Rebuild Base", variant="primary", scale=1) | |
| kb_status = gr.Textbox( | |
| label="Status", | |
| placeholder="Knowledge base status will appear here...", | |
| interactive=False | |
| ) | |
| submit_btn.click( | |
| respond_and_clear, | |
| [msg, chatbot, conversation_id], | |
| [chatbot, conversation_id, msg] | |
| ) | |
| update_kb_btn.click(update_kb, None, kb_status) | |
| rebuild_kb_btn.click(rebuild_kb, None, kb_status) | |
| clear_btn.click(lambda: ([], None), None, [chatbot, conversation_id]) | |
| with gr.Tab("Model Settings"): | |
| gr.Markdown("### Model Configuration") | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| # Model Information | |
| gr.Markdown(f""" | |
| **Current Model:** {MODEL_CONFIG['name']} | |
| **Model ID:** `{MODEL_CONFIG['id']}` | |
| **Description:** {MODEL_CONFIG['description']} | |
| **Type:** {MODEL_CONFIG['type']} | |
| **Embeddings Model:** `{EMBEDDING_MODEL}` | |
| *Used for vector store creation and similarity search* | |
| """) | |
| gr.Markdown("### Model Parameters") | |
| with gr.Row(): | |
| max_length = gr.Slider( | |
| minimum=1, | |
| maximum=4096, | |
| value=MODEL_CONFIG['parameters']['max_length'], | |
| step=1, | |
| label="Maximum Length", | |
| interactive=False | |
| ) | |
| temperature = gr.Slider( | |
| minimum=0.1, | |
| maximum=2.0, | |
| value=MODEL_CONFIG['parameters']['temperature'], | |
| step=0.1, | |
| label="Temperature", | |
| interactive=False | |
| ) | |
| with gr.Row(): | |
| top_p = gr.Slider( | |
| minimum=0.1, | |
| maximum=1.0, | |
| value=MODEL_CONFIG['parameters']['top_p'], | |
| step=0.1, | |
| label="Top-p", | |
| interactive=False | |
| ) | |
| rep_penalty = gr.Slider( | |
| minimum=1.0, | |
| maximum=2.0, | |
| value=MODEL_CONFIG['parameters']['repetition_penalty'], | |
| step=0.1, | |
| label="Repetition Penalty", | |
| interactive=False | |
| ) | |
| gr.Markdown(""" | |
| <small> | |
| **Parameters explanation:** | |
| - **Maximum Length**: Maximum number of tokens in the generated response | |
| - **Temperature**: Controls randomness (0.1 = very focused, 2.0 = very creative) | |
| - **Top-p**: Controls diversity via nucleus sampling (lower = more focused) | |
| - **Repetition Penalty**: Prevents word repetition (higher = less repetition) | |
| </small> | |
| """) | |
| with gr.Column(scale=1): | |
| gr.Markdown("### Training Configuration") | |
| gr.Markdown(f""" | |
| **Base Model Path:** | |
| ``` | |
| {MODEL_CONFIG['training']['base_model_path']} | |
| ``` | |
| **Fine-tuned Model Path:** | |
| ``` | |
| {MODEL_CONFIG['training']['fine_tuned_path']} | |
| ``` | |
| **LoRA Configuration:** | |
| - Rank (r): {MODEL_CONFIG['training']['lora_config']['r']} | |
| - Alpha: {MODEL_CONFIG['training']['lora_config']['lora_alpha']} | |
| - Dropout: {MODEL_CONFIG['training']['lora_config']['lora_dropout']} | |
| """) | |
| with gr.Tab("Model Training"): | |
| gr.Markdown("### Model Training Interface") | |
| with gr.Row(): | |
| with gr.Column(): | |
| epochs = gr.Slider(minimum=1, maximum=10, value=3, step=1, label="Number of Epochs") | |
| batch_size = gr.Slider(minimum=1, maximum=32, value=4, step=1, label="Batch Size") | |
| learning_rate = gr.Slider(minimum=1e-6, maximum=1e-3, value=2e-4, label="Learning Rate") | |
| train_btn = gr.Button("Start Training", variant="primary") | |
| training_output = gr.Textbox(label="Training Status", interactive=False) | |
| with gr.Column(): | |
| analysis_btn = gr.Button("Generate Chat Analysis") | |
| analysis_output = gr.Markdown() | |
| train_btn.click( | |
| start_finetune_action, | |
| inputs=[epochs, batch_size, learning_rate], | |
| outputs=[training_output] | |
| ) | |
| analysis_btn.click( | |
| generate_chat_analysis, | |
| inputs=[], | |
| outputs=[analysis_output] | |
| ) | |
| # Launch application | |
| if __name__ == "__main__": | |
| # Check knowledge base availability in dataset | |
| if not load_vector_store(): | |
| print("Knowledge base not found. Please create it through the interface.") | |
| demo.launch() |