import torch import torch.nn.functional as F from model import GenoLiteHybrid # ========================================================= # CONFIG # ========================================================= DEVICE = torch.device( "cuda" if torch.cuda.is_available() else "cpu" ) CHUNK_SIZE = 64 TOKEN_MAP = { "U": 0, "D": 1, "-": 2, "+": 3, "J": 4, "R": 5, "L": 6, "T": 7, "C": 8, "H": 9, "F": 10 } ID2LABEL = { 0: "0", 1: "1", 2: "2", 3: "3", 4: "4", 5: "5", 6: "6", 7: "7", 8: "8", 9: "9" } # ========================================================= # LOAD MODEL # ========================================================= model = GenoLiteHybrid().to(DEVICE) checkpoint = torch.load( "model.pt", map_location=DEVICE ) # --------------------------------------------------------- # RAW OR FULL CHECKPOINT # --------------------------------------------------------- if isinstance(checkpoint, dict) and \ "model_state_dict" in checkpoint: model.load_state_dict( checkpoint["model_state_dict"] ) print("\nLoaded full checkpoint.") else: model.load_state_dict(checkpoint) print("\nLoaded raw state_dict.") model.eval() print("\n===================================") print(" MODEL LOADED") print("===================================\n") # ========================================================= # ENCODE # ========================================================= def encode(seq): return torch.tensor( [TOKEN_MAP[c] for c in seq], dtype=torch.long ) # ========================================================= # CHUNKING # ========================================================= def split_chunks(sequence): chunks = [] for i in range( 0, len(sequence), CHUNK_SIZE ): chunk = sequence[ i:i + CHUNK_SIZE ] chunks.append(chunk) return chunks # ========================================================= # SINGLE CHUNK INFERENCE # ========================================================= def analyze_chunk(sequence): x = encode(sequence) x = x.unsqueeze(0).to(DEVICE) with torch.no_grad(): # --------------------------------------------- # EMBEDDING # --------------------------------------------- emb = model.embedding(x) # --------------------------------------------- # EXPERTS # --------------------------------------------- cnn_out = model.cnn(emb) gru_out = model.gru(emb) tf_out = model.transformer(emb) mamba_out = model.mamba(emb) # --------------------------------------------- # EXPERT ACTIVITY # --------------------------------------------- cnn_score = cnn_out.abs().mean().item() gru_score = gru_out.abs().mean().item() tf_score = tf_out.abs().mean().item() mamba_score = mamba_out.abs().mean().item() total = ( cnn_score + gru_score + tf_score + mamba_score ) cnn_w = cnn_score / total gru_w = gru_score / total tf_w = tf_score / total mamba_w = mamba_score / total # --------------------------------------------- # FINAL PRED # --------------------------------------------- fused = torch.cat( [ cnn_out, gru_out, tf_out, mamba_out ], dim=-1 ) fused = model.fusion(fused) pooled = fused.mean(dim=1) logits = model.classifier(pooled) probs = F.softmax( logits, dim=-1 ) pred = probs.argmax(dim=-1).item() return { "prediction": ID2LABEL[pred], "probs": probs[0].cpu(), "cnn": cnn_w, "gru": gru_w, "tf": tf_w, "mamba": mamba_w } # ========================================================= # FULL ANALYSIS # ========================================================= def analyze_sequence(sequence): sequence = sequence.strip().upper() # ----------------------------------------------------- # VALIDATION # ----------------------------------------------------- valid = all( c in TOKEN_MAP for c in sequence ) if not valid: print("\nOnly A/T/G/C allowed.\n") return # ----------------------------------------------------- # LENGTH CHECK # ----------------------------------------------------- length = len(sequence) if length < CHUNK_SIZE: missing = CHUNK_SIZE - length print("\n===================================") print(" LENGTH ERROR") print("===================================\n") print("Input too short.\n") print( f"Current Length : {length}" ) print( f"Missing Chars : {missing}" ) print( f"Required Length: {CHUNK_SIZE}" ) print("\n===================================\n") return # ----------------------------------------------------- # MULTIPLE CHECK # ----------------------------------------------------- if length % CHUNK_SIZE != 0: next_valid = ( ( length // CHUNK_SIZE ) + 1 ) * CHUNK_SIZE missing = next_valid - length print("\n===================================") print(" LENGTH ERROR") print("===================================\n") print( f"Sequence length must be " f"a multiple of {CHUNK_SIZE}.\n" ) print( f"Current Length : {length}" ) print( f"Next Valid Size: {next_valid}" ) print( f"Missing Chars : {missing}" ) print("\n===================================\n") return # ----------------------------------------------------- # CHUNKING # ----------------------------------------------------- chunks = split_chunks(sequence) print("\n===================================") print(" ANALYZING INPUT") print("===================================\n") print(f"Total Length : {len(sequence)}") print(f"Chunks : {len(chunks)}") # ----------------------------------------------------- # AGGREGATION # ----------------------------------------------------- total_probs = torch.zeros(10) total_cnn = 0 total_gru = 0 total_tf = 0 total_mamba = 0 # ----------------------------------------------------- # PROCESS CHUNKS # ----------------------------------------------------- for idx, chunk in enumerate(chunks): result = analyze_chunk(chunk) total_probs += result["probs"] total_cnn += result["cnn"] total_gru += result["gru"] total_tf += result["tf"] total_mamba += result["mamba"] print("\n-----------------------------------") print(f"Chunk {idx+1}") print("-----------------------------------\n") print(chunk) print("\nPrediction:") print(result["prediction"]) print("\nProbabilities:\n") for i in range(3): print( f"{ID2LABEL[i]}: " f"{result['probs'][i].item():.4f}" ) # ----------------------------------------------------- # AVERAGES # ----------------------------------------------------- total_probs /= len(chunks) total_cnn /= len(chunks) total_gru /= len(chunks) total_tf /= len(chunks) total_mamba /= len(chunks) # ----------------------------------------------------- # FINAL DECISION # ----------------------------------------------------- final_pred = total_probs.argmax().item() print("\n===================================") print(" FINAL RESULT") print("===================================\n") print( f"FINAL DECISION: " f"{ID2LABEL[final_pred]}" ) print("\n-----------------------------------") print("Average Probabilities") print("-----------------------------------\n") for i in range(3): print( f"{ID2LABEL[i]}: " f"{total_probs[i].item():.4f}" ) print("\n-----------------------------------") print("Average Expert Activity") print("-----------------------------------\n") print(f"CNN : {total_cnn:.4f}") print(f"GRU : {total_gru:.4f}") print(f"Transformer : {total_tf:.4f}") print(f"Mamba : {total_mamba:.4f}") print("\n===================================\n") # ========================================================= # CHAT LOOP # ========================================================= print("Type DNA sequence.") print("Length must be 64 or multiples of 64.") print("Type EXIT to quit.\n") while True: seq = input("logs > ") if seq.strip().upper() == "EXIT": print("\nBye.\n") break analyze_sequence(seq)