#!/usr/bin/env python3 # chat.py - use trained mapper + decoder interactively import torch from transformers import T5Tokenizer from sentence_transformers import SentenceTransformer import torch.nn as nn # ===== CONFIG ===== MAPPER_PTH = "semantic_mapper.pth" DECODER_PTH = "embedding_decoder.pth" MODEL_NAME = "Snowflake/snowflake-arctic-embed-l-v2.0" MAX_LEN = 4096 DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") # ===== LOAD TOKENIZER ===== tokenizer = T5Tokenizer.from_pretrained("t5-small") pad_id = tokenizer.pad_token_id eos_id = tokenizer.eos_token_id # ===== MODEL CLASSES (same defs as training) ===== class SemanticMapper(torch.nn.Module): def __init__(self, dim): super().__init__() self.net = torch.nn.Sequential( torch.nn.Linear(dim, dim * 2), torch.nn.ReLU(), torch.nn.Linear(dim * 2, dim) ) def forward(self, x): return self.net(x) class EmbeddingDecoder(nn.Module): def __init__(self, input_dim, hidden_dim, vocab_size, p=0.2): super().__init__() self.bridge = nn.Linear(input_dim, hidden_dim) # emb -> h0 self.embed = nn.Embedding(vocab_size, hidden_dim) # token -> hidden self.gru = nn.GRU(hidden_dim + input_dim, hidden_dim, batch_first=True) self.ln = nn.LayerNorm(hidden_dim) self.fc = nn.Linear(hidden_dim, vocab_size, bias=True) self.drop = nn.Dropout(p) # tie weights self.fc.weight = self.embed.weight @torch.no_grad() def greedy_decode(self, emb_vec, max_len, start_id, eos_id): B, _ = emb_vec.shape h = torch.tanh(self.bridge(emb_vec)).unsqueeze(0) inp = torch.full((B,1), start_id, dtype=torch.long, device=emb_vec.device) out_ids = [] for _ in range(max_len): token_h = self.drop(self.embed(inp)) # [B,1,H] step_in = torch.cat([token_h, emb_vec.unsqueeze(1)], dim=-1) out, h = self.gru(step_in, h) out = self.ln(out.squeeze(1)) logits = self.fc(self.drop(out)) logits[:, pad_id] = -1e9 next_id = torch.argmax(logits, dim=-1) out_ids.append(next_id.unsqueeze(1)) if (next_id == eos_id).all(): break inp = next_id.unsqueeze(1) return torch.cat(out_ids, dim=1) # ===== LOAD MODELS ===== mapper_ckpt = torch.load(MAPPER_PTH, map_location=DEVICE) mapper = SemanticMapper(mapper_ckpt["dim"]).to(DEVICE) mapper.load_state_dict(mapper_ckpt["state_dict"]) mapper.eval() dec_ckpt = torch.load(DECODER_PTH, map_location=DEVICE) decoder = EmbeddingDecoder(dec_ckpt["dim"], 512, dec_ckpt["vocab_size"]).to(DEVICE) decoder.load_state_dict(dec_ckpt["state_dict"]) decoder.eval() embedder = SentenceTransformer(MODEL_NAME, device=DEVICE) # ===== CHAT LOOP ===== def chat(): print("Chat ready. Type 'quit' to exit.") while True: user = input("User: ").strip() if not user or user.lower() in {"quit","exit"}: break x = embedder.encode([user], convert_to_tensor=True, device=DEVICE).detach().clone() y_pred = mapper(x) ids = decoder.greedy_decode(y_pred, max_len=MAX_LEN, start_id=pad_id, eos_id=eos_id)[0].tolist() reply = tokenizer.decode(ids, skip_special_tokens=True) print("Bot:", reply) if __name__ == "__main__": chat()