import torch import torch.nn as nn import torch.nn.functional as F import faiss import math import numpy as np from typing import Optional, Tuple, Literal from dataclasses import dataclass # Global configuration world_size = 1 rank = 0 block_size = 128 gemm_impl: Literal["bf16", "fp8"] = "bf16" attn_impl: Literal["naive", "absorb"] = "absorb" @dataclass class ModelArgs: dim: int = 4096 n_layers: int = 32 n_heads: int = 32 n_kv_heads: int = 8 vocab_size: int = 32000 multiple_of: int = 256 ffn_dim_multiplier: Optional[float] = None max_seq_len: int = 4096 original_seq_len: int = 4096 rope_theta: float = 10000.0 rope_factor: float = 1.0 beta_fast: float = 32.0 beta_slow: float = 1.0 mscale: float = 0.707 q_lora_rank: int = 0 kv_lora_rank: int = 0 qk_nope_head_dim: int = 128 qk_rope_head_dim: int = 64 v_head_dim: int = 128 n_routed_experts: int = 8 n_activated_experts: int = 2 n_expert_groups: int = 1 n_limited_groups: int = 1 score_func: str = "softmax" route_scale: float = 1.0 n_dense_layers: int = 0 moe_inter_dim: int = None n_shared_experts: int = 1 max_batch_size: int = 32 dtype: str = "bf16" def __post_init__(self): if self.ffn_dim_multiplier is None: self.inter_dim = int(2 * self.dim / 3) self.inter_dim = self.multiple_of * ((self.inter_dim + self.multiple_of - 1) // self.multiple_of) else: self.inter_dim = int(2 * self.dim * self.ffn_dim_multiplier) if self.moe_inter_dim is None: self.moe_inter_dim = int(2 * self.dim / 3) self.moe_inter_dim = self.multiple_of * ((self.moe_inter_dim + self.multiple_of - 1) // self.multiple_of) # Embedding layer class ParallelEmbedding(nn.Module): def __init__(self, vocab_size: int, dim: int): super().__init__() self.vocab_size = vocab_size self.dim = dim assert vocab_size % world_size == 0 self.part_vocab_size = (vocab_size // world_size) self.vocab_start_idx = rank * self.part_vocab_size self.vocab_end_idx = self.vocab_start_idx + self.part_vocab_size self.weight = nn.Parameter(torch.empty(self.part_vocab_size, self.dim)) def forward(self, x: torch.Tensor) -> torch.Tensor: if world_size > 1: mask = (x < self.vocab_start_idx) | (x >= self.vocab_end_idx) x = x - self.vocab_start_idx x[mask] = 0 y = F.embedding(x, self.weight) if world_size > 1: y[mask] = 0 torch.distributed.all_reduce(y) return y # Linear layer def linear(x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: if weight.element_size() > 1: return F.linear(x, weight, bias) elif gemm_impl == "bf16": weight = weight_dequant(weight, weight.scale) return F.linear(x, weight, bias) else: x, scale = act_quant(x, block_size) y = fp8_gemm(x, scale, weight, weight.scale) if bias is not None: y += bias return y class Linear(nn.Module): dtype = torch.bfloat16 def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype=None): super().__init__() self.in_features = in_features self.out_features = out_features self.weight = nn.Parameter(torch.empty(out_features, in_features, dtype=dtype or Linear.dtype)) if self.weight.element_size() == 1: scale_out_features = (out_features + block_size - 1) // block_size scale_in_features = (in_features + block_size - 1) // block_size self.weight.scale = self.scale = nn.Parameter(torch.empty(scale_out_features, scale_in_features, dtype=torch.float32)) else: self.register_parameter("scale", None) if bias: self.bias = nn.Parameter(torch.empty(out_features)) else: self.register_parameter("bias", None) def forward(self, x: torch.Tensor) -> torch.Tensor: return linear(x, self.weight, self.bias) class ColumnParallelLinear(Linear): def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype=None): assert out_features % world_size == 0 self.part_out_features = out_features // world_size super().__init__(in_features, self.part_out_features, bias, dtype) def forward(self, x: torch.Tensor) -> torch.Tensor: return linear(x, self.weight, self.bias) class RowParallelLinear(Linear): def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype=None): assert in_features % world_size == 0 self.part_in_features = in_features // world_size super().__init__(self.part_in_features, out_features, bias, dtype) def forward(self, x: torch.Tensor) -> torch.Tensor: y = linear(x, self.weight) if world_size > 1: torch.distributed.all_reduce(y) if self.bias is not None: y += self.bias return y # Normalization layer class RMSNorm(nn.Module): def __init__(self, dim: int, eps: float = 1e-6): super().__init__() self.eps = eps self.weight = nn.Parameter(torch.ones(dim)) def forward(self, x: torch.Tensor): x = x.float() y = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) return y.type_as(self.weight) * self.weight # Positional encoding def precompute_freqs_cis(args: ModelArgs) -> torch.Tensor: dim = args.qk_rope_head_dim seqlen = args.max_seq_len beta_fast = args.beta_fast beta_slow = args.beta_slow base = args.rope_theta factor = args.rope_factor def find_correction_dim(num_rotations, dim, base, max_seq_len): return dim * math.log(max_seq_len / (num_rotations * 2 * math.pi)) / (2 * math.log(base)) def find_correction_range(low_rot, high_rot, dim, base, max_seq_len): low = math.floor(find_correction_dim(low_rot, dim, base, max_seq_len)) high = math.ceil(find_correction_dim(high_rot, dim, base, max_seq_len)) return max(low, 0), min(high, dim-1) def linear_ramp_factor(min, max, dim): if min == max: max += 0.001 linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min) ramp_func = torch.clamp(linear_func, 0, 1) return ramp_func freqs = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) if seqlen > args.original_seq_len: low, high = find_correction_range(beta_fast, beta_slow, dim, base, args.original_seq_len) smooth = 1 - linear_ramp_factor(low, high, dim // 2) freqs = freqs / factor * (1 - smooth) + freqs * smooth t = torch.arange(seqlen) freqs = torch.outer(t, freqs) freqs_cis = torch.polar(torch.ones_like(freqs), freqs) return freqs_cis def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor: dtype = x.dtype x = torch.view_as_complex(x.float().view(*x.shape[:-1], -1, 2)) freqs_cis = freqs_cis.view(1, x.size(1), 1, x.size(-1)) y = torch.view_as_real(x * freqs_cis).flatten(3) return y.to(dtype) # Multi-Head Latent Attention (MLA) class MLA(nn.Module): def __init__(self, args: ModelArgs): super().__init__() self.dim = args.dim self.n_heads = args.n_heads self.n_local_heads = args.n_heads // world_size self.q_lora_rank = args.q_lora_rank self.kv_lora_rank = args.kv_lora_rank self.qk_nope_head_dim = args.qk_nope_head_dim self.qk_rope_head_dim = args.qk_rope_head_dim self.qk_head_dim = args.qk_nope_head_dim + args.qk_rope_head_dim self.v_head_dim = args.v_head_dim if self.q_lora_rank == 0: self.wq = ColumnParallelLinear(self.dim, self.n_heads * self.qk_head_dim) else: self.wq_a = Linear(self.dim, self.q_lora_rank) self.q_norm = RMSNorm(self.q_lora_rank) self.wq_b = ColumnParallelLinear(self.q_lora_rank, self.n_heads * self.qk_head_dim) self.wkv_a = Linear(self.dim, self.kv_lora_rank + self.qk_rope_head_dim) self.kv_norm = RMSNorm(self.kv_lora_rank) self.wkv_b = ColumnParallelLinear(self.kv_lora_rank, self.n_heads * (self.qk_nope_head_dim + self.v_head_dim)) self.wo = RowParallelLinear(self.n_heads * self.v_head_dim, self.dim) self.softmax_scale = self.qk_head_dim ** -0.5 if args.max_seq_len > args.original_seq_len: mscale = 0.1 * args.mscale * math.log(args.rope_factor) + 1.0 self.softmax_scale = self.softmax_scale * mscale * mscale if attn_impl == "naive": self.register_buffer("k_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.n_local_heads, self.qk_head_dim), persistent=False) self.register_buffer("v_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.n_local_heads, self.v_head_dim), persistent=False) else: self.register_buffer("kv_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.kv_lora_rank), persistent=False) self.register_buffer("pe_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.qk_rope_head_dim), persistent=False) def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]): bsz, seqlen, _ = x.size() end_pos = start_pos + seqlen if self.q_lora_rank == 0: q = self.wq(x) else: q = self.wq_b(self.q_norm(self.wq_a(x))) q = q.view(bsz, seqlen, self.n_local_heads, self.qk_head_dim) q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) q_pe = apply_rotary_emb(q_pe, freqs_cis) kv = self.wkv_a(x) kv, k_pe = torch.split(kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) k_pe = apply_rotary_emb(k_pe.unsqueeze(2), freqs_cis) if attn_impl == "naive": q = torch.cat([q_nope, q_pe], dim=-1) kv = self.wkv_b(self.kv_norm(kv)) kv = kv.view(bsz, seqlen, self.n_local_heads, self.qk_nope_head_dim + self.v_head_dim) k_nope, v = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1) k = torch.cat([k_nope, k_pe.expand(-1, -1, self.n_local_heads, -1)], dim=-1) self.k_cache[:bsz, start_pos:end_pos] = k self.v_cache[:bsz, start_pos:end_pos] = v scores = torch.einsum("bshd,bthd->bsht", q, self.k_cache[:bsz, :end_pos]) * self.softmax_scale else: wkv_b = self.wkv_b.weight if self.wkv_b.scale is None else weight_dequant(self.wkv_b.weight, self.wkv_b.scale, block_size) wkv_b = wkv_b.view(self.n_local_heads, -1, self.kv_lora_rank) q_nope = torch.einsum("bshd,hdc->bshc", q_nope, wkv_b[:, :self.qk_nope_head_dim]) self.kv_cache[:bsz, start_pos:end_pos] = self.kv_norm(kv) self.pe_cache[:bsz, start_pos:end_pos] = k_pe.squeeze(2) scores = (torch.einsum("bshc,btc->bsht", q_nope, self.kv_cache[:bsz, :end_pos]) + torch.einsum("bshr,btr->bsht", q_pe, self.pe_cache[:bsz, :end_pos])) * self.softmax_scale if mask is not None: scores += mask.unsqueeze(1) scores = scores.softmax(dim=-1, dtype=torch.float32).type_as(x) if attn_impl == "naive": x = torch.einsum("bsht,bthd->bshd", scores, self.v_cache[:bsz, :end_pos]) else: x = torch.einsum("bsht,btc->bshc", scores, self.kv_cache[:bsz, :end_pos]) x = torch.einsum("bshc,hdc->bshd", x, wkv_b[:, -self.v_head_dim:]) x = self.wo(x.flatten(2)) return x # MLP layer class MLP(nn.Module): def __init__(self, dim: int, inter_dim: int): super().__init__() self.w1 = ColumnParallelLinear(dim, inter_dim) self.w2 = RowParallelLinear(inter_dim, dim) self.w3 = ColumnParallelLinear(dim, inter_dim) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.w2(F.silu(self.w1(x)) * self.w3(x)) # Mixture of Experts (MoE) components class Gate(nn.Module): def __init__(self, args: ModelArgs): super().__init__() self.dim = args.dim self.topk = args.n_activated_experts self.n_groups = args.n_expert_groups self.topk_groups = args.n_limited_groups self.score_func = args.score_func self.route_scale = args.route_scale self.weight = nn.Parameter(torch.empty(args.n_routed_experts, args.dim)) self.bias = nn.Parameter(torch.empty(args.n_routed_experts)) if self.dim == 7168 else None def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: scores = linear(x, self.weight) if self.score_func == "softmax": scores = scores.softmax(dim=-1, dtype=torch.float32) else: scores = scores.sigmoid() original_scores = scores if self.bias is not None: scores = scores + self.bias if self.n_groups > 1: scores = scores.view(x.size(0), self.n_groups, -1) if self.bias is None: group_scores = scores.amax(dim=-1) else: group_scores = scores.topk(2, dim=-1)[0].sum(dim=-1) indices = group_scores.topk(self.topk_groups, dim=-1)[1] mask = torch.zeros_like(scores[..., 0]).scatter_(1, indices, True) scores = (scores * mask.unsqueeze(-1)).flatten(1) indices = torch.topk(scores, self.topk, dim=-1)[1] weights = original_scores.gather(1, indices) if self.score_func == "sigmoid": weights /= weights.sum(dim=-1, keepdim=True) weights *= self.route_scale return weights.type_as(x), indices class Expert(nn.Module): def __init__(self, dim: int, inter_dim: int): super().__init__() self.w1 = Linear(dim, inter_dim) self.w2 = Linear(inter_dim, dim) self.w3 = Linear(dim, inter_dim) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.w2(F.silu(self.w1(x)) * self.w3(x)) class MoE(nn.Module): def __init__(self, args: ModelArgs): super().__init__() self.dim = args.dim assert args.n_routed_experts % world_size == 0 self.n_routed_experts = args.n_routed_experts self.n_local_experts = args.n_routed_experts // world_size self.n_activated_experts = args.n_activated_experts self.experts_start_idx = rank * self.n_local_experts self.experts_end_idx = self.experts_start_idx + self.n_local_experts self.gate = Gate(args) self.experts = nn.ModuleList([Expert(args.dim, args.moe_inter_dim) if self.experts_start_idx <= i < self.experts_end_idx else None for i in range(self.n_routed_experts)]) self.shared_experts = MLP(args.dim, args.n_shared_experts * args.moe_inter_dim) def forward(self, x: torch.Tensor) -> torch.Tensor: shape = x.size() x = x.view(-1, self.dim) weights, indices = self.gate(x) y = torch.zeros_like(x) counts = torch.bincount(indices.flatten(), minlength=self.n_routed_experts).tolist() for i in range(self.experts_start_idx, self.experts_end_idx): if counts[i] == 0: continue expert = self.experts[i] idx, top = torch.where(indices == i) y[idx] += expert(x[idx]) * weights[idx, top, None] z = self.shared_experts(x) if world_size > 1: torch.distributed.all_reduce(y) return (y + z).view(shape) # Transformer block class Block(nn.Module): def __init__(self, layer_id: int, args: ModelArgs): super().__init__() self.attn = MLA(args) self.ffn = MLP(args.dim, args.inter_dim) if layer_id < args.n_dense_layers else MoE(args) self.attn_norm = RMSNorm(args.dim) self.ffn_norm = RMSNorm(args.dim) def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]) -> torch.Tensor: x = x + self.attn(self.attn_norm(x), start_pos, freqs_cis, mask) x = x + self.ffn(self.ffn_norm(x)) return x # Transformer model class Transformer(nn.Module): def __init__(self, args: ModelArgs): global world_size, rank world_size = torch.distributed.get_world_size() if torch.distributed.is_initialized() else 1 rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0 Linear.dtype = torch.float8_e4m3fn if args.dtype == "fp8" else torch.bfloat16 super().__init__() self.max_seq_len = args.max_seq_len self.embed = ParallelEmbedding(args.vocab_size, args.dim) self.layers = torch.nn.ModuleList() for layer_id in range(args.n_layers): self.layers.append(Block(layer_id, args)) self.norm = RMSNorm(args.dim) self.head = ColumnParallelLinear(args.dim, args.vocab_size, dtype=torch.get_default_dtype()) self.register_buffer("freqs_cis", precompute_freqs_cis(args), persistent=False) @torch.inference_mode() def forward(self, tokens: torch.Tensor, start_pos: int = 0): seqlen = tokens.size(1) h = self.embed(tokens) freqs_cis = self.freqs_cis[start_pos:start_pos+seqlen] mask = None if seqlen > 1: mask = torch.full((seqlen, seqlen), float("-inf"), device=tokens.device).triu_(1) for layer in self.layers: h = layer(h, start_pos, freqs_cis, mask) h = self.norm(h)[:, -1] logits = self.head(h) if world_size > 1: all_logits = [torch.empty_like(logits) for _ in range(world_size)] torch.distributed.all_gather(all_logits, logits) logits = torch.cat(all_logits, dim=-1) return logits # FAISS Retriever class FAISSRetriever: def __init__(self, knowledge_base: faiss.Index, dim: int = 768, num_results: int = 5): self.index = knowledge_base self.dim = dim self.num_results = num_results def search(self, query_embedding: torch.Tensor, k: int = None) -> torch.Tensor: if k is None: k = self.num_results query_np = query_embedding.detach().cpu().numpy() distances, indices = self.index.search(query_np, k) return torch.tensor(indices, device=query_embedding.device) # Complete Multi-Modal LLM class CombinedMultiModalTransformer(nn.Module): def __init__(self, args: ModelArgs, knowledge_base: faiss.Index): super(CombinedMultiModalTransformer, self).__init__() self.args = args self.transformer = Transformer(args) # Multi-modal components self.audio_encoder = nn.Sequential( nn.Conv1d(256, 256, kernel_size=11, stride=2, padding='same'), nn.ReLU(), nn.Conv1d(256, 256, kernel_size=11, stride=2, padding='same'), nn.ReLU(), nn.Conv1d(256, args.dim, kernel_size=11, stride=2, padding='same'), nn.ReLU() ) self.image_encoder = nn.Sequential( # Simplified ResNet50 implementation nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3), nn.ReLU(), nn.MaxPool2d(kernel_size=3, stride=2, padding=1), nn.AdaptiveAvgPool2d((1, 1)), nn.Flatten(), nn.Linear(2048, args.dim) ) # Music generation components self.pitch_embedding = nn.Embedding(128, args.dim) self.duration_embedding = nn.Embedding(32, args.dim) self.velocity_embedding = nn.Embedding(128, args.dim) # Anomaly detection components self.anomaly_detector = nn.Sequential( nn.Linear(args.dim, args.dim), nn.ReLU(), nn.Linear(args.dim, 1), nn.Sigmoid() ) # RAG components self.knowledge_base = FAISSRetriever(knowledge_base) self.query_encoder = nn.Sequential( nn.Linear(args.dim, args.dim), nn.ReLU(), nn.Linear(args.dim, args.dim) ) def forward(self, inputs, task, start_pos=0): if task == 'text_generation': # RAG component query_embedding = self.query_encoder(self.transformer.embed(inputs)) retrieved_indices = self.knowledge_base.search(query_embedding, k=5) # Concatenate retrieved docs with input # In practice, you would convert indices to actual embeddings retrieved_embeddings = torch.zeros_like(inputs[:, :5, :]) # Placeholder inputs = torch.cat([retrieved_embeddings, inputs], dim=1) # Pass through transformer logits = self.transformer(inputs, start_pos) return logits elif task == 'speech_recognition': x = self.audio_encoder(inputs) # Convert audio encoder output to transformer format batch_size, seq_len = x.shape[0], x.shape[1] tokens = torch.zeros(batch_size, seq_len, dtype=torch.long, device=x.device) logits = self.transformer(tokens, start_pos) return logits elif task == 'image_captioning': image_features = self.image_encoder(inputs) # Convert image features to transformer format batch_size = image_features.shape[0] tokens = torch.zeros(batch_size, 1, dtype=torch.long, device=image_features.device) logits = self.transformer(tokens, start_pos) return logits elif task == 'music_generation': pitch, duration, velocity = inputs x = self.pitch_embedding(pitch) + self.duration_embedding(duration) + self.velocity_embedding(velocity) # Convert music features to transformer format batch_size, seq_len = x.shape[0], x.shape[1] tokens = torch.zeros(batch_size, seq_len, dtype=torch.long, device=x.device) logits = self.transformer(tokens, start_pos) return logits elif task == 'anomaly_detection': x = self.transformer.embed(inputs) anomaly_scores = self.anomaly_detector(x) return anomaly_scores else: raise ValueError(f"Unknown task: {task}") # Helper functions def act_quant(x: torch.Tensor, block_size: int = 128): # Simplified activation quantization function return x, torch.ones(1, device=x.device) def weight_dequant(weight: torch.Tensor, scale: torch.Tensor, block_size: int = 128): # Simplified weight dequantization function return weight * scale def fp8_gemm(x: torch.Tensor, x_scale: torch.Tensor, weight: torch.Tensor, weight_scale: torch.Tensor): # Simplified FP8 GEMM function return torch.matmul(x, weight.t()) * x_scale * weight_scale # Training function def train_model(model, dataloader, optimizer, criterion, device, num_epochs=10): model.train() model.to(device) for epoch in range(num_epochs): total_loss = 0.0 for batch_idx, (inputs, targets, tasks) in enumerate(dataloader): inputs, targets = inputs.to(device), targets.to(device) optimizer.zero_grad() outputs = model(inputs, tasks) if isinstance(outputs, dict): # Handle multi-task outputs loss = 0.0 for task, output in outputs.items(): task_targets = targets[task] loss += criterion(output, task_targets) else: loss = criterion(outputs, targets) loss.backward() optimizer.step() total_loss += loss.item() if batch_idx % 100 == 0: print(f'Epoch {epoch+1}/{num_epochs}, Batch {batch_idx}/{len(dataloader)}, Loss: {loss.item():.4f}') avg_loss = total_loss / len(dataloader) print(f'Epoch {epoch+1}/{num_epochs}, Average Loss: {avg_loss:.4f}') # Inference function def generate_text(model, prompt, max_length=100, temperature=1.0, device='cpu'): model.eval() model.to(device) # Convert prompt to tokens tokens = torch.tensor([prompt], dtype=torch.long, device=device) with torch.no_grad(): for _ in range(max_length): logits = model(tokens, 'text_generation') # Apply temperature scaling logits = logits[:, -1, :] / temperature # Get probabilities for next token probs = F.softmax(logits, dim=-1) # Sample next token next_token = torch.multinomial(probs, num_samples=1) # Append new token to sequence tokens = torch.cat([tokens, next_token], dim=1) return tokens[0].tolist() # Example usage if __name__ == "__main__": # Initialize model parameters args = ModelArgs() # Create a dummy knowledge base for testing dim = args.dim knowledge_base = faiss.IndexFlatL2(dim) # Add some dummy vectors vectors = np.random.rand(100, dim).astype('float32') knowledge_base.add(vectors) # Initialize model model = CombinedMultiModalTransformer(args, knowledge_base) # Print model structure print(model) # Test text generation prompt = [1, 2, 3, 4, 5] # Example token sequence generated_tokens = generate_text(model, prompt, max_length=20) print(f"Generated tokens: {generated_tokens}") # Test other tasks # Note: In practice, you would provide appropriate input data try: # Speech recognition audio_input = torch.randn(1, 256, 160) # Example audio input speech_output = model(audio_input, 'speech_recognition') print(f"Speech recognition output shape: {speech_output.shape}") # Image captioning image_input = torch.randn(1, 3, 224, 224) # Example image input caption_output = model(image_input, 'image_captioning') print(f"Image captioning output shape: {caption_output.shape}") # Music generation pitch = torch.randint(0, 128, (1, 100)) duration = torch.randint(0, 32, (1, 100)) velocity = torch.randint(0, 128, (1, 100)) music_output = model((pitch, duration, velocity), 'music_generation') print(f"Music generation output shape: {music_output.shape}") # Anomaly detection anomaly_input = torch.randn(1, 100, args.dim) anomaly_output = model(anomaly_input, 'anomaly_detection') print(f"Anomaly detection output shape: {anomaly_output.shape}") except Exception as e: print(f"Error during testing: {e}")