| import torch
|
| import torch.nn.functional as F
|
|
|
| from model import GenoLiteHybrid
|
|
|
|
|
|
|
|
|
|
|
| DEVICE = torch.device(
|
| "cuda" if torch.cuda.is_available() else "cpu"
|
| )
|
|
|
| CHUNK_SIZE = 64
|
| TOKEN_MAP = {
|
| "U": 0,
|
| "D": 1,
|
| "-": 2,
|
| "+": 3,
|
| "J": 4,
|
| "R": 5,
|
| "L": 6,
|
| "T": 7,
|
| "C": 8,
|
| "H": 9,
|
| "F": 10
|
| }
|
|
|
| ID2LABEL = {
|
| 0: "0",
|
| 1: "1",
|
| 2: "2",
|
| 3: "3",
|
| 4: "4",
|
| 5: "5",
|
| 6: "6",
|
| 7: "7",
|
| 8: "8",
|
| 9: "9"
|
| }
|
|
|
|
|
|
|
|
|
|
|
| model = GenoLiteHybrid().to(DEVICE)
|
|
|
| checkpoint = torch.load(
|
| "model.pt",
|
| map_location=DEVICE
|
| )
|
|
|
|
|
|
|
|
|
|
|
| if isinstance(checkpoint, dict) and \
|
| "model_state_dict" in checkpoint:
|
|
|
| model.load_state_dict(
|
| checkpoint["model_state_dict"]
|
| )
|
|
|
| print("\nLoaded full checkpoint.")
|
|
|
| else:
|
|
|
| model.load_state_dict(checkpoint)
|
|
|
| print("\nLoaded raw state_dict.")
|
|
|
| model.eval()
|
|
|
| print("\n===================================")
|
| print(" MODEL LOADED")
|
| print("===================================\n")
|
|
|
|
|
|
|
|
|
|
|
| def encode(seq):
|
|
|
| return torch.tensor(
|
|
|
| [TOKEN_MAP[c] for c in seq],
|
|
|
| dtype=torch.long
|
| )
|
|
|
|
|
|
|
|
|
|
|
| def split_chunks(sequence):
|
|
|
| chunks = []
|
|
|
| for i in range(
|
| 0,
|
| len(sequence),
|
| CHUNK_SIZE
|
| ):
|
|
|
| chunk = sequence[
|
| i:i + CHUNK_SIZE
|
| ]
|
|
|
| chunks.append(chunk)
|
|
|
| return chunks
|
|
|
|
|
|
|
|
|
|
|
| def analyze_chunk(sequence):
|
|
|
| x = encode(sequence)
|
|
|
| x = x.unsqueeze(0).to(DEVICE)
|
|
|
| with torch.no_grad():
|
|
|
|
|
|
|
|
|
|
|
| emb = model.embedding(x)
|
|
|
|
|
|
|
|
|
|
|
| cnn_out = model.cnn(emb)
|
|
|
| gru_out = model.gru(emb)
|
|
|
| tf_out = model.transformer(emb)
|
|
|
| mamba_out = model.mamba(emb)
|
|
|
|
|
|
|
|
|
|
|
| cnn_score = cnn_out.abs().mean().item()
|
|
|
| gru_score = gru_out.abs().mean().item()
|
|
|
| tf_score = tf_out.abs().mean().item()
|
|
|
| mamba_score = mamba_out.abs().mean().item()
|
|
|
| total = (
|
| cnn_score +
|
| gru_score +
|
| tf_score +
|
| mamba_score
|
| )
|
|
|
| cnn_w = cnn_score / total
|
| gru_w = gru_score / total
|
| tf_w = tf_score / total
|
| mamba_w = mamba_score / total
|
|
|
|
|
|
|
|
|
|
|
| fused = torch.cat(
|
| [
|
| cnn_out,
|
| gru_out,
|
| tf_out,
|
| mamba_out
|
| ],
|
| dim=-1
|
| )
|
|
|
| fused = model.fusion(fused)
|
|
|
| pooled = fused.mean(dim=1)
|
|
|
| logits = model.classifier(pooled)
|
|
|
| probs = F.softmax(
|
| logits,
|
| dim=-1
|
| )
|
|
|
| pred = probs.argmax(dim=-1).item()
|
|
|
| return {
|
|
|
| "prediction": ID2LABEL[pred],
|
|
|
| "probs": probs[0].cpu(),
|
|
|
| "cnn": cnn_w,
|
| "gru": gru_w,
|
| "tf": tf_w,
|
| "mamba": mamba_w
|
| }
|
|
|
|
|
|
|
|
|
|
|
| def analyze_sequence(sequence):
|
|
|
| sequence = sequence.strip().upper()
|
|
|
|
|
|
|
|
|
|
|
| valid = all(
|
| c in TOKEN_MAP
|
| for c in sequence
|
| )
|
|
|
| if not valid:
|
|
|
| print("\nOnly A/T/G/C allowed.\n")
|
| return
|
|
|
|
|
|
|
|
|
|
|
| length = len(sequence)
|
|
|
| if length < CHUNK_SIZE:
|
|
|
| missing = CHUNK_SIZE - length
|
|
|
| print("\n===================================")
|
| print(" LENGTH ERROR")
|
| print("===================================\n")
|
|
|
| print("Input too short.\n")
|
|
|
| print(
|
| f"Current Length : {length}"
|
| )
|
|
|
| print(
|
| f"Missing Chars : {missing}"
|
| )
|
|
|
| print(
|
| f"Required Length: {CHUNK_SIZE}"
|
| )
|
|
|
| print("\n===================================\n")
|
|
|
| return
|
|
|
|
|
|
|
|
|
|
|
| if length % CHUNK_SIZE != 0:
|
|
|
| next_valid = (
|
| (
|
| length // CHUNK_SIZE
|
| ) + 1
|
| ) * CHUNK_SIZE
|
|
|
| missing = next_valid - length
|
|
|
| print("\n===================================")
|
| print(" LENGTH ERROR")
|
| print("===================================\n")
|
|
|
| print(
|
| f"Sequence length must be "
|
| f"a multiple of {CHUNK_SIZE}.\n"
|
| )
|
|
|
| print(
|
| f"Current Length : {length}"
|
| )
|
|
|
| print(
|
| f"Next Valid Size: {next_valid}"
|
| )
|
|
|
| print(
|
| f"Missing Chars : {missing}"
|
| )
|
|
|
| print("\n===================================\n")
|
|
|
| return
|
|
|
|
|
|
|
|
|
|
|
| chunks = split_chunks(sequence)
|
|
|
| print("\n===================================")
|
| print(" ANALYZING INPUT")
|
| print("===================================\n")
|
|
|
| print(f"Total Length : {len(sequence)}")
|
|
|
| print(f"Chunks : {len(chunks)}")
|
|
|
|
|
|
|
|
|
|
|
| total_probs = torch.zeros(10)
|
|
|
| total_cnn = 0
|
| total_gru = 0
|
| total_tf = 0
|
| total_mamba = 0
|
|
|
|
|
|
|
|
|
|
|
| for idx, chunk in enumerate(chunks):
|
|
|
| result = analyze_chunk(chunk)
|
|
|
| total_probs += result["probs"]
|
|
|
| total_cnn += result["cnn"]
|
| total_gru += result["gru"]
|
| total_tf += result["tf"]
|
| total_mamba += result["mamba"]
|
|
|
| print("\n-----------------------------------")
|
| print(f"Chunk {idx+1}")
|
| print("-----------------------------------\n")
|
|
|
| print(chunk)
|
|
|
| print("\nPrediction:")
|
| print(result["prediction"])
|
|
|
| print("\nProbabilities:\n")
|
|
|
| for i in range(3):
|
|
|
| print(
|
| f"{ID2LABEL[i]}: "
|
| f"{result['probs'][i].item():.4f}"
|
| )
|
|
|
|
|
|
|
|
|
|
|
| total_probs /= len(chunks)
|
|
|
| total_cnn /= len(chunks)
|
| total_gru /= len(chunks)
|
| total_tf /= len(chunks)
|
| total_mamba /= len(chunks)
|
|
|
|
|
|
|
|
|
|
|
| final_pred = total_probs.argmax().item()
|
|
|
| print("\n===================================")
|
| print(" FINAL RESULT")
|
| print("===================================\n")
|
|
|
| print(
|
| f"FINAL DECISION: "
|
| f"{ID2LABEL[final_pred]}"
|
| )
|
|
|
| print("\n-----------------------------------")
|
| print("Average Probabilities")
|
| print("-----------------------------------\n")
|
|
|
| for i in range(3):
|
|
|
| print(
|
| f"{ID2LABEL[i]}: "
|
| f"{total_probs[i].item():.4f}"
|
| )
|
|
|
| print("\n-----------------------------------")
|
| print("Average Expert Activity")
|
| print("-----------------------------------\n")
|
|
|
| print(f"CNN : {total_cnn:.4f}")
|
| print(f"GRU : {total_gru:.4f}")
|
| print(f"Transformer : {total_tf:.4f}")
|
| print(f"Mamba : {total_mamba:.4f}")
|
|
|
| print("\n===================================\n")
|
|
|
|
|
|
|
|
|
|
|
| print("Type DNA sequence.")
|
| print("Length must be 64 or multiples of 64.")
|
| print("Type EXIT to quit.\n")
|
|
|
| while True:
|
|
|
| seq = input("logs > ")
|
|
|
| if seq.strip().upper() == "EXIT":
|
|
|
| print("\nBye.\n")
|
| break
|
|
|
| analyze_sequence(seq)
|
|
|