File size: 2,642 Bytes
398a289
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import time

from src.models.fusion_inference import FusionClaimVerifier

# We will patch perf_counter to explicitly track load times
load_times = {}

print("====================================")
print("    MODEL LOADING SPEED PROFILE     ")
print("====================================")

t0 = time.time()
print("1. Initializing FusionClaimVerifier ... (This loads ALL models)")

verifier = FusionClaimVerifier(
    fusion_model_path="models/fusion_model.pt",
    opensearch_index="news_kb",  # OPENSEARCH_INDEX_NAME
    llm_model_path="models/lora_llm",
    retriever_model_path="AITeamVN/Vietnamese_Embedding",
    device="cpu",  # Test full CPU
    llm_evidence_top_k=5,
    debug=True,
)
init_time = time.time() - t0
print(f"[LOAD] FusionClaimVerifier Init Time: {init_time:.2f} seconds")

# Explicitly measure components for one prediction
print("\n====================================")
print("    INFERENCE SPEED PROFILE         ")
print("====================================")

claim_text = "giá vàng sẽ tăng 200% vào ngày mai"

t_start = time.time()

# 1. Query Expansion & Retrieval
print("\n1. Running Retrieval (BM25 + Semantic Vector Search)")
t_retr_0 = time.time()
docs = verifier.retriever.retrieve(claim_text, top_k=verifier.top_k)
t_retr_1 = time.time()
print(f"   [INFER] Retrieval Time: {t_retr_1 - t_retr_0:.2f} seconds")

# 2. LLM Inference
print("\n2. Running LLM Inference (Prompt + Context -> Logits)")
evidences = [d.text for d in docs[: verifier.llm_evidence_top_k]]
t_llm_0 = time.time()
import torch

with torch.inference_mode():
    llm_logits = verifier.llm.score_logits([claim_text], [evidences]).to(
        verifier.device
    )
t_llm_1 = time.time()
print(f"   [INFER] LLM Inference Time: {t_llm_1 - t_llm_0:.2f} seconds")

# 3. Fusion Layer
print("\n3. Running Fusion Layer")
from src.models.fusion_inference import _build_retrieval_features_train_compatible

# We rebuild the features just for fusion measurement
retrieval_features_np, _, _ = _build_retrieval_features_train_compatible(
    verifier.retriever, claim_text, verifier.top_k
)
t_fuse_0 = time.time()
with torch.inference_mode():
    ret_feat = torch.tensor(
        retrieval_features_np, dtype=torch.float32, device=verifier.device
    ).unsqueeze(0)
    retrieval_encoded = verifier.retrieval_encoder(ret_feat)
    fusion_output = verifier.fusion(llm_logits, retrieval_encoded)
t_fuse_1 = time.time()
print(f"   [INFER] Fusion Time: {t_fuse_1 - t_fuse_0:.2f} seconds")

print("\n====================================")
print(f"TOTAL PREDICTION TIME: {t_fuse_1 - t_start:.2f} seconds")
print("====================================\n")