| | import os |
| | import json |
| | from sentence_transformers import SentenceTransformer |
| | from chatbot_config import ChatbotConfig |
| | from chatbot_model import RetrievalChatbot |
| | from response_quality_checker import ResponseQualityChecker |
| | from chatbot_validator import ChatbotValidator |
| | from plotter import Plotter |
| | from environment_setup import EnvironmentSetup |
| | from logger_config import config_logger |
| | from tf_data_pipeline import TFDataPipeline |
| |
|
| | logger = config_logger(__name__) |
| |
|
| | def run_chatbot_validation(): |
| | |
| | env = EnvironmentSetup() |
| | env.initialize() |
| |
|
| | MODEL_DIR = "models" |
| | FAISS_INDICES_DIR = os.path.join(MODEL_DIR, "faiss_indices") |
| | FAISS_INDEX_PRODUCTION_PATH = os.path.join(FAISS_INDICES_DIR, "faiss_index_production.index") |
| | FAISS_INDEX_TEST_PATH = os.path.join(FAISS_INDICES_DIR, "faiss_index_test.index") |
| |
|
| | |
| | ENVIRONMENT = "production" |
| | if ENVIRONMENT == "test": |
| | FAISS_INDEX_PATH = FAISS_INDEX_TEST_PATH |
| | RESPONSE_POOL_PATH = FAISS_INDEX_TEST_PATH.replace(".index", "_responses.json") |
| | else: |
| | FAISS_INDEX_PATH = FAISS_INDEX_PRODUCTION_PATH |
| | RESPONSE_POOL_PATH = FAISS_INDEX_PRODUCTION_PATH.replace(".index", "_responses.json") |
| |
|
| | |
| | config_path = os.path.join(MODEL_DIR, "config.json") |
| | if os.path.exists(config_path): |
| | with open(config_path, "r", encoding="utf-8") as f: |
| | config_dict = json.load(f) |
| | config = ChatbotConfig.from_dict(config_dict) |
| | logger.info(f"Loaded ChatbotConfig from {config_path}") |
| | else: |
| | config = ChatbotConfig() |
| | logger.warning("No config.json found. Using default ChatbotConfig.") |
| |
|
| | |
| | try: |
| | encoder = SentenceTransformer(config.pretrained_model) |
| | logger.info(f"Loaded SentenceTransformer model: {config.pretrained_model}") |
| | except Exception as e: |
| | logger.error(f"Failed to load SentenceTransformer: {e}") |
| | return |
| |
|
| | |
| | try: |
| | |
| | data_pipeline = TFDataPipeline( |
| | config=config, |
| | tokenizer=encoder.tokenizer, |
| | encoder=encoder, |
| | response_pool=[], |
| | query_embeddings_cache={}, |
| | index_type='IndexFlatIP', |
| | faiss_index_file_path=FAISS_INDEX_PATH |
| | ) |
| |
|
| | if not os.path.exists(FAISS_INDEX_PATH) or not os.path.exists(RESPONSE_POOL_PATH): |
| | logger.error("FAISS index or response pool file is missing.") |
| | return |
| |
|
| | data_pipeline.load_faiss_index(FAISS_INDEX_PATH) |
| | logger.info(f"FAISS index loaded from {FAISS_INDEX_PATH}.") |
| |
|
| | with open(RESPONSE_POOL_PATH, "r", encoding="utf-8") as f: |
| | data_pipeline.response_pool = json.load(f) |
| | logger.info(f"Response pool loaded from {RESPONSE_POOL_PATH}.") |
| | logger.info(f"Total responses in pool: {len(data_pipeline.response_pool)}") |
| |
|
| | |
| | data_pipeline.validate_faiss_index() |
| | logger.info("FAISS index and response pool validated successfully.") |
| | except Exception as e: |
| | logger.error(f"Failed to load or validate FAISS index: {e}") |
| | return |
| |
|
| | |
| | try: |
| | chatbot = RetrievalChatbot.load_model(load_dir=MODEL_DIR, mode="inference") |
| | quality_checker = ResponseQualityChecker(data_pipeline=data_pipeline) |
| | validator = ChatbotValidator(chatbot=chatbot, quality_checker=quality_checker) |
| | logger.info("ResponseQualityChecker and ChatbotValidator initialized.") |
| |
|
| | |
| | validation_metrics = validator.run_validation(num_examples=5) |
| | logger.info(f"Validation Metrics: {validation_metrics}") |
| | except Exception as e: |
| | logger.error(f"Validation process failed: {e}") |
| | return |
| |
|
| | |
| | try: |
| | plotter = Plotter(save_dir=env.training_dirs["plots"]) |
| | plotter.plot_validation_metrics(validation_metrics) |
| | logger.info("Validation metrics plotted successfully.") |
| | except Exception as e: |
| | logger.error(f"Failed to plot validation metrics: {e}") |
| |
|
| | |
| | try: |
| | logger.info("\nStarting interactive chat session...") |
| | chatbot.run_interactive_chat(quality_checker=quality_checker, show_alternatives=True) |
| | except Exception as e: |
| | logger.error(f"Interactive chat session failed: {e}") |
| | |
| | |
| | if __name__ == "__main__": |
| | run_chatbot_validation() |