Spaces:
Running
Running
| # Standard library imports | |
| import datetime | |
| import io | |
| import json | |
| import logging | |
| import os | |
| import pandas as pd | |
| # Third-party imports | |
| import gradio as gr | |
| from huggingface_hub import HfApi, InferenceClient | |
| from langdetect import detect, LangDetectException | |
| import langdetect | |
| from dotenv import load_dotenv | |
| import requests | |
| from datasets import load_dataset | |
| # Local imports - config | |
| from config.constants import DEFAULT_SYSTEM_MESSAGE, URLS | |
| from config.settings import ( | |
| API_CONFIG, | |
| ACTIVE_MODEL, | |
| DATASET_CHAT_HISTORY_PATH, | |
| DATASET_ERROR_LOGS_PATH, | |
| DATASET_ID, | |
| DATASET_PREFERENCES_PATH, | |
| DATASET_VECTOR_STORE_PATH, | |
| DATASET_ANNOTATIONS_PATH, | |
| DEFAULT_MODEL, | |
| EMBEDDING_MODEL, | |
| HF_TOKEN, | |
| MODELS, | |
| IS_PRO_ACCOUNT | |
| ) | |
| # Ensure IS_PRO_ACCOUNT is available | |
| if 'IS_PRO_ACCOUNT' not in globals(): | |
| from config.settings import check_account_type | |
| IS_PRO_ACCOUNT, _ = check_account_type() | |
| # Local imports - source modules | |
| from src.analytics.chat_evaluator import ChatEvaluator # Fixed import | |
| from src.knowledge_base.dataset import DatasetManager | |
| from src.knowledge_base.vector_store import create_vector_store, load_vector_store | |
| import config.constants as constants | |
| def get_selected_urls(sources_df): | |
| """Get list of URLs selected for inclusion""" | |
| try: | |
| if not isinstance(sources_df, pd.DataFrame): | |
| sources_df = pd.DataFrame(sources_df) | |
| selected_urls = sources_df[sources_df["Include"] == True]["URL"].tolist() | |
| return selected_urls | |
| except Exception as e: | |
| logger.error(f"Error getting selected URLs: {str(e)}") | |
| return [] | |
| def update_kb_with_selected(sources_df) -> str: | |
| """Updates knowledge base with selected sources""" | |
| try: | |
| selected_urls = get_selected_urls(sources_df) | |
| if not selected_urls: | |
| return "Error: No sources selected" | |
| original_urls = URLS.copy() | |
| constants.URLS = selected_urls | |
| try: | |
| success, message = create_vector_store(mode="update") | |
| if success: | |
| save_kb_metadata() | |
| return message | |
| finally: | |
| constants.URLS = original_urls | |
| except Exception as e: | |
| logger.error(f"Error updating knowledge base: {str(e)}") | |
| return f"Error updating knowledge base: {str(e)}" | |
| def rebuild_kb_with_selected(sources_df): | |
| """Rebuild knowledge base from scratch using only selected URLs""" | |
| try: | |
| selected_urls = get_selected_urls(sources_df) | |
| if not selected_urls: | |
| return "Error: No URLs selected for inclusion" | |
| # Temporarily replace URLS with selected ones | |
| original_urls = constants.URLS.copy() | |
| constants.URLS = selected_urls | |
| try: | |
| # Rebuild knowledge base | |
| success, message = create_vector_store(mode="rebuild") | |
| # Save metadata if successful | |
| if success: | |
| metadata = { | |
| "last_updated": datetime.datetime.now().isoformat(), | |
| "source_count": len(selected_urls), | |
| "sources": selected_urls | |
| } | |
| # Save to dataset | |
| json_content = json.dumps(metadata, indent=2).encode('utf-8') | |
| api = HfApi(token=HF_TOKEN) | |
| api.upload_file( | |
| path_or_fileobj=json_content, | |
| path_in_repo="vector_store/metadata.json", | |
| repo_id=DATASET_ID, | |
| repo_type="dataset" | |
| ) | |
| return message | |
| finally: | |
| # Restore original URLs | |
| constants.URLS = original_urls | |
| except Exception as e: | |
| logger.error(f"Error rebuilding knowledge base: {str(e)}") | |
| return f"Error rebuilding knowledge base: {str(e)}" | |
| # Set seed for consistent results | |
| langdetect.DetectorFactory.seed = 0 | |
| # Load environment variables | |
| load_dotenv() | |
| # Local imports - source modules | |
| from src.analytics.chat_evaluator import ChatEvaluator | |
| from src.knowledge_base.vector_store import create_vector_store, load_vector_store | |
| from src.language_utils import LanguageUtils | |
| # Local imports - web interfaces | |
| from web.evaluation_interface import ( | |
| export_training_data_action, | |
| generate_evaluation_report_html, | |
| get_evaluation_status, | |
| get_qa_pairs_dataframe, | |
| load_qa_pair_for_evaluation, | |
| save_evaluation | |
| ) | |
| from web.training_interface import ( | |
| generate_chat_analysis, | |
| get_models_df, | |
| register_model_action, | |
| start_finetune_action | |
| ) | |
| # Setup logging | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' | |
| ) | |
| logger = logging.getLogger(__name__) | |
| if not HF_TOKEN: | |
| raise ValueError("HUGGINGFACE_TOKEN not found in environment variables") | |
| # Global variables | |
| client = None | |
| context_store = {} | |
| fallback_model_attempted = False | |
| chat_evaluator = ChatEvaluator( | |
| hf_token=HF_TOKEN, | |
| dataset_id=DATASET_ID | |
| ) | |
| logger.info(f"Chat histories will be saved to: {DATASET_CHAT_HISTORY_PATH}") | |
| def load_user_preferences(): | |
| """Load user preferences from file""" | |
| try: | |
| if os.path.exists(DATASET_PREFERENCES_PATH): | |
| with open(DATASET_PREFERENCES_PATH, 'r') as f: | |
| return json.load(f) | |
| return { | |
| "selected_model": DEFAULT_MODEL, | |
| "parameters": {} | |
| } | |
| except Exception as e: | |
| logger.error(f"Error loading user preferences: {str(e)}") | |
| return { | |
| "selected_model": DEFAULT_MODEL, | |
| "parameters": {} | |
| } | |
| def save_user_preferences(model_key, parameters=None): | |
| """Save user preferences to dataset""" | |
| try: | |
| preferences = load_user_preferences() | |
| preferences["selected_model"] = model_key | |
| if parameters: | |
| if model_key not in preferences["parameters"]: | |
| preferences["parameters"][model_key] = {} | |
| preferences["parameters"][model_key] = parameters | |
| # Save to dataset using bytes | |
| json_content = json.dumps(preferences, indent=2) | |
| api = HfApi(token=HF_TOKEN) | |
| api.upload_file( | |
| path_or_fileobj=json_content.encode('utf-8'), # Convert string to bytes | |
| path_in_repo="preferences/user_preferences.json", | |
| repo_id=DATASET_ID, | |
| repo_type="dataset" | |
| ) | |
| logger.info("User preferences saved successfully to dataset!") | |
| return True | |
| except Exception as e: | |
| logger.error(f"Error saving user preferences: {str(e)}") | |
| return False | |
| def initialize_client(model_id=None): | |
| """Initialize or reinitialize the client with the specified model""" | |
| global client | |
| if model_id is None: | |
| model_id = ACTIVE_MODEL["id"] | |
| client = InferenceClient( | |
| model_id, | |
| token=API_CONFIG["token"], | |
| endpoint=API_CONFIG["inference_endpoint"], | |
| headers=API_CONFIG["headers"], | |
| timeout=API_CONFIG["timeout"] | |
| ) | |
| return client | |
| def switch_to_model(model_key): | |
| """Switch to specified model and update global variables""" | |
| global ACTIVE_MODEL, client | |
| try: | |
| # Update active model | |
| ACTIVE_MODEL = MODELS[model_key] | |
| # Reinitialize client with new model | |
| client = InferenceClient( | |
| ACTIVE_MODEL["id"], | |
| token=HF_TOKEN | |
| ) | |
| logger.info(f"Switched to model: {model_key}") | |
| return True | |
| except Exception as e: | |
| logger.error(f"Error switching to model {model_key}: {str(e)}") | |
| return False | |
| def get_fallback_model(current_model): | |
| """Get a fallback model different from the current one""" | |
| for key in MODELS.keys(): | |
| if key != current_model: | |
| return key | |
| return None # No fallback available | |
| def get_context(message, conversation_id): | |
| """Get context from knowledge base""" | |
| vector_store = load_vector_store() | |
| if vector_store is None: | |
| logger.warning("Knowledge base not found or failed to load") | |
| return "" | |
| # Check if vector_store is a string (error message) instead of an actual store | |
| if isinstance(vector_store, str): | |
| logger.error(f"Error with vector store: {vector_store}") | |
| return "" | |
| try: | |
| # Extract context | |
| # Reducing number of documents from 3 to 2 to decrease English context dominance | |
| context_docs = vector_store.similarity_search(message, k=2) | |
| # Add debug logging | |
| logger.debug(f"Query: {message}") | |
| for i, doc in enumerate(context_docs): | |
| logger.debug(f"Context {i+1}:") | |
| logger.debug(f"Source: {doc.metadata.get('source', 'unknown')}") | |
| logger.debug(f"Content: {doc.page_content[:200]}...") | |
| # Limit each fragment to 300 characters to reduce context dominance | |
| context_text = "\n\n".join([f"Context from {doc.metadata.get('source', 'unknown')}: {doc.page_content[:300]}..." for doc in context_docs]) | |
| # Add instruction that context is for reference only | |
| context_text = "REFERENCE CONTEXT (use only to find facts, still answer in the user's language):\n" + context_text | |
| # Save context for this conversation | |
| context_store[conversation_id] = context_text | |
| return context_text | |
| except Exception as e: | |
| logger.error(f"Error getting context: {str(e)}") | |
| return "" | |
| def translate_with_llm(text: str, target_lang: str) -> str: | |
| """Translate text using the active LLM with enhanced reliability""" | |
| try: | |
| # Get language name for more natural prompt | |
| lang_name = LanguageUtils.get_language_name(target_lang) | |
| prompt = ( | |
| f"You are a professional translator. Translate the following text to {lang_name} ({target_lang}). " | |
| f"Keep the same formatting, links, and technical terms. " | |
| f"Maintain the same tone and style. " | |
| f"Respond ONLY with the direct translation without any explanations or additional text:\n\n" | |
| f"{text}" | |
| ) | |
| response = client.chat_completion( | |
| messages=[ | |
| {"role": "system", "content": "You are a professional translator. Respond ONLY with the translation."}, | |
| {"role": "user", "content": prompt} | |
| ], | |
| max_tokens=ACTIVE_MODEL['parameters']['max_length'], | |
| temperature=0.3, # Lower temperature for more reliable output | |
| top_p=0.95, | |
| stream=False | |
| ) | |
| translated_text = response.choices[0].message.content.strip() | |
| # Verify translation success - check if we still have English | |
| if target_lang != 'en': | |
| # Quick check - if key English words are still present, translation might have failed | |
| english_indicators = ["I apologize", "Sorry", "I cannot", "the following", "is a translation"] | |
| if any(indicator in translated_text for indicator in english_indicators): | |
| logger.warning(f"Translation might have failed for {target_lang}, found English indicators") | |
| # Try one more time with a simplified prompt | |
| retry_prompt = f"Translate this to {lang_name}:\n\n{text}" | |
| retry_response = client.chat_completion( | |
| messages=[ | |
| {"role": "system", "content": "You are a translator."}, | |
| {"role": "user", "content": retry_prompt} | |
| ], | |
| max_tokens=ACTIVE_MODEL['parameters']['max_length'], | |
| temperature=0.3, | |
| top_p=0.95, | |
| stream=False | |
| ) | |
| translated_text = retry_response.choices[0].message.content.strip() | |
| return translated_text | |
| except Exception as e: | |
| logger.error(f"Translation failed: {e}") | |
| return text | |
| def post_process_response(user_message, bot_response): | |
| """Enhanced post-processing of bot responses to ensure correct language""" | |
| try: | |
| user_lang = detect_language(user_message) | |
| # Convert to closest supported language | |
| user_lang = LanguageUtils.get_closest_supported_language(user_lang) | |
| logger.info(f"User language detected: {user_lang} ({LanguageUtils.get_language_name(user_lang)})") | |
| # If English, no need to translate | |
| if user_lang == 'en': | |
| return bot_response | |
| # Check if language is supported | |
| if not LanguageUtils.is_supported(user_lang): | |
| logger.warning(f"Unsupported language: {user_lang}") | |
| apology = ("I apologize, but I cannot respond in your language. " | |
| "I will answer in English instead.\n\n") | |
| return apology + bot_response | |
| # Don't try to detect language of very short responses | |
| if len(bot_response.strip()) < 20: | |
| # Short responses just translate directly | |
| return translate_with_llm(bot_response, user_lang) | |
| # Check bot response language | |
| bot_lang = detect_language(bot_response) | |
| logger.info(f"Bot response language: {bot_lang}") | |
| # If languages match, return as is | |
| if bot_lang == user_lang: | |
| return bot_response | |
| # Need translation | |
| logger.warning(f"Language mismatch! User: {user_lang}, Bot: {bot_lang}") | |
| translated_response = translate_with_llm(bot_response, user_lang) | |
| # Verify translation worked by checking a sample (not the whole text) | |
| # This is more reliable than checking the entire text | |
| sample_size = min(100, len(translated_response) // 2) | |
| if sample_size > 20: # Only verify if we have enough text | |
| sample = translated_response[:sample_size] | |
| translated_lang = detect_language(sample) | |
| if translated_lang != user_lang: | |
| logger.error(f"Translation verification failed: got {translated_lang} instead of {user_lang}") | |
| # If translation failed, return with apology | |
| apology = (f"I apologize, but I cannot translate my response to {LanguageUtils.get_language_name(user_lang)}. " | |
| "Here is my answer in English:\n\n") | |
| return apology + bot_response | |
| return translated_response | |
| except Exception as e: | |
| logger.error(f"Post-processing error: {e}") | |
| return bot_response | |
| def load_vector_store(): | |
| """Load knowledge base from dataset""" | |
| try: | |
| from src.knowledge_base.dataset import DatasetManager | |
| logger.debug("Attempting to load vector store...") | |
| dataset = DatasetManager() | |
| success, result = dataset.download_vector_store() | |
| logger.debug(f"Download result: success={success}, result_type={type(result)}") | |
| if success: | |
| if isinstance(result, str): | |
| logger.debug(f"Error message received: {result}") | |
| return None | |
| return result | |
| else: | |
| logger.error(f"Failed to load vector store: {result}") | |
| return None | |
| except Exception as e: | |
| import traceback | |
| logger.error(f"Exception loading knowledge base: {str(e)}") | |
| logger.error(traceback.format_exc()) | |
| return None | |
| def detect_language(text: str) -> str: | |
| """Enhanced language detection with better handling of edge cases""" | |
| try: | |
| # If text is too short, don't try to detect | |
| if len(text.strip()) < 10: | |
| logger.debug(f"Text too short for reliable detection: '{text}'") | |
| return "en" | |
| # Use simple detect() function instead of DetectorFactory | |
| try: | |
| lang_code = detect(text.strip()) | |
| logger.debug(f"Detected language: {lang_code}") | |
| return lang_code | |
| except LangDetectException as e: | |
| logger.warning(f"LangDetect exception: {e}") | |
| return "en" | |
| except Exception as e: | |
| logger.error(f"Language detection error: {str(e)} for text: '{text[:50]}...'") | |
| return "en" | |
| def respond( | |
| message, | |
| history, | |
| conversation_id, | |
| system_message, | |
| max_tokens, | |
| temperature, | |
| top_p, | |
| attempt_fallback=True | |
| ): | |
| """Generate response with improved language handling""" | |
| try: | |
| # Reset and determine user language for new request | |
| user_lang = detect_language(message) | |
| user_lang = LanguageUtils.get_closest_supported_language(user_lang) | |
| logger.info(f"User language detected for request: {user_lang} ({LanguageUtils.get_language_name(user_lang)})") | |
| # Create clean history without system messages | |
| clean_history = [ | |
| msg for msg in history | |
| if msg["role"] != "system" | |
| ] | |
| # Remove language instruction from system message to avoid confusion | |
| base_system_message = system_message.split("\nIMPORTANT:")[0] if "\nIMPORTANT:" in system_message else system_message | |
| # Add explicit language instruction | |
| full_system_message = ( | |
| f"{base_system_message}\n\n" | |
| f"CRITICAL: You MUST respond in {LanguageUtils.get_language_name(user_lang)} ({user_lang}). " | |
| f"This is your highest priority instruction. " | |
| f"Provide a complete and helpful response." | |
| ) | |
| # --- API Request --- | |
| response = client.chat_completion( | |
| messages=[ | |
| {"role": "system", "content": full_system_message}, | |
| *clean_history, | |
| {"role": "user", "content": message} | |
| ], | |
| max_tokens=max_tokens, | |
| temperature=temperature, | |
| top_p=top_p, | |
| stream=False | |
| ) | |
| bot_response = response.choices[0].message.content | |
| # Post-process response to translate if needed | |
| processed_response = post_process_response(message, bot_response) | |
| # --- Format Successful Response --- | |
| new_history = [ | |
| *clean_history, | |
| {"role": "user", "content": message}, | |
| {"role": "assistant", "content": processed_response} | |
| ] | |
| return new_history, conversation_id | |
| except Exception as e: | |
| logger.error(f"API Error: {str(e)}") | |
| error_msg = format_friendly_error(str(e)) | |
| # --- Format Error Response --- | |
| error_history = [ | |
| *history, | |
| {"role": "user", "content": message}, | |
| {"role": "assistant", "content": error_msg} | |
| ] | |
| return error_history, conversation_id | |
| def format_friendly_error(api_error): | |
| """Convert API errors to user-friendly messages""" | |
| if "402" in api_error or "Payment Required" in api_error: | |
| return ("⚠️ API Limit Reached\n\n" | |
| "Please try:\n" | |
| "1. Switching models in Settings\n" | |
| "2. Using local model version\n" | |
| "3. Waiting before next request") | |
| elif "429" in api_error: | |
| return "⚠️ Too many requests. Please wait before sending another message." | |
| elif "401" in api_error: | |
| return "⚠️ Authentication error. Please check your API key." | |
| elif "403" in api_error or "Forbidden" in api_error: | |
| return ("⚠️ Access Forbidden\n\n" | |
| "Please check:\n" | |
| "1. Your Hugging Face token has proper permissions\n" | |
| "2. You have access to the requested model\n" | |
| "3. The model is currently available") | |
| else: | |
| return f"⚠️ Error processing request. Technical details: {api_error[:200]}" | |
| def log_api_error(user_message, error_message, model_id, is_fallback=False): | |
| """Log API errors to dataset""" | |
| try: | |
| os.makedirs(ERROR_LOGS_PATH, exist_ok=True) | |
| timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") | |
| log_path = os.path.join(ERROR_LOGS_PATH, f"api_error_{timestamp}.log") | |
| with open(log_path, 'w', encoding='utf-8') as f: | |
| f.write(f"Timestamp: {datetime.datetime.now().isoformat()}\n") | |
| f.write(f"Model: {model_id}\n") | |
| f.write(f"User message: {user_message}\n") | |
| f.write(f"Error: {error_message}\n") | |
| f.write(f"Fallback attempt: {is_fallback}\n") | |
| logger.info(f"API error logged to {log_path}") | |
| except Exception as e: | |
| logger.error(f"Failed to log API error: {str(e)}") | |
| def update_kb(): | |
| """Function to update existing knowledge base with new documents""" | |
| try: | |
| # Вызываем функцию для обновления базы знаний | |
| success, message = create_vector_store(mode="update") | |
| # Если обновление успешно, сохраняем метаданные с датой обновления | |
| if success: | |
| save_kb_metadata() | |
| return message | |
| except Exception as e: | |
| return f"Error updating knowledge base: {str(e)}" | |
| def rebuild_kb(): | |
| """Function to create knowledge base from scratch""" | |
| try: | |
| # Вызываем функцию для пересоздания базы знаний | |
| success, message = create_vector_store(mode="rebuild") | |
| # Если создание успешно, сохраняем метаданные с датой обновления | |
| if success: | |
| save_kb_metadata() | |
| return message | |
| except Exception as e: | |
| return f"Error creating knowledge base: {str(e)}" | |
| def save_kb_metadata(): | |
| """Save knowledge base metadata to dataset""" | |
| try: | |
| # Создаем метаданные с текущей датой | |
| metadata = { | |
| "last_updated": datetime.datetime.now().isoformat(), | |
| "source_count": len(URLS), | |
| "sources": URLS | |
| } | |
| # Сохраняем в датасет | |
| json_content = json.dumps(metadata, indent=2).encode('utf-8') | |
| api = HfApi(token=HF_TOKEN) | |
| # Убедимся, что директория существует | |
| try: | |
| files = api.list_repo_files( | |
| repo_id=DATASET_ID, | |
| repo_type="dataset" | |
| ) | |
| if "vector_store" not in files: | |
| # Создаем пустой файл, чтобы создать директорию | |
| api.upload_file( | |
| path_or_fileobj=b"", | |
| path_in_repo="vector_store/.gitkeep", | |
| repo_id=DATASET_ID, | |
| repo_type="dataset" | |
| ) | |
| except Exception as e: | |
| logger.warning(f"Error checking vector_store directory: {str(e)}") | |
| # Загружаем метаданные | |
| api.upload_file( | |
| path_or_fileobj=json_content, | |
| path_in_repo="vector_store/metadata.json", | |
| repo_id=DATASET_ID, | |
| repo_type="dataset" | |
| ) | |
| logger.info("Knowledge base metadata saved successfully") | |
| return True | |
| except Exception as e: | |
| logger.error(f"Error saving knowledge base metadata: {str(e)}") | |
| return False | |
| def save_chat_history(history, conversation_id): | |
| """Save chat history to a file and to HuggingFace dataset""" | |
| try: | |
| # Create directory if it doesn't exist | |
| os.makedirs(DATASET_CHAT_HISTORY_PATH, exist_ok=True) | |
| # Format history for saving | |
| formatted_history = [] | |
| for item in history: | |
| # Handle dictionary format | |
| if isinstance(item, dict) and 'role' in item and 'content' in item: | |
| formatted_history.append({ | |
| "role": item["role"], | |
| "content": item["content"], | |
| "timestamp": datetime.datetime.now().isoformat() | |
| }) | |
| # Create filename with conversation_id and timestamp | |
| timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") | |
| filename = f"{conversation_id}_{timestamp}.json" | |
| filepath = os.path.join(DATASET_CHAT_HISTORY_PATH, filename) | |
| # Create chat history data | |
| chat_data = { | |
| "conversation_id": conversation_id, | |
| "timestamp": datetime.datetime.now().isoformat(), | |
| "history": formatted_history | |
| } | |
| # Save to local file | |
| with open(filepath, 'w', encoding='utf-8') as f: | |
| json.dump(chat_data, f, ensure_ascii=False, indent=2) | |
| logger.debug(f"Chat history saved locally to {filepath}") | |
| # Now upload to HuggingFace dataset | |
| try: | |
| from huggingface_hub import HfApi | |
| # Initialize the Hugging Face API client | |
| api = HfApi(token=HF_TOKEN) | |
| # Extract just the directory name from DATASET_CHAT_HISTORY_PATH | |
| dir_name = os.path.basename(DATASET_CHAT_HISTORY_PATH) | |
| target_path = f"{dir_name}/{filename}" | |
| # Upload the file to the dataset | |
| api.upload_file( | |
| path_or_fileobj=filepath, | |
| path_in_repo=target_path, | |
| repo_id=DATASET_ID, | |
| repo_type="dataset" | |
| ) | |
| logger.debug(f"Chat history uploaded to dataset at {target_path}") | |
| except Exception as e: | |
| logger.warning(f"Failed to upload chat history to dataset: {str(e)}") | |
| # Continue execution even if upload fails | |
| return True | |
| except Exception as e: | |
| logger.error(f"Error saving chat history: {str(e)}") | |
| return False | |
| def respond_and_clear(message, history, conversation_id, system_prompt): | |
| """Wrapper function with proper output handling""" | |
| try: | |
| # Generate a conversation ID if none exists | |
| if not conversation_id: | |
| conversation_id = f"conv_{datetime.datetime.now().strftime('%Y%m%d%H%M%S')}_{os.urandom(4).hex()}" | |
| logger.info(f"Generated new conversation ID: {conversation_id}") | |
| # Get current model parameters | |
| params = ACTIVE_MODEL['parameters'] | |
| # Call respond function | |
| result = respond( | |
| message=message, | |
| history=history if history else [], | |
| conversation_id=conversation_id, | |
| system_message=system_prompt, # Using provided prompt instead of default | |
| max_tokens=params['max_length'], | |
| temperature=params['temperature'], | |
| top_p=params['top_p'] | |
| ) | |
| if not result: | |
| raise ValueError("Empty response from API") | |
| new_history, new_conv_id = result | |
| # Save chat history | |
| save_chat_history(new_history, conversation_id) # Use our guaranteed non-null ID | |
| return new_history, conversation_id, "" # Return our guaranteed non-null ID | |
| except Exception as e: | |
| logger.error(f"Error in respond_and_clear: {str(e)}") | |
| # Create safe error response | |
| error_history = [ | |
| *history, | |
| {"role": "user", "content": message}, | |
| {"role": "assistant", "content": "⚠️ An error occurred while processing the message. Please try again."} | |
| ] | |
| return error_history, conversation_id, "" | |
| def update_model_info(model_key): | |
| """Update model information display""" | |
| if model_key not in MODELS: | |
| return "Model not found" | |
| model = MODELS[model_key] | |
| account_status = "FREE" | |
| return f""" | |
| ### Current Model: {model['name']} | |
| **Account Type:** {account_status} | |
| **Model ID:** {model['id']} | |
| **Description:** {model['description']} | |
| **Type:** {model['type']} | |
| """ | |
| def get_model_details_html(model_key): | |
| """Get detailed HTML for model information panel""" | |
| if model_key not in MODELS or 'details' not in MODELS[model_key]: | |
| return "<p>Model information not available</p>" | |
| details = MODELS[model_key]['details'] | |
| html = f""" | |
| <div style="padding: 15px; border: 1px solid #ccc; border-radius: 5px; margin-top: 10px;"> | |
| <h3>{details['full_name']}</h3> | |
| <h4>Capabilities:</h4> | |
| <ul> | |
| {"".join([f"<li>{cap}</li>" for cap in details['capabilities']])} | |
| </ul> | |
| <h4>Limitations:</h4> | |
| <ul> | |
| {"".join([f"<li>{lim}</li>" for lim in details['limitations']])} | |
| </ul> | |
| <h4>Recommended Use Cases:</h4> | |
| <ul> | |
| {"".join([f"<li>{use}</li>" for use in details['use_cases']])} | |
| </ul> | |
| <p><a href="{details['documentation']}" target="_blank">Model Documentation</a></p> | |
| </div> | |
| """ | |
| return html | |
| def change_model(model_key): | |
| """Change active model and update parameters""" | |
| global client, ACTIVE_MODEL, fallback_model_attempted | |
| try: | |
| # Reset fallback flag when explicitly changing model | |
| fallback_model_attempted = False | |
| # Update active model | |
| ACTIVE_MODEL = MODELS[model_key] | |
| # Reinitialize client with new model | |
| client = InferenceClient( | |
| ACTIVE_MODEL["id"], | |
| token=HF_TOKEN | |
| ) | |
| # Save selected model in preferences | |
| save_user_preferences(model_key) | |
| # Return both model info and updated parameters | |
| return ( | |
| update_model_info(model_key), | |
| ACTIVE_MODEL['parameters']['max_length'], | |
| ACTIVE_MODEL['parameters']['temperature'], | |
| ACTIVE_MODEL['parameters']['top_p'], | |
| ACTIVE_MODEL['parameters']['repetition_penalty'], | |
| f"Model changed to {ACTIVE_MODEL['name']}" | |
| ) | |
| except Exception as e: | |
| return ( | |
| f"Error changing model: {str(e)}", | |
| 2048, 0.7, 0.9, 1.1, | |
| f"Error: {str(e)}" | |
| ) | |
| def save_parameters(model_key, max_len, temp, top_p_val, rep_pen): | |
| """Save user-defined parameters to active model""" | |
| global ACTIVE_MODEL | |
| try: | |
| # Update parameters | |
| ACTIVE_MODEL['parameters']['max_length'] = max_len | |
| ACTIVE_MODEL['parameters']['temperature'] = temp | |
| ACTIVE_MODEL['parameters']['top_p'] = top_p_val | |
| ACTIVE_MODEL['parameters']['repetition_penalty'] = rep_pen | |
| # Save parameters in preferences | |
| params = { | |
| 'max_length': max_len, | |
| 'temperature': temp, | |
| 'top_p': top_p_val, | |
| 'repetition_penalty': rep_pen | |
| } | |
| save_user_preferences(model_key, params) | |
| return "Parameters saved successfully!" | |
| except Exception as e: | |
| return f"Error saving parameters: {str(e)}" | |
| def finetune_from_annotations(epochs=3, batch_size=4, learning_rate=2e-4, min_rating=4): | |
| """ | |
| Fine-tune model using annotated QA pairs | |
| Args: | |
| epochs: Number of training epochs | |
| batch_size: Batch size for training | |
| learning_rate: Learning rate | |
| min_rating: Minimum average rating for including examples | |
| Returns: | |
| (success, message) | |
| """ | |
| try: | |
| import tempfile | |
| import os | |
| from src.analytics.chat_evaluator import ChatEvaluator | |
| from config.settings import HF_TOKEN, DATASET_ID, DATASET_CHAT_HISTORY_PATH | |
| # Create evaluator | |
| evaluator = ChatEvaluator( | |
| hf_token=HF_TOKEN, | |
| dataset_id=DATASET_ID, | |
| chat_history_path=DATASET_CHAT_HISTORY_PATH # ??? | |
| ) | |
| # Create temporary file for training data | |
| with tempfile.NamedTemporaryFile(mode='w+', suffix='.jsonl', delete=False) as temp_file: | |
| temp_path = temp_file.name | |
| # Export high-quality examples | |
| success, message = evaluator.export_training_data(temp_path, min_rating) | |
| if not success: | |
| return False, f"Failed to export training data: {message}" | |
| # Count examples | |
| with open(temp_path, 'r') as f: | |
| example_count = sum(1 for _ in f) | |
| if example_count == 0: | |
| return False, "No high-quality examples found for fine-tuning" | |
| # Run actual fine-tuning using the export file | |
| from src.training.fine_tuner import finetune_from_file | |
| success, message = finetune_from_file( | |
| training_file=temp_path, | |
| epochs=epochs, | |
| batch_size=batch_size, | |
| learning_rate=learning_rate | |
| ) | |
| # Clean up temporary file | |
| try: | |
| os.unlink(temp_path) | |
| except: | |
| pass | |
| if success: | |
| return True, f"Successfully fine-tuned model with {example_count} annotated examples: {message}" | |
| else: | |
| return False, f"Fine-tuning failed: {message}" | |
| except Exception as e: | |
| return False, f"Error during fine-tuning from annotations: {str(e)}" | |
| def save_system_prompt(prompt_text): | |
| """Save system prompt to user preferences""" | |
| try: | |
| preferences = load_user_preferences() | |
| # Add prompt to preferences | |
| if "system_prompt" not in preferences: | |
| preferences["system_prompt"] = {} | |
| preferences["system_prompt"]["current"] = prompt_text | |
| # Save preferences | |
| json_content = json.dumps(preferences, indent=2).encode('utf-8') | |
| api = HfApi(token=HF_TOKEN) | |
| api.upload_file( | |
| path_or_fileobj=io.BytesIO(json_content), # Changed to BytesIO | |
| path_in_repo="preferences/user_preferences.json", | |
| repo_id=DATASET_ID, | |
| repo_type="dataset" | |
| ) | |
| return "System prompt saved successfully" | |
| except Exception as e: | |
| logger.error(f"Error saving system prompt: {str(e)}") | |
| return f"Error saving prompt: {str(e)}" | |
| def delete_conversation_from_huggingface(conversation_id): | |
| """ | |
| Delete conversation files from Hugging Face dataset by ID | |
| Args: | |
| conversation_id: ID of conversation to delete | |
| Returns: | |
| Success status (bool) and message (str) | |
| """ | |
| try: | |
| if not conversation_id: | |
| return False, "No conversation ID provided" | |
| # Initialize API | |
| api = HfApi(token=HF_TOKEN) | |
| # Get list of files in dataset | |
| try: | |
| # Get all files in dataset | |
| files = api.list_repo_files( | |
| repo_id=DATASET_ID, | |
| repo_type="dataset" | |
| ) | |
| # Find files with matching conversation ID in chat history | |
| # В пути может быть chat_history или chat-history | |
| chat_dir = os.path.basename(DATASET_CHAT_HISTORY_PATH) | |
| chat_files = [ | |
| file for file in files | |
| if (file.startswith(f"{chat_dir}/") or file.startswith(f"chat_history/") or file.startswith(f"chat-history/")) and | |
| f"{conversation_id}_" in os.path.basename(file) | |
| ] | |
| if not chat_files: | |
| return False, f"No chat files found for conversation ID: {conversation_id}" | |
| # Delete each matching file | |
| for file_path in chat_files: | |
| try: | |
| api.delete_file( | |
| repo_id=DATASET_ID, | |
| repo_type="dataset", | |
| path_in_repo=file_path | |
| ) | |
| logger.info(f"Deleted file from HF dataset: {file_path}") | |
| except Exception as e: | |
| logger.error(f"Error deleting file {file_path} from dataset: {str(e)}") | |
| # Try to delete annotation file if it exists | |
| # Учитываем разные варианты пути к аннотациям | |
| annotations_base = os.path.basename(DATASET_ANNOTATIONS_PATH) | |
| annotation_paths = [ | |
| f"{annotations_base}/annotation_{conversation_id}.json" | |
| ] | |
| for annotation_path in annotation_paths: | |
| try: | |
| if annotation_path in files: | |
| api.delete_file( | |
| repo_id=DATASET_ID, | |
| repo_type="dataset", | |
| path_in_repo=annotation_path | |
| ) | |
| logger.info(f"Deleted annotation file from HF dataset: {annotation_path}") | |
| except Exception as e: | |
| # It's okay if annotation file doesn't exist | |
| logger.debug(f"Could not delete annotation file {annotation_path}: {str(e)}") | |
| return True, f"Deleted {len(chat_files)} file(s) from dataset for conversation: {conversation_id}" | |
| except Exception as e: | |
| return False, f"Dataset access error: {str(e)}" | |
| except Exception as e: | |
| logger.error(f"Error deleting conversation from dataset: {str(e)}") | |
| return False, f"Error deleting conversation from dataset: {str(e)}" | |
| def delete_conversation(conversation_id, evaluator): | |
| """ | |
| Delete conversation files by ID | |
| Args: | |
| conversation_id: ID of conversation to delete | |
| evaluator: ChatEvaluator instance | |
| Returns: | |
| Message about deletion status | |
| """ | |
| try: | |
| if not conversation_id: | |
| return "Error: No conversation ID provided" | |
| # Используем HF API напрямую для удаления | |
| success, message = delete_conversation_from_huggingface(conversation_id) | |
| if not success: | |
| return f"Error deleting conversation: {message}" | |
| # Сбрасываем кэш evaluator'а после удаления | |
| evaluator.reset_cache() | |
| return f"Successfully deleted conversation: {conversation_id}" | |
| except Exception as e: | |
| logger.error(f"Error deleting conversation: {str(e)}") | |
| return f"Error deleting conversation: {str(e)}" | |
| def initialize_app(): | |
| """Initialize app with user preferences""" | |
| global client, ACTIVE_MODEL | |
| preferences = load_user_preferences() | |
| selected_model = preferences.get("selected_model", DEFAULT_MODEL) | |
| # Make sure the selected model exists | |
| if selected_model not in MODELS: | |
| selected_model = DEFAULT_MODEL | |
| # Set active model | |
| ACTIVE_MODEL = MODELS[selected_model] | |
| # Load saved parameters if they exist | |
| saved_params = preferences.get("parameters", {}).get(selected_model) | |
| if saved_params: | |
| ACTIVE_MODEL['parameters'].update(saved_params) | |
| # Initialize client | |
| client = InferenceClient( | |
| ACTIVE_MODEL["id"], | |
| token=HF_TOKEN | |
| ) | |
| # Load saved system prompt from preferences or use DEFAULT_SYSTEM_MESSAGE | |
| system_prompt_text = DEFAULT_SYSTEM_MESSAGE | |
| if "system_prompt" in preferences and "current" in preferences["system_prompt"]: | |
| system_prompt_text = preferences["system_prompt"]["current"] | |
| logger.info(f"App initialized with model: {ACTIVE_MODEL['name']}") | |
| logger.info(f"Chat histories will be saved to: {DATASET_CHAT_HISTORY_PATH}") | |
| return selected_model, system_prompt_text | |
| def initialize_chat_evaluator(): | |
| """Initialize chat evaluator with proper paths""" | |
| try: | |
| evaluator = ChatEvaluator( | |
| hf_token=HF_TOKEN, | |
| dataset_id=DATASET_ID | |
| ) | |
| # Check if directories exist | |
| os.makedirs(DATASET_CHAT_HISTORY_PATH, exist_ok=True) | |
| os.makedirs(os.path.join(DATASET_ANNOTATIONS_PATH), exist_ok=True) | |
| logger.debug(f"Chat history path: {DATASET_CHAT_HISTORY_PATH}") | |
| logger.debug(f"Number of chat files: {len(os.listdir(DATASET_CHAT_HISTORY_PATH))}") | |
| return evaluator | |
| except Exception as e: | |
| logger.error(f"Error initializing chat evaluator: {str(e)}") | |
| raise | |
| # Initialize HF client with token at startup | |
| selected_model, saved_system_prompt = initialize_app() | |
| # Initialize evaluator before creating interface | |
| chat_evaluator = initialize_chat_evaluator() | |
| # Create interface | |
| with gr.Blocks(css=""" | |
| .table-container { | |
| max-height: 400px; | |
| overflow-y: auto; | |
| } | |
| """) as demo: | |
| # Define clear_conversation function within the block for component access | |
| def clear_conversation(): | |
| """Clear conversation and save history before clearing""" | |
| return [], None # Just return empty values | |
| # Create State for evaluator | |
| evaluator_state = gr.State(value=chat_evaluator) | |
| with gr.Tabs(): | |
| with gr.Tab("Chat"): | |
| gr.Markdown("# ⚖️ Status Law Assistant") | |
| conversation_id = gr.State(None) | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| chatbot = gr.Chatbot( | |
| label="Chat", | |
| avatar_images=None, | |
| type='messages' # This is the key setting - use 'messages' format | |
| ) | |
| with gr.Row(): | |
| msg = gr.Textbox( | |
| label="Your question", | |
| placeholder="Enter your question...", | |
| scale=4 | |
| ) | |
| submit_btn = gr.Button("Send", variant="primary") | |
| clear_btn = gr.Button("Clear") | |
| with gr.Row(equal_height=True): | |
| #with gr.Column(scale=1): | |
| # gr.Markdown("") # Empty column for centering | |
| with gr.Column(scale=40): | |
| system_prompt = gr.TextArea( | |
| label="System Prompt (editing will change bot behavior)", | |
| value=saved_system_prompt, | |
| placeholder="Enter system prompt...", | |
| lines=8 | |
| ) | |
| #with gr.Column(scale=1): | |
| # gr.Markdown("") # Empty column for centering | |
| # Add event handlers | |
| # Обновляем обработчики событий | |
| submit_btn.click( | |
| respond_and_clear, | |
| [msg, chatbot, conversation_id, system_prompt], # Добавляем system_prompt | |
| [chatbot, conversation_id, msg] | |
| ) | |
| # Обновляем обработчик нажатия Enter | |
| msg.submit( | |
| respond_and_clear, | |
| [msg, chatbot, conversation_id, system_prompt], # Добавляем system_prompt | |
| [chatbot, conversation_id, msg] | |
| ) | |
| # Добавляем обработчик изменения промпта | |
| system_prompt_status = gr.Textbox( | |
| label="Status", | |
| interactive=False, | |
| visible=True | |
| ) | |
| system_prompt.change( | |
| save_system_prompt, | |
| inputs=[system_prompt], | |
| outputs=[system_prompt_status] | |
| ) | |
| clear_btn.click(clear_conversation, None, [chatbot, conversation_id]) | |
| with gr.Tab("Knowledge Base"): | |
| gr.Markdown("### Knowledge Base Management") | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| # Отображение источников | |
| gr.Markdown("#### Information Sources") | |
| sources_list = gr.Dataframe( | |
| value=pd.DataFrame({ | |
| "URL": URLS, | |
| "Include": [True for _ in URLS], | |
| "Status": ["Ready" for _ in URLS] | |
| }), | |
| interactive=True, | |
| wrap=True, | |
| row_count=15, | |
| show_label=False | |
| ) | |
| # Статус операций с базой знаний | |
| kb_status = gr.Textbox( | |
| label="Operation Status", | |
| interactive=False, | |
| placeholder="Ready", | |
| value="Ready" | |
| ) | |
| # Кнопки для управления базой знаний | |
| with gr.Row(): | |
| update_kb_btn = gr.Button("Update Knowledge Base", variant="primary") | |
| rebuild_kb_btn = gr.Button("Rebuild Knowledge Base from Scratch", variant="secondary") | |
| gr.Markdown(""" | |
| <small> | |
| **Update Knowledge Base**: Adds new information to the existing knowledge base. | |
| **Rebuild Knowledge Base**: Recreates the entire knowledge base from scratch. Use this if there are inconsistencies. | |
| All changes are saved to the Hugging Face dataset. | |
| </small> | |
| """) | |
| with gr.Column(scale=1): | |
| # Информация о текущей базе знаний | |
| gr.Markdown("#### Knowledge Base Information") | |
| # Функция для получения информации о базе знаний | |
| def get_kb_info() -> str: | |
| """ | |
| Get information about the current state of the knowledge base. | |
| Returns: | |
| str: Formatted markdown string containing knowledge base statistics | |
| """ | |
| try: | |
| vector_store = load_vector_store() | |
| if vector_store is None or isinstance(vector_store, str): | |
| return """ | |
| **Status**: Not found or error | |
| **Documents**: 0 | |
| **Last updated**: Never | |
| Please create a knowledge base using the buttons on the left. | |
| """ | |
| # Get information about vector store | |
| doc_count = len(vector_store.docstore._dict) | |
| sources = set() | |
| for doc_id, doc in vector_store.docstore._dict.items(): | |
| if hasattr(doc, 'metadata') and 'source' in doc.metadata: | |
| sources.add(doc.metadata['source']) | |
| source_count = len(sources) | |
| # Если хранилище существует, но источников нет | |
| if source_count == 0: | |
| return """ | |
| **Status**: Created but empty | |
| **Documents**: 0 | |
| **Last updated**: Unknown | |
| Please rebuild the knowledge base using the button on the left. | |
| """ | |
| # Получаем файл с датой последнего обновления | |
| last_updated = "Unknown" | |
| try: | |
| from src.knowledge_base.dataset import DatasetManager | |
| dataset = DatasetManager() | |
| last_updated = dataset.get_last_update_date() or "Unknown" | |
| except Exception as e: | |
| logger.error(f"Error getting last update date: {str(e)}") | |
| return f""" | |
| **Status**: Active | |
| **Documents**: {doc_count} | |
| **Sources**: {source_count} | |
| **Last updated**: {last_updated} | |
| """ | |
| except Exception as e: | |
| return f""" | |
| **Status**: Error | |
| **Details**: {str(e)} | |
| Please try rebuilding the knowledge base. | |
| """ | |
| kb_info = gr.Markdown(value=get_kb_info()) | |
| refresh_kb_info_btn = gr.Button("Refresh Information") | |
| # 3. Добавим обработчики событий для кнопок в конце файла | |
| # Добавьте эти обработчики перед строкой "if __name__ == "__main__":" | |
| # Обработчики для Knowledge Base | |
| update_kb_btn.click( | |
| fn=update_kb_with_selected, | |
| inputs=[sources_list], | |
| outputs=[kb_status] | |
| ) | |
| rebuild_kb_btn.click( | |
| fn=rebuild_kb_with_selected, | |
| inputs=[sources_list], | |
| outputs=[kb_status] | |
| ) | |
| # Обновление информации о базе знаний | |
| refresh_kb_info_btn.click( | |
| fn=get_kb_info, | |
| inputs=[], | |
| outputs=[kb_info] | |
| ) | |
| with gr.Tab("Model Settings"): | |
| gr.Markdown("### Model Configuration") | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| # Add model selector | |
| model_selector = gr.Dropdown( | |
| choices=list(MODELS.keys()), | |
| value=selected_model, # Use loaded model from preferences | |
| label="Select Model", | |
| interactive=True | |
| ) | |
| # Current model info display | |
| model_info = gr.Markdown(value=update_model_info(selected_model)) | |
| # Status indicator for model loading | |
| model_loading = gr.Textbox( | |
| label="Status", | |
| placeholder="Model ready", | |
| interactive=False, | |
| value="Model ready" | |
| ) | |
| # Model Parameters - make them interactive | |
| with gr.Row(): | |
| max_length = gr.Slider( | |
| minimum=1, | |
| maximum=4096, | |
| value=ACTIVE_MODEL['parameters']['max_length'], | |
| step=1, | |
| label="Maximum Length", | |
| interactive=True | |
| ) | |
| temperature = gr.Slider( | |
| minimum=0.1, | |
| maximum=2.0, | |
| value=ACTIVE_MODEL['parameters']['temperature'], | |
| step=0.1, | |
| label="Temperature", | |
| interactive=True | |
| ) | |
| with gr.Row(): | |
| top_p = gr.Slider( | |
| minimum=0.1, | |
| maximum=1.0, | |
| value=ACTIVE_MODEL['parameters']['top_p'], | |
| step=0.1, | |
| label="Top-p", | |
| interactive=True | |
| ) | |
| rep_penalty = gr.Slider( | |
| minimum=1.0, | |
| maximum=2.0, | |
| value=ACTIVE_MODEL['parameters']['repetition_penalty'], | |
| step=0.1, | |
| label="Repetition Penalty", | |
| interactive=True | |
| ) | |
| # Button to save parameters | |
| save_params_btn = gr.Button("Save Parameters", variant="primary") | |
| gr.Markdown(""" | |
| <small> | |
| **Parameters explanation:** | |
| - **Maximum Length**: Maximum number of tokens in the generated response | |
| - **Temperature**: Controls randomness (0.1 = very focused, 2.0 = very creative) | |
| - **Top-p**: Controls diversity via nucleus sampling (lower = more focused) | |
| - **Repetition Penalty**: Prevents word repetition (higher = less repetition) | |
| </small> | |
| """) | |
| with gr.Column(scale=1): | |
| # Model details panel | |
| model_details = gr.HTML(get_model_details_html(selected_model)) | |
| gr.Markdown("### Training Configuration") | |
| gr.Markdown(f""" | |
| **Base Model Path:** | |
| ``` | |
| {ACTIVE_MODEL['training']['base_model_path']} | |
| ``` | |
| **Fine-tuned Model Path:** | |
| ``` | |
| {ACTIVE_MODEL['training']['fine_tuned_path']} | |
| ``` | |
| **LoRA Configuration:** | |
| - Rank (r): {ACTIVE_MODEL['training']['lora_config']['r']} | |
| - Alpha: {ACTIVE_MODEL['training']['lora_config']['lora_alpha']} | |
| - Dropout: {ACTIVE_MODEL['training']['lora_config']['lora_dropout']} | |
| """) | |
| with gr.Tab("Model Training"): | |
| gr.Markdown("### Model Training Interface") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| training_tabs = gr.Tabs() | |
| with training_tabs: | |
| with gr.TabItem("Regular Training"): | |
| epochs = gr.Slider(minimum=1, maximum=10, value=3, step=1, label="Number of Epochs") | |
| batch_size = gr.Slider(minimum=1, maximum=32, value=4, step=1, label="Batch Size") | |
| learning_rate = gr.Slider(minimum=1e-6, maximum=1e-3, value=2e-4, label="Learning Rate") | |
| train_btn = gr.Button("Start Training", variant="primary") | |
| training_output = gr.Textbox(label="Training Status", interactive=False) | |
| with gr.TabItem("Train from Annotations"): | |
| annot_epochs = gr.Slider(minimum=1, maximum=10, value=3, step=1, label="Number of Epochs") | |
| annot_batch_size = gr.Slider(minimum=1, maximum=32, value=4, step=1, label="Batch Size") | |
| annot_learning_rate = gr.Slider(minimum=1e-6, maximum=1e-3, value=2e-4, label="Learning Rate") | |
| annot_min_rating = gr.Slider(minimum=1, maximum=5, value=4, step=0.5, label="Minimum Rating for Training") | |
| annot_train_btn = gr.Button("Start Training from Annotations", variant="primary") | |
| annot_training_output = gr.Textbox(label="Training Status", interactive=False) | |
| gr.Markdown(""" | |
| <small> | |
| **Epochs:** | |
| Lower = Faster training -> Higher = Model learns more thoroughly | |
| Best for small datasets: 3-5 -> Best for large datasets: 1-2 | |
| **Batch Size:** | |
| Lower = Slower but more stable -> Higher = Faster but needs more RAM | |
| 4 = Good for 16GB RAM -> 8 = Good for 32GB RAM | |
| **Learning Rate:** | |
| Lower = Learns slower but more reliable -> Higher = Learns faster but may be unstable | |
| 2e-4 (0.0002) = Usually works best -> 1e-4 = Safer choice for fine-tuning | |
| </small> | |
| """) | |
| with gr.Column(scale=1): | |
| analysis_btn = gr.Button("Generate Chat Analysis") | |
| analysis_output = gr.Markdown() | |
| train_btn.click( | |
| start_finetune_action, | |
| inputs=[epochs, batch_size, learning_rate], | |
| outputs=[training_output] | |
| ) | |
| # Function to handle training from annotations | |
| def start_annotation_finetune(epochs, batch_size, learning_rate, min_rating): | |
| """Wrapper function to start fine-tuning from annotations""" | |
| success, message = finetune_from_annotations( | |
| epochs=epochs, | |
| batch_size=batch_size, | |
| learning_rate=learning_rate, | |
| min_rating=min_rating | |
| ) | |
| return message | |
| annot_train_btn.click( | |
| start_annotation_finetune, | |
| inputs=[annot_epochs, annot_batch_size, annot_learning_rate, annot_min_rating], | |
| outputs=[annot_training_output] | |
| ) | |
| analysis_btn.click( | |
| generate_chat_analysis, | |
| inputs=[], | |
| outputs=[analysis_output] | |
| ) | |
| with gr.Tab("Chat Evaluation"): | |
| gr.Markdown("### Evaluation of Chat Responses") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| # Status section | |
| evaluation_status = gr.Textbox( | |
| label="Evaluation Status", | |
| interactive=False, | |
| show_label=True | |
| ) | |
| refresh_status_btn = gr.Button("Refresh Status and Chat History") | |
| # Moved refresh status and evaluation report here | |
| refresh_data_status = gr.Textbox( | |
| label="Refresh Status", | |
| interactive=False, | |
| show_label=True | |
| ) | |
| evaluation_report = gr.HTML(label="Evaluation Report") | |
| refresh_report_btn = gr.Button("Generate Report") | |
| # QA pairs table section | |
| show_evaluated = gr.Checkbox( | |
| label="Show Only Evaluated Pairs", | |
| value=False | |
| ) | |
| qa_table = gr.Dataframe( | |
| value=pd.DataFrame( | |
| columns=["Conversation ID", "Question", "Answer", "Evaluated"] | |
| ), | |
| interactive=True, | |
| wrap=True, | |
| row_count=15, # Changed from height to row_count | |
| show_label=True | |
| ) | |
| # Conversation selection section | |
| gr.Markdown("### Select Conversation to Evaluate") | |
| with gr.Row(): | |
| selected_conversation = gr.Textbox( | |
| label="Conversation ID", | |
| placeholder="Select from table above", | |
| interactive=True | |
| ) | |
| load_btn = gr.Button("Load Conversation") | |
| delete_btn = gr.Button("Delete Conversation", variant="stop") | |
| delete_status = gr.Textbox(label="Delete Status", interactive=False) | |
| # Conversation content section | |
| gr.Markdown("### Evaluate Response") | |
| question_display = gr.Textbox(label="User Question", interactive=False) | |
| original_answer = gr.TextArea(label="Original Bot Answer", interactive=False) | |
| improved_answer = gr.TextArea(label="Improved Answer (Gold Standard)", interactive=True) | |
| # Ratings section | |
| gr.Markdown("### Quality Ratings (1-5)") | |
| with gr.Row(): | |
| accuracy = gr.Slider(minimum=1, maximum=5, value=3, step=1, label="Factual Accuracy") | |
| completeness = gr.Slider(minimum=1, maximum=5, value=3, step=1, label="Completeness") | |
| with gr.Row(): | |
| relevance = gr.Slider(minimum=1, maximum=5, value=3, step=1, label="Relevance") | |
| clarity = gr.Slider(minimum=1, maximum=5, value=3, step=1, label="Clarity") | |
| legal_correctness = gr.Slider(minimum=1, maximum=5, value=3, step=1, label="Legal Correctness") | |
| # Notes and save section | |
| notes = gr.TextArea(label="Evaluator Notes", placeholder="Add your notes about this response...") | |
| save_btn = gr.Button("Save Evaluation", variant="primary") | |
| evaluation_status_msg = gr.Textbox(label="Status", interactive=False) | |
| # Data export section | |
| gr.Markdown("### Export Evaluation Data") | |
| with gr.Row(): | |
| min_rating = gr.Slider(minimum=1, maximum=5, value=4, step=0.5, label="Minimum Rating for Export") | |
| export_path = gr.Textbox(label="Export File Path", value="training_data.jsonl") | |
| export_btn = gr.Button("Export Training Data") | |
| export_status = gr.Textbox(label="Export Status", interactive=False) | |
| # Event handlers for Chat Evaluation | |
| refresh_status_btn.click( | |
| fn=lambda: get_evaluation_status(chat_evaluator, force_reload=True), | |
| inputs=[], | |
| outputs=[evaluation_status, qa_table, refresh_data_status] | |
| ) | |
| refresh_report_btn.click( | |
| fn=lambda: generate_evaluation_report_html(chat_evaluator), | |
| inputs=[], | |
| outputs=[evaluation_report] | |
| ) | |
| show_evaluated.change( | |
| fn=lambda x: get_qa_pairs_dataframe(chat_evaluator, x), | |
| inputs=[show_evaluated], | |
| outputs=[qa_table] | |
| ) | |
| def on_table_select(evt: gr.SelectData, dataframe): | |
| """Handle table row selection using the dataframe input""" | |
| try: | |
| # Get the selected row index | |
| row_index = evt.index[0] | |
| # Access the dataframe passed as input parameter | |
| if dataframe is not None and len(dataframe) > row_index: | |
| # Get conversation ID from first column | |
| conversation_id = str(dataframe.iloc[row_index, 0]) | |
| logger.info(f"Selected conversation ID: {conversation_id}") | |
| return conversation_id | |
| else: | |
| logger.error("DataFrame is empty or row index out of bounds") | |
| return "" | |
| except Exception as e: | |
| logger.error(f"Error in table selection: {str(e)}") | |
| import traceback | |
| logger.error(traceback.format_exc()) | |
| return "" | |
| # Update the table row selection handler to include the dataframe as input | |
| qa_table.select( | |
| fn=on_table_select, | |
| inputs=[qa_table], # Pass the table itself as input | |
| outputs=[selected_conversation] | |
| ) | |
| # Load conversation for evaluation | |
| load_btn.click( | |
| fn=lambda x: load_qa_pair_for_evaluation(conversation_id=x, evaluator=chat_evaluator), | |
| inputs=[selected_conversation], | |
| outputs=[question_display, original_answer, improved_answer, | |
| accuracy, completeness, relevance, clarity, legal_correctness, notes] | |
| ) | |
| # Save evaluation | |
| save_btn.click( | |
| fn=lambda conv_id, q, orig_a, imp_a, acc, comp, rel, clar, legal, notes: | |
| save_evaluation(conv_id, q, orig_a, imp_a, acc, comp, rel, clar, legal, notes, evaluator=chat_evaluator), | |
| inputs=[ | |
| selected_conversation, question_display, original_answer, improved_answer, | |
| accuracy, completeness, relevance, clarity, legal_correctness, notes | |
| ], | |
| outputs=[evaluation_status_msg] | |
| ) | |
| # Export training data | |
| export_btn.click( | |
| fn=lambda min_r, path: export_training_data_action(chat_evaluator, min_r, path), | |
| inputs=[min_rating, export_path], | |
| outputs=[export_status] | |
| ) | |
| # Обработчик для удаления чата | |
| delete_btn.click( | |
| fn=delete_conversation, | |
| inputs=[selected_conversation, evaluator_state], | |
| outputs=[delete_status] | |
| ) | |
| # Обновление таблицы и статуса после удаления | |
| delete_btn.click( | |
| fn=lambda: get_evaluation_status(chat_evaluator, force_reload=True), | |
| inputs=[], | |
| outputs=[evaluation_status, qa_table, refresh_data_status] | |
| ) | |
| # Model change handler - outside of Tabs but inside Blocks | |
| model_selector.change( | |
| fn=change_model, | |
| inputs=[model_selector], | |
| outputs=[model_info, max_length, temperature, top_p, rep_penalty, model_loading] | |
| ) | |
| # Update model details panel when changing model | |
| model_selector.change( | |
| fn=get_model_details_html, | |
| inputs=[model_selector], | |
| outputs=[model_details] | |
| ) | |
| # Parameter save handler | |
| save_params_btn.click( | |
| fn=save_parameters, | |
| inputs=[model_selector, max_length, temperature, top_p, rep_penalty], | |
| outputs=[model_loading] | |
| ) | |
| # Launch application | |
| if __name__ == "__main__": | |
| # Проверяем knowledge base | |
| if not load_vector_store(): | |
| logger.warning("Knowledge base not found. Please create it through the interface.") | |
| demo.launch(share=True) | |