| | import os |
| | import json |
| | import pickle |
| | import faiss |
| | from tqdm.auto import tqdm |
| | from pathlib import Path |
| | from sentence_transformers import SentenceTransformer |
| | from tf_data_pipeline import TFDataPipeline |
| | from chatbot_config import ChatbotConfig |
| | from logger_config import config_logger |
| |
|
| | logger = config_logger(__name__) |
| |
|
| | os.environ["TOKENIZERS_PARALLELISM"] = "false" |
| |
|
| | def main(): |
| | MODELS_DIR = 'models' |
| | PROCESSED_DATA_DIR = 'processed_outputs' |
| | CACHE_DIR = os.path.join(MODELS_DIR, 'query_embeddings_cache') |
| | TOKENIZER_DIR = os.path.join(MODELS_DIR, 'tokenizer') |
| | FAISS_INDICES_DIR = os.path.join(MODELS_DIR, 'faiss_indices') |
| | TF_RECORD_DIR = 'training_data' |
| | FAISS_INDEX_PRODUCTION_PATH = os.path.join(FAISS_INDICES_DIR, 'faiss_index_production.index') |
| | JSON_TRAINING_DATA_PATH = os.path.join(PROCESSED_DATA_DIR, 'taskmaster_only.json') |
| | CACHE_FILE = os.path.join(CACHE_DIR, 'query_embeddings_cache.pkl') |
| | TF_RECORD_PATH = os.path.join(TF_RECORD_DIR, 'training_data_3.tfrecord') |
| | |
| | |
| | os.makedirs(MODELS_DIR, exist_ok=True) |
| | os.makedirs(PROCESSED_DATA_DIR, exist_ok=True) |
| | os.makedirs(CACHE_DIR, exist_ok=True) |
| | os.makedirs(TOKENIZER_DIR, exist_ok=True) |
| | os.makedirs(FAISS_INDICES_DIR, exist_ok=True) |
| | os.makedirs(TF_RECORD_DIR, exist_ok=True) |
| | |
| | |
| | config_json = Path(MODELS_DIR) / "config.json" |
| | if config_json.exists(): |
| | with open(config_json, "r", encoding="utf-8") as f: |
| | config_dict = json.load(f) |
| | config = ChatbotConfig.from_dict(config_dict) |
| | logger.info(f"Loaded ChatbotConfig from {config_json}") |
| | else: |
| | config = ChatbotConfig() |
| | logger.warning("No config.json found. Using default ChatbotConfig.") |
| | try: |
| | with open(config_json, "w", encoding="utf-8") as f: |
| | json.dump(config.to_dict(), f, indent=2) |
| | logger.info(f"Default ChatbotConfig saved to {config_json}") |
| | except Exception as e: |
| | logger.error(f"Failed to save default ChatbotConfig: {e}") |
| | raise |
| | |
| | |
| | encoder = SentenceTransformer(config.pretrained_model) |
| | logger.info(f"Initialized SentenceTransformer model: {config.pretrained_model}") |
| | |
| | |
| | if Path(JSON_TRAINING_DATA_PATH).exists(): |
| | dialogues = TFDataPipeline.load_json_training_data(JSON_TRAINING_DATA_PATH) |
| | logger.info(f"Loaded {len(dialogues)} dialogues.") |
| | else: |
| | logger.warning(f"No dialogues found at {JSON_TRAINING_DATA_PATH}.") |
| | dialogues = [] |
| | |
| | |
| | query_embeddings_cache = {} |
| | if os.path.exists(CACHE_FILE): |
| | with open(CACHE_FILE, 'rb') as f: |
| | query_embeddings_cache = pickle.load(f) |
| | logger.info(f"Loaded query embeddings cache with {len(query_embeddings_cache)} entries.") |
| | else: |
| | logger.info("No existing query embeddings cache found. Starting fresh.") |
| | |
| | |
| | dimension = encoder.get_sentence_embedding_dimension() |
| | if Path(FAISS_INDEX_PRODUCTION_PATH).exists(): |
| | faiss_index = faiss.read_index(FAISS_INDEX_PRODUCTION_PATH) |
| | logger.info(f"Loaded FAISS index from {FAISS_INDEX_PRODUCTION_PATH}.") |
| | else: |
| | faiss_index = faiss.IndexFlatIP(dimension) |
| | logger.info(f"Initialized new FAISS index with dimension {dimension}.") |
| | |
| | |
| | data_pipeline = TFDataPipeline( |
| | config=config, |
| | tokenizer=encoder.tokenizer, |
| | encoder=encoder, |
| | response_pool=[], |
| | query_embeddings_cache=query_embeddings_cache, |
| | index_type='IndexFlatIP', |
| | faiss_index_file_path=FAISS_INDEX_PRODUCTION_PATH |
| | ) |
| | |
| | |
| | if dialogues: |
| | response_pool = data_pipeline.collect_responses_with_domain(dialogues) |
| | data_pipeline.response_pool = response_pool |
| | |
| | |
| | response_pool_path = FAISS_INDEX_PRODUCTION_PATH.replace('.index', '_responses.json') |
| | with open(response_pool_path, 'w', encoding='utf-8') as f: |
| | json.dump(response_pool, f, indent=2) |
| | logger.info(f"Response pool saved to {response_pool_path}.") |
| | data_pipeline.compute_and_index_response_embeddings() |
| | data_pipeline.save_faiss_index(FAISS_INDEX_PRODUCTION_PATH) |
| | logger.info(f"FAISS index saved at {FAISS_INDEX_PRODUCTION_PATH}.") |
| | else: |
| | logger.warning("No responses to embed. Skipping FAISS indexing.") |
| | |
| | |
| | with open(CACHE_FILE, 'wb') as f: |
| | pickle.dump(query_embeddings_cache, f) |
| | logger.info(f"Query embeddings cache saved at {CACHE_FILE}.") |
| |
|
| | logger.info("Pipeline completed successfully.") |
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|