Fake_Crypto_Claim_Detector / scripts /profile_inference.py
ktoan911's picture
Upload folder using huggingface_hub
398a289 verified
import time
from src.models.fusion_inference import FusionClaimVerifier
# We will patch perf_counter to explicitly track load times
load_times = {}
print("====================================")
print(" MODEL LOADING SPEED PROFILE ")
print("====================================")
t0 = time.time()
print("1. Initializing FusionClaimVerifier ... (This loads ALL models)")
verifier = FusionClaimVerifier(
fusion_model_path="models/fusion_model.pt",
opensearch_index="news_kb", # OPENSEARCH_INDEX_NAME
llm_model_path="models/lora_llm",
retriever_model_path="AITeamVN/Vietnamese_Embedding",
device="cpu", # Test full CPU
llm_evidence_top_k=5,
debug=True,
)
init_time = time.time() - t0
print(f"[LOAD] FusionClaimVerifier Init Time: {init_time:.2f} seconds")
# Explicitly measure components for one prediction
print("\n====================================")
print(" INFERENCE SPEED PROFILE ")
print("====================================")
claim_text = "giá vàng sẽ tăng 200% vào ngày mai"
t_start = time.time()
# 1. Query Expansion & Retrieval
print("\n1. Running Retrieval (BM25 + Semantic Vector Search)")
t_retr_0 = time.time()
docs = verifier.retriever.retrieve(claim_text, top_k=verifier.top_k)
t_retr_1 = time.time()
print(f" [INFER] Retrieval Time: {t_retr_1 - t_retr_0:.2f} seconds")
# 2. LLM Inference
print("\n2. Running LLM Inference (Prompt + Context -> Logits)")
evidences = [d.text for d in docs[: verifier.llm_evidence_top_k]]
t_llm_0 = time.time()
import torch
with torch.inference_mode():
llm_logits = verifier.llm.score_logits([claim_text], [evidences]).to(
verifier.device
)
t_llm_1 = time.time()
print(f" [INFER] LLM Inference Time: {t_llm_1 - t_llm_0:.2f} seconds")
# 3. Fusion Layer
print("\n3. Running Fusion Layer")
from src.models.fusion_inference import _build_retrieval_features_train_compatible
# We rebuild the features just for fusion measurement
retrieval_features_np, _, _ = _build_retrieval_features_train_compatible(
verifier.retriever, claim_text, verifier.top_k
)
t_fuse_0 = time.time()
with torch.inference_mode():
ret_feat = torch.tensor(
retrieval_features_np, dtype=torch.float32, device=verifier.device
).unsqueeze(0)
retrieval_encoded = verifier.retrieval_encoder(ret_feat)
fusion_output = verifier.fusion(llm_logits, retrieval_encoded)
t_fuse_1 = time.time()
print(f" [INFER] Fusion Time: {t_fuse_1 - t_fuse_0:.2f} seconds")
print("\n====================================")
print(f"TOTAL PREDICTION TIME: {t_fuse_1 - t_start:.2f} seconds")
print("====================================\n")