import time from src.models.fusion_inference import FusionClaimVerifier # We will patch perf_counter to explicitly track load times load_times = {} print("====================================") print(" MODEL LOADING SPEED PROFILE ") print("====================================") t0 = time.time() print("1. Initializing FusionClaimVerifier ... (This loads ALL models)") verifier = FusionClaimVerifier( fusion_model_path="models/fusion_model.pt", opensearch_index="news_kb", # OPENSEARCH_INDEX_NAME llm_model_path="models/lora_llm", retriever_model_path="AITeamVN/Vietnamese_Embedding", device="cpu", # Test full CPU llm_evidence_top_k=5, debug=True, ) init_time = time.time() - t0 print(f"[LOAD] FusionClaimVerifier Init Time: {init_time:.2f} seconds") # Explicitly measure components for one prediction print("\n====================================") print(" INFERENCE SPEED PROFILE ") print("====================================") claim_text = "giá vàng sẽ tăng 200% vào ngày mai" t_start = time.time() # 1. Query Expansion & Retrieval print("\n1. Running Retrieval (BM25 + Semantic Vector Search)") t_retr_0 = time.time() docs = verifier.retriever.retrieve(claim_text, top_k=verifier.top_k) t_retr_1 = time.time() print(f" [INFER] Retrieval Time: {t_retr_1 - t_retr_0:.2f} seconds") # 2. LLM Inference print("\n2. Running LLM Inference (Prompt + Context -> Logits)") evidences = [d.text for d in docs[: verifier.llm_evidence_top_k]] t_llm_0 = time.time() import torch with torch.inference_mode(): llm_logits = verifier.llm.score_logits([claim_text], [evidences]).to( verifier.device ) t_llm_1 = time.time() print(f" [INFER] LLM Inference Time: {t_llm_1 - t_llm_0:.2f} seconds") # 3. Fusion Layer print("\n3. Running Fusion Layer") from src.models.fusion_inference import _build_retrieval_features_train_compatible # We rebuild the features just for fusion measurement retrieval_features_np, _, _ = _build_retrieval_features_train_compatible( verifier.retriever, claim_text, verifier.top_k ) t_fuse_0 = time.time() with torch.inference_mode(): ret_feat = torch.tensor( retrieval_features_np, dtype=torch.float32, device=verifier.device ).unsqueeze(0) retrieval_encoded = verifier.retrieval_encoder(ret_feat) fusion_output = verifier.fusion(llm_logits, retrieval_encoded) t_fuse_1 = time.time() print(f" [INFER] Fusion Time: {t_fuse_1 - t_fuse_0:.2f} seconds") print("\n====================================") print(f"TOTAL PREDICTION TIME: {t_fuse_1 - t_start:.2f} seconds") print("====================================\n")