Spaces:
Build error
Build error
| import os | |
| import json | |
| from tqdm.auto import tqdm | |
| from chatbot_config import ChatbotConfig | |
| from chatbot_model import RetrievalChatbot | |
| from sentence_transformers import SentenceTransformer | |
| from tf_data_pipeline import TFDataPipeline | |
| from response_quality_checker import ResponseQualityChecker | |
| from environment_setup import EnvironmentSetup | |
| from logger_config import config_logger | |
| logger = config_logger(__name__) | |
| logger.setLevel("WARNING") | |
| os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" | |
| tqdm(disable=True) | |
| def run_chatbot_chat(): | |
| env = EnvironmentSetup() | |
| env.initialize() | |
| MODEL_DIR = "models" | |
| FAISS_INDICES_DIR = os.path.join(MODEL_DIR, "faiss_indices") | |
| FAISS_INDEX_PRODUCTION_PATH = os.path.join(FAISS_INDICES_DIR, "faiss_index_production.index") | |
| FAISS_INDEX_TEST_PATH = os.path.join(FAISS_INDICES_DIR, "faiss_index_test.index") | |
| # Toggle 'production' or 'test' env | |
| ENVIRONMENT = "production" | |
| if ENVIRONMENT == "test": | |
| FAISS_INDEX_PATH = FAISS_INDEX_TEST_PATH | |
| RESPONSE_POOL_PATH = FAISS_INDEX_TEST_PATH.replace(".index", "_responses.json") | |
| else: | |
| FAISS_INDEX_PATH = FAISS_INDEX_PRODUCTION_PATH | |
| RESPONSE_POOL_PATH = FAISS_INDEX_PRODUCTION_PATH.replace(".index", "_responses.json") | |
| # Load the config | |
| config_path = os.path.join(MODEL_DIR, "config.json") | |
| if os.path.exists(config_path): | |
| with open(config_path, "r", encoding="utf-8") as f: | |
| config_dict = json.load(f) | |
| config = ChatbotConfig.from_dict(config_dict) | |
| logger.info(f"Loaded ChatbotConfig from {config_path}") | |
| else: | |
| config = ChatbotConfig() | |
| logger.warning("No config.json found. Using default ChatbotConfig.") | |
| # Init SentenceTransformer | |
| try: | |
| encoder = SentenceTransformer(config.pretrained_model) | |
| logger.info(f"Loaded SentenceTransformer model: {config.pretrained_model}") | |
| except Exception as e: | |
| logger.error(f"Failed to load SentenceTransformer: {e}") | |
| return | |
| # Load FAISS index and response pool | |
| try: | |
| # Initialize TFDataPipeline | |
| data_pipeline = TFDataPipeline( | |
| config=config, | |
| tokenizer=encoder.tokenizer, | |
| encoder=encoder, | |
| response_pool=[], | |
| query_embeddings_cache={}, | |
| index_type='IndexFlatIP', | |
| faiss_index_file_path=FAISS_INDEX_PATH | |
| ) | |
| if not os.path.exists(FAISS_INDEX_PATH) or not os.path.exists(RESPONSE_POOL_PATH): | |
| logger.error("FAISS index or response pool file is missing.") | |
| return | |
| data_pipeline.load_faiss_index(FAISS_INDEX_PATH) | |
| logger.info(f"FAISS index loaded from {FAISS_INDEX_PATH}.") | |
| with open(RESPONSE_POOL_PATH, "r", encoding="utf-8") as f: | |
| data_pipeline.response_pool = json.load(f) | |
| logger.info(f"Response pool loaded from {RESPONSE_POOL_PATH}.") | |
| logger.info(f"Total responses in pool: {len(data_pipeline.response_pool)}") | |
| # Validate dimension consistency | |
| data_pipeline.validate_faiss_index() | |
| logger.info("FAISS index and response pool validated successfully.") | |
| except Exception as e: | |
| logger.error(f"Failed to load or validate FAISS index: {e}") | |
| return | |
| # Run interactive chat | |
| try: | |
| chatbot = RetrievalChatbot.load_model(load_dir=MODEL_DIR, mode="inference") | |
| quality_checker = ResponseQualityChecker(data_pipeline=data_pipeline) | |
| logger.info("\nStarting interactive chat session...") | |
| chatbot.run_interactive_chat(quality_checker=quality_checker, show_alternatives=False) | |
| except Exception as e: | |
| logger.error(f"Interactive chat session failed: {e}") | |
| if __name__ == "__main__": | |
| run_chatbot_chat() | |