ShadowCore-v2 / chat.py
brscftc's picture
Upload 4 files
e49356a verified
import torch
import torch.nn.functional as F
from model import GenoLiteHybrid
# =========================================================
# CONFIG
# =========================================================
DEVICE = torch.device(
"cuda" if torch.cuda.is_available() else "cpu"
)
CHUNK_SIZE = 64
TOKEN_MAP = {
"U": 0,
"D": 1,
"-": 2,
"+": 3,
"J": 4,
"R": 5,
"L": 6,
"T": 7,
"C": 8,
"H": 9,
"F": 10
}
ID2LABEL = {
0: "0",
1: "1",
2: "2",
3: "3",
4: "4",
5: "5",
6: "6",
7: "7",
8: "8",
9: "9"
}
# =========================================================
# LOAD MODEL
# =========================================================
model = GenoLiteHybrid().to(DEVICE)
checkpoint = torch.load(
"model.pt",
map_location=DEVICE
)
# ---------------------------------------------------------
# RAW OR FULL CHECKPOINT
# ---------------------------------------------------------
if isinstance(checkpoint, dict) and \
"model_state_dict" in checkpoint:
model.load_state_dict(
checkpoint["model_state_dict"]
)
print("\nLoaded full checkpoint.")
else:
model.load_state_dict(checkpoint)
print("\nLoaded raw state_dict.")
model.eval()
print("\n===================================")
print(" MODEL LOADED")
print("===================================\n")
# =========================================================
# ENCODE
# =========================================================
def encode(seq):
return torch.tensor(
[TOKEN_MAP[c] for c in seq],
dtype=torch.long
)
# =========================================================
# CHUNKING
# =========================================================
def split_chunks(sequence):
chunks = []
for i in range(
0,
len(sequence),
CHUNK_SIZE
):
chunk = sequence[
i:i + CHUNK_SIZE
]
chunks.append(chunk)
return chunks
# =========================================================
# SINGLE CHUNK INFERENCE
# =========================================================
def analyze_chunk(sequence):
x = encode(sequence)
x = x.unsqueeze(0).to(DEVICE)
with torch.no_grad():
# ---------------------------------------------
# EMBEDDING
# ---------------------------------------------
emb = model.embedding(x)
# ---------------------------------------------
# EXPERTS
# ---------------------------------------------
cnn_out = model.cnn(emb)
gru_out = model.gru(emb)
tf_out = model.transformer(emb)
mamba_out = model.mamba(emb)
# ---------------------------------------------
# EXPERT ACTIVITY
# ---------------------------------------------
cnn_score = cnn_out.abs().mean().item()
gru_score = gru_out.abs().mean().item()
tf_score = tf_out.abs().mean().item()
mamba_score = mamba_out.abs().mean().item()
total = (
cnn_score +
gru_score +
tf_score +
mamba_score
)
cnn_w = cnn_score / total
gru_w = gru_score / total
tf_w = tf_score / total
mamba_w = mamba_score / total
# ---------------------------------------------
# FINAL PRED
# ---------------------------------------------
fused = torch.cat(
[
cnn_out,
gru_out,
tf_out,
mamba_out
],
dim=-1
)
fused = model.fusion(fused)
pooled = fused.mean(dim=1)
logits = model.classifier(pooled)
probs = F.softmax(
logits,
dim=-1
)
pred = probs.argmax(dim=-1).item()
return {
"prediction": ID2LABEL[pred],
"probs": probs[0].cpu(),
"cnn": cnn_w,
"gru": gru_w,
"tf": tf_w,
"mamba": mamba_w
}
# =========================================================
# FULL ANALYSIS
# =========================================================
def analyze_sequence(sequence):
sequence = sequence.strip().upper()
# -----------------------------------------------------
# VALIDATION
# -----------------------------------------------------
valid = all(
c in TOKEN_MAP
for c in sequence
)
if not valid:
print("\nOnly A/T/G/C allowed.\n")
return
# -----------------------------------------------------
# LENGTH CHECK
# -----------------------------------------------------
length = len(sequence)
if length < CHUNK_SIZE:
missing = CHUNK_SIZE - length
print("\n===================================")
print(" LENGTH ERROR")
print("===================================\n")
print("Input too short.\n")
print(
f"Current Length : {length}"
)
print(
f"Missing Chars : {missing}"
)
print(
f"Required Length: {CHUNK_SIZE}"
)
print("\n===================================\n")
return
# -----------------------------------------------------
# MULTIPLE CHECK
# -----------------------------------------------------
if length % CHUNK_SIZE != 0:
next_valid = (
(
length // CHUNK_SIZE
) + 1
) * CHUNK_SIZE
missing = next_valid - length
print("\n===================================")
print(" LENGTH ERROR")
print("===================================\n")
print(
f"Sequence length must be "
f"a multiple of {CHUNK_SIZE}.\n"
)
print(
f"Current Length : {length}"
)
print(
f"Next Valid Size: {next_valid}"
)
print(
f"Missing Chars : {missing}"
)
print("\n===================================\n")
return
# -----------------------------------------------------
# CHUNKING
# -----------------------------------------------------
chunks = split_chunks(sequence)
print("\n===================================")
print(" ANALYZING INPUT")
print("===================================\n")
print(f"Total Length : {len(sequence)}")
print(f"Chunks : {len(chunks)}")
# -----------------------------------------------------
# AGGREGATION
# -----------------------------------------------------
total_probs = torch.zeros(10)
total_cnn = 0
total_gru = 0
total_tf = 0
total_mamba = 0
# -----------------------------------------------------
# PROCESS CHUNKS
# -----------------------------------------------------
for idx, chunk in enumerate(chunks):
result = analyze_chunk(chunk)
total_probs += result["probs"]
total_cnn += result["cnn"]
total_gru += result["gru"]
total_tf += result["tf"]
total_mamba += result["mamba"]
print("\n-----------------------------------")
print(f"Chunk {idx+1}")
print("-----------------------------------\n")
print(chunk)
print("\nPrediction:")
print(result["prediction"])
print("\nProbabilities:\n")
for i in range(3):
print(
f"{ID2LABEL[i]}: "
f"{result['probs'][i].item():.4f}"
)
# -----------------------------------------------------
# AVERAGES
# -----------------------------------------------------
total_probs /= len(chunks)
total_cnn /= len(chunks)
total_gru /= len(chunks)
total_tf /= len(chunks)
total_mamba /= len(chunks)
# -----------------------------------------------------
# FINAL DECISION
# -----------------------------------------------------
final_pred = total_probs.argmax().item()
print("\n===================================")
print(" FINAL RESULT")
print("===================================\n")
print(
f"FINAL DECISION: "
f"{ID2LABEL[final_pred]}"
)
print("\n-----------------------------------")
print("Average Probabilities")
print("-----------------------------------\n")
for i in range(3):
print(
f"{ID2LABEL[i]}: "
f"{total_probs[i].item():.4f}"
)
print("\n-----------------------------------")
print("Average Expert Activity")
print("-----------------------------------\n")
print(f"CNN : {total_cnn:.4f}")
print(f"GRU : {total_gru:.4f}")
print(f"Transformer : {total_tf:.4f}")
print(f"Mamba : {total_mamba:.4f}")
print("\n===================================\n")
# =========================================================
# CHAT LOOP
# =========================================================
print("Type DNA sequence.")
print("Length must be 64 or multiples of 64.")
print("Type EXIT to quit.\n")
while True:
seq = input("logs > ")
if seq.strip().upper() == "EXIT":
print("\nBye.\n")
break
analyze_sequence(seq)