SwipeALot-base / processing_swipe.py
dleemiller's picture
Upload folder using huggingface_hub
bf31071 verified
raw
history blame
11.1 kB
"""Processor for handling multimodal swipe inputs (path + text)."""
import numpy as np
import torch
from transformers import ProcessorMixin
class SwipeProcessor(ProcessorMixin):
"""
Processor for handling multimodal swipe inputs (path coordinates + text).
This processor combines path coordinate preprocessing with text tokenization,
creating the inputs needed for SwipeTransformer models.
Args:
tokenizer: SwipeTokenizer instance
max_path_len (int): Maximum path length. Defaults to 64.
max_char_len (int): Maximum character length. Defaults to 38.
"""
attributes = ["tokenizer"]
tokenizer_class = "AutoTokenizer" # Will use auto_map from tokenizer_config.json
def __init__(self, tokenizer=None, max_path_len: int = 64, max_char_len: int = 38):
self.tokenizer = tokenizer
self.max_path_len = max_path_len
self.max_char_len = max_char_len
# Attributes expected by newer transformers (not used for swipe models)
self.chat_template = None
self.audio_tokenizer = None
self.feature_extractor = None
self.image_processor = None
def __call__(
self,
path_coords: list[list[list[float]]] | torch.Tensor | np.ndarray | None = None,
text: str | list[str] | None = None,
padding: bool | str = True,
truncation: bool = True,
max_length: int | None = None,
return_tensors: str | None = "pt",
**kwargs,
):
"""
Process path coordinates and text into model inputs.
Args:
path_coords: List of paths or tensor [batch, path_len, 3]
Each point is (x, y, time). Can be None if only processing text.
text: String or list of strings to encode. Can be None if only processing paths.
padding: Whether to pad sequences. Can be True/False or "max_length"
truncation: Whether to truncate sequences
max_length: Maximum sequence length for text (overrides max_char_len)
return_tensors: "pt" for PyTorch, "np" for NumPy, None for lists
**kwargs: Additional keyword arguments
Returns:
Dictionary with:
- path_coords: [batch, max_path_len, 3] (if path_coords provided)
- input_ids: [batch, max_char_len] (if text provided)
- attention_mask: [batch, total_seq_len]
"""
if path_coords is None and text is None:
raise ValueError("Must provide either path_coords or text (or both)")
# Determine batch size
if path_coords is not None:
# Handle path coordinates
if isinstance(path_coords, (list, tuple)):
# Check if it's a batch or single path
if len(path_coords) > 0 and isinstance(path_coords[0][0], (list, tuple)):
# Batch of paths [[path1], [path2], ...]
path_coords = torch.tensor(path_coords, dtype=torch.float32)
else:
# Single path [[x,y,t], [x,y,t], ...]
path_coords = torch.tensor([path_coords], dtype=torch.float32)
elif isinstance(path_coords, np.ndarray):
path_coords = torch.from_numpy(path_coords).float()
if path_coords.dim() == 2:
# Single path, add batch dimension
path_coords = path_coords.unsqueeze(0)
elif isinstance(path_coords, torch.Tensor):
if path_coords.dim() == 2:
# Single path, add batch dimension
path_coords = path_coords.unsqueeze(0)
batch_size = path_coords.shape[0]
elif text is not None:
if isinstance(text, str):
batch_size = 1
text = [text]
else:
batch_size = len(text)
else:
batch_size = 1
result = {}
# Process path coordinates
if path_coords is not None:
current_path_len = path_coords.shape[1]
# Truncate if needed
if truncation and current_path_len > self.max_path_len:
path_coords = path_coords[:, : self.max_path_len, :]
current_path_len = self.max_path_len
# Pad if needed
if padding and current_path_len < self.max_path_len:
pad_len = self.max_path_len - current_path_len
path_coords = torch.cat([path_coords, torch.zeros(batch_size, pad_len, 3)], dim=1)
# Create path mask (1 = real data, 0 = padding)
path_mask = torch.ones(batch_size, self.max_path_len, dtype=torch.long)
if padding and current_path_len < self.max_path_len:
path_mask[:, current_path_len:] = 0
result["path_coords"] = path_coords
# Store path_mask internally for attention_mask construction
_path_mask = path_mask
else:
# No path coords provided, create empty/zero tensors
path_coords = torch.zeros(batch_size, self.max_path_len, 3)
_path_mask = torch.zeros(batch_size, self.max_path_len, dtype=torch.long)
result["path_coords"] = path_coords
# Process text
if text is not None:
# Ensure text is a list
if isinstance(text, str):
text = [text]
# Tokenize text
text_max_length = max_length if max_length is not None else self.max_char_len
# First tokenize without padding/truncation to add EOS
encoded_raw = self.tokenizer(
text,
padding=False,
truncation=False,
return_tensors=None, # Get lists first
**kwargs,
)
# Add EOS token after each word (matching training dataset behavior)
eos_id = self.tokenizer.eos_token_id
for i in range(len(encoded_raw["input_ids"])):
# Add EOS if not already present
if encoded_raw["input_ids"][i][-1] != eos_id:
encoded_raw["input_ids"][i].append(eos_id)
# Now apply padding and truncation
max_len_needed = max(len(ids) for ids in encoded_raw["input_ids"])
if truncation and max_len_needed > text_max_length:
# Truncate but preserve EOS at the end
for i in range(len(encoded_raw["input_ids"])):
if len(encoded_raw["input_ids"][i]) > text_max_length:
encoded_raw["input_ids"][i] = (
encoded_raw["input_ids"][i][: text_max_length - 1] + [eos_id]
)
# Pad sequences
if padding:
pad_id = self.tokenizer.pad_token_id
for i in range(len(encoded_raw["input_ids"])):
seq_len = len(encoded_raw["input_ids"][i])
if seq_len < text_max_length:
encoded_raw["input_ids"][i].extend([pad_id] * (text_max_length - seq_len))
# Create attention mask (1 for real tokens + EOS, 0 for padding)
_char_mask = []
for ids in encoded_raw["input_ids"]:
mask = [1 if token_id != self.tokenizer.pad_token_id else 0 for token_id in ids]
_char_mask.append(mask)
# Convert to tensors if requested
if return_tensors == "pt":
result["input_ids"] = torch.tensor(encoded_raw["input_ids"], dtype=torch.long)
_char_mask = torch.tensor(_char_mask, dtype=torch.long)
elif return_tensors == "np":
result["input_ids"] = np.array(encoded_raw["input_ids"], dtype=np.int64)
_char_mask = np.array(_char_mask, dtype=np.int64)
else:
result["input_ids"] = encoded_raw["input_ids"]
else:
# No text provided, create padding tokens
if return_tensors == "pt":
char_tokens = torch.full(
(batch_size, self.max_char_len), self.tokenizer.pad_token_id, dtype=torch.long
)
_char_mask = torch.zeros(batch_size, self.max_char_len, dtype=torch.long)
elif return_tensors == "np":
char_tokens = np.full(
(batch_size, self.max_char_len), self.tokenizer.pad_token_id, dtype=np.int64
)
_char_mask = np.zeros((batch_size, self.max_char_len), dtype=np.int64)
else:
char_tokens = [
[self.tokenizer.pad_token_id] * self.max_char_len for _ in range(batch_size)
]
_char_mask = [[0] * self.max_char_len for _ in range(batch_size)]
result["input_ids"] = char_tokens
# Create combined attention mask: [CLS] + path + [SEP] + chars
# Sequence structure: [CLS:1] + _path_mask + [SEP:1] + _char_mask
if return_tensors == "pt":
cls_mask = torch.ones(batch_size, 1, dtype=torch.long)
sep_mask = torch.ones(batch_size, 1, dtype=torch.long)
attention_mask = torch.cat([cls_mask, _path_mask, sep_mask, _char_mask], dim=1)
elif return_tensors == "np":
cls_mask = np.ones((batch_size, 1), dtype=np.int64)
sep_mask = np.ones((batch_size, 1), dtype=np.int64)
attention_mask = np.concatenate([cls_mask, _path_mask, sep_mask, _char_mask], axis=1)
else:
cls_mask = [[1] for _ in range(batch_size)]
sep_mask = [[1] for _ in range(batch_size)]
attention_mask = [
cls + path.tolist() + sep + char
for cls, path, sep, char in zip(
cls_mask, _path_mask, sep_mask, _char_mask, strict=False
)
]
result["attention_mask"] = attention_mask
# Convert to requested format
if return_tensors == "np":
for key in result:
if isinstance(result[key], torch.Tensor):
result[key] = result[key].numpy()
elif return_tensors is None:
for key in result:
if isinstance(result[key], torch.Tensor):
result[key] = result[key].tolist()
return result
def batch_decode(self, token_ids, **kwargs):
"""
Decode token IDs to strings.
Args:
token_ids: Token IDs to decode
**kwargs: Additional arguments passed to tokenizer
Returns:
List of decoded strings
"""
return self.tokenizer.batch_decode(token_ids, **kwargs)
def decode(self, token_ids, **kwargs):
"""
Decode single sequence of token IDs to string.
Args:
token_ids: Token IDs to decode
**kwargs: Additional arguments passed to tokenizer
Returns:
Decoded string
"""
return self.tokenizer.decode(token_ids, **kwargs)