| | import torch |
| | import torch.nn as nn |
| | from typing import Dict, Optional, Tuple |
| |
|
| | import torch.nn.functional as F |
| | import torch |
| | import torch.nn as nn |
| |
|
| |
|
| | class WMDetector(nn.Module): |
| | """ |
| | Detect watermarks in an audio signal using a Transformer architecture, |
| | where the watermark bits are split into bytes (8 bits each). |
| | We assume nbits is a multiple of 8. |
| | """ |
| |
|
| | def __init__( |
| | self, input_channels: int, nbits: int, nchunk_size: int, d_model: int = 512 |
| | ): |
| | """ |
| | Args: |
| | input_channels (int): Number of input channels in the audio feature (e.g., mel channels). |
| | nbits (int): Total number of bits in the watermark, must be a multiple of 8. |
| | d_model (int): Embedding dimension for the Transformer. |
| | """ |
| | super().__init__() |
| | self.nchunk_size = nchunk_size |
| | assert nbits % nchunk_size == 0, "nbits must be a multiple of 8!" |
| | self.nbits = nbits |
| | self.d_model = d_model |
| | |
| | self.nchunks = nbits // nchunk_size |
| |
|
| | |
| | self.embedding = nn.Conv1d(input_channels, d_model, kernel_size=1) |
| |
|
| | |
| | self.transformer = nn.TransformerEncoder( |
| | nn.TransformerEncoderLayer( |
| | d_model=d_model, |
| | nhead=1, |
| | dim_feedforward=d_model * 2, |
| | activation="gelu", |
| | batch_first=True, |
| | ), |
| | num_layers=8, |
| | ) |
| |
|
| | |
| | self.watermark_head = nn.Linear(d_model, 1) |
| |
|
| | |
| | self.message_heads = nn.ModuleList( |
| | nn.Linear(d_model, 2**nchunk_size) for _ in range(self.nchunks) |
| | ) |
| |
|
| | |
| | |
| | self.nchunk_embeddings = nn.Parameter(torch.randn(self.nchunks, d_model)) |
| |
|
| | def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: |
| | """ |
| | Forward pass of the detector. |
| | |
| | Returns: |
| | logits (torch.Tensor): Watermark detection logits of shape [batch, seq_len]. |
| | chunk_logits (torch.Tensor): Byte-level classification logits of shape [batch, nchunks, 256]. |
| | """ |
| | batch_size, input_channels, time_steps = x.shape |
| |
|
| | |
| | x = self.embedding(x).permute(0, 2, 1) |
| |
|
| | |
| | |
| | nchunk_embeds = self.nchunk_embeddings.unsqueeze(0).expand(batch_size, -1, -1) |
| | |
| | x = torch.cat([nchunk_embeds, x], dim=1) |
| |
|
| | |
| | x = self.transformer(x) |
| | |
| |
|
| | |
| | detection_part = x[:, self.nchunks :] |
| | logits = self.watermark_head(detection_part).squeeze(-1) |
| |
|
| | |
| | message_part = x[:, : self.nchunks] |
| | chunk_logits_list = [] |
| | for i, head in enumerate(self.message_heads): |
| | |
| | |
| | chunk_vec = message_part[:, i, :] |
| | chunk_logits_list.append(head(chunk_vec).unsqueeze(1)) |
| |
|
| | |
| | chunk_logits = torch.cat(chunk_logits_list, dim=1) |
| |
|
| | return logits, chunk_logits |
| |
|
| | def detect_watermark( |
| | self, |
| | x: torch.Tensor, |
| | sample_rate: Optional[int] = None, |
| | threshold: float = 0.5, |
| | ) -> Tuple[float, torch.Tensor, torch.Tensor]: |
| | """ |
| | A convenience function for inference. |
| | |
| | Returns: |
| | detect_prob (float): Probability that the audio is watermarked. |
| | binary_message (torch.Tensor): The recovered message of shape [batch, nbits] (binary). |
| | detected (torch.Tensor): The sigmoid values of the per-timestep watermark detection. |
| | """ |
| | logits, chunk_logits = self.forward(x) |
| | |
| | |
| |
|
| | |
| | detected = torch.sigmoid(logits) |
| | detect_prob = detected.mean(dim=-1).cpu().item() |
| |
|
| | |
| | chunk_probs = F.softmax(chunk_logits, dim=-1) |
| | chunk_indices = torch.argmax( |
| | chunk_probs, dim=-1 |
| | ) |
| | |
| | |
| | binary_message = [] |
| | for i in range(self.nchunks): |
| | chunk_val = chunk_indices[:, i] |
| | |
| | chunk_bits = [] |
| | for b in range(self.nchunk_size): |
| | bit_b = (chunk_val >> b) & 1 |
| | chunk_bits.append(bit_b.unsqueeze(-1)) |
| | |
| | chunk_bits = torch.cat(chunk_bits, dim=-1) |
| | binary_message.append(chunk_bits) |
| |
|
| | |
| | binary_message = torch.cat(binary_message, dim=-1) |
| |
|
| | return detect_prob, binary_message, detected |
| |
|
| |
|
| |
|
| | class WMEmbedder(nn.Module): |
| | """ |
| | A class that takes a secret message, processes it into chunk embeddings |
| | (as a small sequence), and uses a TransformerDecoder to do cross-attention |
| | between the original hidden (target) and the watermark tokens (memory). |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | nbits: int, |
| | input_dim: int, |
| | nchunk_size: int, |
| | hidden_dim: int = 256, |
| | num_heads: int = 1, |
| | num_layers: int = 4, |
| | ): |
| | super().__init__() |
| | self.nchunk_size = nchunk_size |
| | assert nbits % nchunk_size == 0, "nbits must be a multiple of nchunk_size!" |
| | self.nbits = nbits |
| | self.nchunks = nbits // nchunk_size |
| |
|
| | |
| | self.msg_embeddings = nn.ModuleList( |
| | nn.Embedding(2**nchunk_size, hidden_dim) for _ in range(self.nchunks) |
| | ) |
| |
|
| | |
| | self.input_projection = nn.Linear(input_dim, hidden_dim) |
| |
|
| | |
| | |
| | |
| | decoder_layer = nn.TransformerDecoderLayer( |
| | d_model=hidden_dim, |
| | nhead=num_heads, |
| | dim_feedforward=2 * hidden_dim, |
| | activation="gelu", |
| | batch_first=True, |
| | ) |
| | self.transformer_decoder = nn.TransformerDecoder( |
| | decoder_layer, num_layers=num_layers |
| | ) |
| |
|
| | |
| | |
| | self.output_projection = nn.Linear(hidden_dim, input_dim) |
| |
|
| | def forward(self, hidden: torch.Tensor, msg: torch.Tensor) -> torch.Tensor: |
| | """ |
| | Args: |
| | hidden: [batch, input_dim, seq_len] |
| | msg: [batch, nbits] |
| | Returns: |
| | A tensor [batch, input_dim, seq_len] with watermark injected. |
| | """ |
| | b, in_dim, seq_len = hidden.shape |
| |
|
| | |
| | hidden_projected = self.input_projection( |
| | hidden.permute(0, 2, 1) |
| | ) |
| |
|
| | |
| | |
| | chunk_emb_list = [] |
| | for i in range(self.nchunks): |
| | |
| | chunk_bits = msg[:, i * self.nchunk_size : (i + 1) * self.nchunk_size] |
| | chunk_val = torch.zeros_like(chunk_bits[:, 0]) |
| | for bit_idx in range(self.nchunk_size): |
| | |
| | chunk_val += chunk_bits[:, bit_idx] << bit_idx |
| |
|
| | |
| | chunk_emb = self.msg_embeddings[i](chunk_val) |
| | chunk_emb_list.append(chunk_emb.unsqueeze(1)) |
| |
|
| | |
| | chunk_emb_seq = torch.cat(chunk_emb_list, dim=1) |
| |
|
| | |
| | |
| | |
| | |
| | |
| | x_decoded = self.transformer_decoder( |
| | tgt=hidden_projected, |
| | memory=chunk_emb_seq, |
| | ) |
| |
|
| | |
| | x_output = self.output_projection(x_decoded) |
| |
|
| | |
| | x_output = x_output.permute(0, 2, 1) |
| |
|
| | |
| | x_output = x_output + hidden |
| |
|
| | return x_output |
| |
|
| |
|
| | from speechtokenizer import SpeechTokenizer |
| |
|
| |
|
| | class SBW(nn.Module): |
| | def __init__(self): |
| | super().__init__() |
| | self.nbits = 16 |
| | config_path = ( |
| | "speechtokenizer/pretrained_model/speechtokenizer_hubert_avg_config.json" |
| | ) |
| | ckpt_path = "speechtokenizer/pretrained_model/SpeechTokenizer.pt" |
| | self.st_model = SpeechTokenizer.load_from_checkpoint(config_path, ckpt_path) |
| | self.msg_processor = WMEmbedder( |
| | nbits=16, |
| | input_dim=1024, |
| | nchunk_size=4, |
| | ) |
| | self.detector = WMDetector( |
| | 1024, |
| | 16, |
| | nchunk_size=4, |
| | ) |
| |
|
| | def detect_watermark( |
| | self, x: torch.Tensor, return_logits=False |
| | ) -> Tuple[float, torch.Tensor]: |
| | embedding = self.st_model.forward_feature(x) |
| | if return_logits: |
| | return self.detector(embedding) |
| | return self.detector.detect_watermark(embedding) |
| |
|
| | def forward( |
| | self, |
| | speech_input: torch.Tensor, |
| | message: Optional[torch.Tensor] = None, |
| | ) -> Dict[str, torch.Tensor]: |
| | recon, recon_wm, acoustic, acoustic_wm = self.st_model( |
| | speech_input, msg_processor=self.msg_processor, message=message |
| | ) |
| | wav_length = min(speech_input.size(-1), recon_wm.size(-1)) |
| | speech_input = speech_input[..., :wav_length] |
| | recon = recon[..., :wav_length] |
| | recon_wm = recon_wm[..., :wav_length] |
| | return { |
| | "recon": recon, |
| | "recon_wm": recon_wm, |
| | } |
| |
|