GPokeT2 / conditioned_gpt2.py
iamthinbaker's picture
fix: update chars_in_row before computing remaining to prevent 64th pixel per row
eb4af6f verified
Raw
History Blame Contribute Delete
21.2 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from transformers import GPT2Config, GPT2LMHeadModel, LogitsProcessor
NUM_GENERATIONS = 10
NUM_TYPES1 = 19 # 18 tipos + UNK
NUM_TYPES2 = 20 # 18 tipos + NONE + UNK
NUM_EVO_STAGES = 4 # 3 etapas + 1
NUM_HAS_EVOLUTION = 2
NUM_COLOR_SHIFTS = 6 # 0 = no shift, 1-5 = ColorShift permutations
_TYPES = [
"normal",
"fire",
"water",
"electric",
"grass",
"ice",
"fighting",
"poison",
"ground",
"flying",
"psychic",
"bug",
"rock",
"ghost",
"dragon",
"dark",
"steel",
"fairy",
]
_TYPE1_UNK = 18 # unknown primary type
_TYPE2_NONE = 18 # no secondary type
_TYPE2_UNK = 19 # unknown secondary type
def _resolve_type(
val: "str | int | None", default_none_idx: int
) -> "int | None":
if val is None:
return None
if isinstance(val, str):
return (
_TYPES.index(val.lower())
if val.lower() in _TYPES
else default_none_idx
)
return int(val)
class ConditionedGPT2(GPT2LMHeadModel):
def __init__(
self,
config: GPT2Config,
num_pokemon: int | None = None,
noise_std: float = 0.1,
row_marker_token_ids: list[int] | None = None,
num_types1: int = NUM_TYPES1,
num_types2: int = NUM_TYPES2,
num_generations: int = NUM_GENERATIONS,
num_evo_stages: int = NUM_EVO_STAGES,
token_weights: torch.Tensor | None = None,
num_color_shifts: int = NUM_COLOR_SHIFTS,
):
num_pokemon = num_pokemon or getattr(config, "num_pokemon", None)
if num_pokemon is None:
raise ValueError(
"num_pokemon must be provided or present in config"
)
config.num_pokemon = num_pokemon
super().__init__(config)
self.conditioning = nn.Embedding(num_pokemon, config.n_embd)
nn.init.normal_(self.conditioning.weight, std=0.02)
self.noise_std = noise_std
# Row embedding: 0-63 for sprite rows, 64 = padding (BOS/EOS/pre-row tokens)
self.row_emb = nn.Embedding(65, config.n_embd, padding_idx=64)
nn.init.normal_(self.row_emb.weight, std=0.02)
self.row_emb.weight.data[64].zero_()
# Column embedding: 0-63 for pixel position within a row, 64 = padding
self.col_emb = nn.Embedding(65, config.n_embd, padding_idx=64)
nn.init.normal_(self.col_emb.weight, std=0.02)
self.col_emb.weight.data[64].zero_()
# Metadata conditioning embeddings
self.type1_emb = nn.Embedding(num_types1, config.n_embd)
self.type2_emb = nn.Embedding(num_types2, config.n_embd)
self.is_shiny_emb = nn.Embedding(NUM_HAS_EVOLUTION, config.n_embd)
self.generation_emb = nn.Embedding(num_generations, config.n_embd)
self.evo_stage_emb = nn.Embedding(num_evo_stages, config.n_embd)
self.has_evolution_emb = nn.Embedding(NUM_HAS_EVOLUTION, config.n_embd)
self.color_shift_emb = nn.Embedding(num_color_shifts, config.n_embd)
for emb in (
self.type1_emb,
self.type2_emb,
self.is_shiny_emb,
self.generation_emb,
self.evo_stage_emb,
self.has_evolution_emb,
self.color_shift_emb,
):
nn.init.normal_(emb.weight, std=0.02)
# Per-token loss weights — downweights background tokens to focus on color pixels.
# Never None so PyTorch always includes it in state_dict and from_pretrained
# can load it from checkpoint without an "unexpected key" warning.
if token_weights is None:
token_weights = torch.ones(config.vocab_size)
self.register_buffer("token_weights", token_weights)
# Store row marker token ids as a buffer so they're saved with the model
_ids = row_marker_token_ids or [0] * 64
self.register_buffer(
"row_marker_ids",
torch.tensor(_ids, dtype=torch.long),
)
def _ids_to_row_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
B, T = input_ids.shape
device = input_ids.device
row_ids = input_ids.new_full((B, T), 64)
for row_idx in range(64):
row_ids[input_ids == self.row_marker_ids[row_idx]] = row_idx
# Vectorized forward-fill via cummax: for each position find the last marker seen
is_assigned = row_ids < 64
t_idx = torch.arange(T, device=device).unsqueeze(0).expand(B, -1)
last_marker_t, _ = torch.where(
is_assigned,
t_idx,
torch.zeros_like(t_idx),
).cummax(dim=1)
row_ids_filled = torch.gather(row_ids, 1, last_marker_t)
in_row = is_assigned.long().cumsum(dim=1) >= 1
return torch.where(
in_row,
row_ids_filled,
input_ids.new_full((B, T), 64),
)
def _ids_to_col_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
B, T = input_ids.shape
device = input_ids.device
is_marker = torch.isin(input_ids, self.row_marker_ids.to(device))
is_pixel = ~is_marker
in_row = is_marker.long().cumsum(dim=1) >= 1
# Cumulative pixel count (inclusive) and the baseline at the last marker
pixel_cumsum = is_pixel.long().cumsum(dim=1)
marker_base = torch.where(
is_marker,
pixel_cumsum,
torch.zeros_like(pixel_cumsum),
)
t_idx = torch.arange(T, device=device).unsqueeze(0).expand(B, -1)
last_marker_t, _ = torch.where(
is_marker,
t_idx,
torch.zeros_like(t_idx),
).cummax(dim=1)
last_marker_base = torch.gather(marker_base, 1, last_marker_t)
# 0-indexed column position within the current row
col_pos = pixel_cumsum - last_marker_base - 1
return torch.where(
is_pixel & in_row & (col_pos < 64),
col_pos.clamp(min=0),
input_ids.new_full((B, T), 64),
)
@torch.compiler.disable
def forward(
self,
input_ids=None,
attention_mask=None,
pokemon_idx=None,
pokemon_cond=None,
row_ids=None,
col_ids=None,
type1=None,
type2=None,
is_shiny=None,
generation=None,
evolution_stage=None,
has_evolution=None,
color_shift=None,
logits_to_keep: int | torch.Tensor = 0,
num_items_in_batch=None,
**kwargs,
):
# Extract labels before passing to parent so we can compute our own loss
labels = kwargs.pop("labels", None)
if input_ids is not None and (
pokemon_idx is not None or pokemon_cond is not None
):
token_embs = self.transformer.wte(input_ids)
# Spatial 2-D embeddings: row (which row) + col (which pixel position)
if row_ids is None:
row_ids = self._ids_to_row_ids(input_ids)
if col_ids is None:
col_ids = self._ids_to_col_ids(input_ids)
token_embs = (
token_embs + self.row_emb(row_ids) + self.col_emb(col_ids)
)
# Combined conditioning: pokemon identity + metadata
B, device = token_embs.shape[0], token_embs.device
def _rand_or_use(val, emb):
if val is None:
val = torch.randint(
0,
emb.num_embeddings,
(B,),
device=device,
)
return emb(val)
# pokemon_cond allows passing a pre-computed conditioning vector directly
# (used for novel Pokémon generation via embedding interpolation)
base_cond = (
pokemon_cond
if pokemon_cond is not None
else self.conditioning(pokemon_idx)
)
cond = (
base_cond
+ _rand_or_use(type1, self.type1_emb)
+ _rand_or_use(type2, self.type2_emb)
+ _rand_or_use(is_shiny, self.is_shiny_emb)
+ _rand_or_use(generation, self.generation_emb)
+ _rand_or_use(evolution_stage, self.evo_stage_emb)
+ _rand_or_use(has_evolution, self.has_evolution_emb)
+ _rand_or_use(color_shift, self.color_shift_emb)
)
if self.training and self.noise_std > 0:
cond = cond + torch.randn_like(cond) * self.noise_std
token_embs = token_embs + cond.unsqueeze(1)
kwargs["inputs_embeds"] = token_embs
input_ids = None
outputs = super().forward(
input_ids=input_ids,
attention_mask=attention_mask,
logits_to_keep=logits_to_keep,
**kwargs,
)
if labels is not None:
shift_logits = outputs.logits[..., :-1, :].contiguous().float()
shift_labels = (
labels[..., 1:].contiguous().to(outputs.logits.device)
)
weights = self.token_weights.to(outputs.logits.device)
loss = F.cross_entropy(
shift_logits.view(-1, self.config.vocab_size),
shift_labels.view(-1),
weight=weights,
ignore_index=-100,
reduction="mean",
)
outputs.loss = torch.nan_to_num(loss, nan=0.0)
return outputs
def prepare_inputs_for_generation(
self,
input_ids,
past_key_values=None,
**kwargs,
):
# Compute positional ids from the full sequence before the parent trims for KV cache
row_ids_full = self._ids_to_row_ids(input_ids)
col_ids_full = self._ids_to_col_ids(input_ids)
inputs = super().prepare_inputs_for_generation(
input_ids,
past_key_values=past_key_values,
**kwargs,
)
# Trim to last token when KV cache is active (same logic as parent trims input_ids)
if past_key_values is not None:
inputs["row_ids"] = row_ids_full[:, -1:]
inputs["col_ids"] = col_ids_full[:, -1:]
else:
inputs["row_ids"] = row_ids_full
inputs["col_ids"] = col_ids_full
for field in (
"pokemon_idx",
"pokemon_cond",
"type1",
"type2",
"is_shiny",
"generation",
"evolution_stage",
"has_evolution",
"color_shift",
):
if field in kwargs:
inputs[field] = kwargs[field]
return inputs
def sample_conditioning(self, idx: int | None = None) -> torch.Tensor:
if idx is None:
idx = torch.randint(
0,
self.conditioning.num_embeddings,
(1,),
).item()
with torch.no_grad():
return self.conditioning(
torch.tensor([idx], device=self.conditioning.weight.device),
)
def sample_random_conditioning(self, device: str = "cpu") -> dict:
return {
"pokemon_idx": torch.randint(
0,
self.conditioning.num_embeddings,
(1,),
device=device,
),
"type1": torch.randint(
0,
self.type1_emb.num_embeddings,
(1,),
device=device,
),
"type2": torch.randint(
0,
self.type2_emb.num_embeddings,
(1,),
device=device,
),
"is_shiny": torch.randint(
0,
self.is_shiny_emb.num_embeddings,
(1,),
device=device,
),
"generation": torch.randint(
0,
self.generation_emb.num_embeddings,
(1,),
device=device,
),
"evolution_stage": torch.randint(
0,
self.evo_stage_emb.num_embeddings,
(1,),
device=device,
),
"has_evolution": torch.randint(
0,
self.has_evolution_emb.num_embeddings,
(1,),
device=device,
),
"color_shift": torch.randint(
0,
NUM_COLOR_SHIFTS,
(1,),
dtype=torch.long,
device=device,
),
}
def sample_novel_conditioning(
self,
n_mix: int = 3,
device: str = "cpu",
) -> dict:
"""Blend n_mix random Pokémon embeddings to produce a novel conditioning vector."""
with torch.no_grad():
idxs = torch.randint(
0,
self.conditioning.num_embeddings,
(n_mix,),
device=device,
)
weights = torch.softmax(torch.randn(n_mix, device=device), dim=0)
pokemon_cond = (weights.unsqueeze(1) * self.conditioning(idxs)).sum(
0,
keepdim=True,
)
return {
"pokemon_cond": pokemon_cond,
"type1": torch.randint(
0,
self.type1_emb.num_embeddings,
(1,),
device=device,
),
"type2": torch.randint(
0,
self.type2_emb.num_embeddings,
(1,),
device=device,
),
"is_shiny": torch.randint(
0,
self.is_shiny_emb.num_embeddings,
(1,),
device=device,
),
"generation": torch.randint(
0,
self.generation_emb.num_embeddings,
(1,),
device=device,
),
"evolution_stage": torch.randint(
0,
self.evo_stage_emb.num_embeddings,
(1,),
device=device,
),
"has_evolution": torch.randint(
0,
self.has_evolution_emb.num_embeddings,
(1,),
device=device,
),
"color_shift": torch.randint(
0,
NUM_COLOR_SHIFTS,
(1,),
dtype=torch.long,
device=device,
),
}
def generate_sprite(
self,
tokenizer,
temperature: float = 1.2,
top_p: float = 0.95,
verbose: bool = False,
type1: "str | int | None" = None,
type2: "str | int | None" = None,
) -> np.ndarray:
"""Generate a 64×64 RGB sprite as a numpy array.
type1 / type2 accept a type name ("fire", "water", …) or its integer
index. Pass None to pick a random value.
"""
device = next(self.parameters()).device
cond = self.sample_novel_conditioning(device=str(device))
cond["color_shift"] = torch.tensor([0], dtype=torch.long, device=device)
cond["is_shiny"] = torch.tensor([0], dtype=torch.long, device=device)
cond["generation"] = torch.tensor(
[2 + torch.randint(0, 2, (1,)).item()],
dtype=torch.long,
device=device,
)
t1 = _resolve_type(type1, _TYPE1_UNK)
t2 = _resolve_type(type2, _TYPE2_NONE)
if t1 is not None:
cond["type1"] = torch.tensor([t1], dtype=torch.long, device=device)
if t2 is not None:
cond["type2"] = torch.tensor([t2], dtype=torch.long, device=device)
row_marker_ids = [
tokenizer.convert_tokens_to_ids(f"[ROW_{i:02d}]") for i in range(64)
]
inputs = tokenizer(
"[ROW_00]",
return_tensors="pt",
add_special_tokens=False,
)
inputs = {k: v.to(device) for k, v in inputs.items()}
inputs.update(cond)
streamer = None
if verbose:
from transformers import TextStreamer
class _CompactStreamer(TextStreamer):
def on_finalized_text(
self, text: str, stream_end: bool = False
):
prefix = "\n" if "[ROW" in text else ""
print(
prefix + text,
end="" if not stream_end else "\n",
flush=True,
)
streamer = _CompactStreamer(tokenizer, skip_special_tokens=False)
with torch.no_grad():
output_ids = self.generate(
**inputs,
max_length=4096,
do_sample=True,
top_k=0,
top_p=top_p,
temperature=temperature,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
logits_processor=[
_RowLengthLogitsProcessor(tokenizer, row_marker_ids),
],
streamer=streamer,
)
return _tokens_to_image(
tokenizer.decode(output_ids[0], skip_special_tokens=False),
)
class _RowLengthLogitsProcessor(LogitsProcessor):
def __init__(
self,
tokenizer,
row_marker_ids: list[int],
row_width: int = 63,
):
vocab = tokenizer.get_vocab()
self.row_marker_set = set(row_marker_ids)
self.row_marker_ids = row_marker_ids
self.eos_id = tokenizer.eos_token_id
self.bos_id = tokenizer.bos_token_id
self.row_width = row_width
special = {
tokenizer.eos_token_id,
tokenizer.bos_token_id,
tokenizer.pad_token_id,
*row_marker_ids,
}
self.pixel_len = torch.zeros(len(vocab), dtype=torch.long)
for token, idx in vocab.items():
if idx not in special:
self.pixel_len[idx] = len(token)
self.current_row = self.chars_in_row = 0
def __call__(
self,
input_ids: torch.Tensor,
scores: torch.Tensor,
) -> torch.Tensor:
device = scores.device
pixel_len = self.pixel_len.to(device)
last_id = int(input_ids[0, -1].item())
if last_id in self.row_marker_set:
self.current_row = self.row_marker_ids.index(last_id)
self.chars_in_row = 0
elif last_id not in {self.eos_id, self.bos_id}:
self.chars_in_row += int(pixel_len[last_id].item())
remaining = self.row_width - self.chars_in_row
if remaining > 0:
mask = (pixel_len > remaining) | (pixel_len == 0)
if self.current_row < 63:
mask[self.eos_id] = True
scores = scores.masked_fill(mask.unsqueeze(0), float("-inf"))
else:
allowed = torch.full((scores.shape[-1],), True, device=device)
if self.current_row < 63:
allowed[self.row_marker_ids[self.current_row + 1]] = False
else:
allowed[self.eos_id] = False
scores = scores.masked_fill(allowed.unsqueeze(0), float("-inf"))
return scores
def _tokens_to_image(text: str) -> np.ndarray:
rows, current = [], []
for token in text.split():
if token.startswith("[ROW_") and token.endswith("]"):
if current:
rows.append(current)
current = []
elif not (token.startswith("[") and token.endswith("]")):
current.append(token)
if current:
rows.append(current)
image = np.zeros((64, 64, 3), dtype=np.uint8)
for y, row in enumerate(rows[:64]):
for x, char in enumerate(row[:64]):
if char == "~":
image[y, x] = [255, 255, 255]
else:
idx = ord(char) - 59
r, g, b = idx // 16, (idx // 4) % 4, idx % 4
image[y, x] = [r * 64 + 32, g * 64 + 32, b * 64 + 32]
return image
if __name__ == "__main__":
from pathlib import Path
import cv2
from huggingface_hub import snapshot_download
from transformers import AutoModelForCausalLM, PreTrainedTokenizerFast
REPO_ID = "iamthinbaker/GPokeT2"
REVISION = "v0.1-wip-3400"
N_SAMPLES = 5
OUTPUT_DIR = Path("generated")
print(f"Descargando {REPO_ID} ({REVISION})...")
ckpt = snapshot_download(REPO_ID, revision=REVISION)
tokenizer = PreTrainedTokenizerFast.from_pretrained(ckpt)
model = AutoModelForCausalLM.from_pretrained(ckpt, trust_remote_code=True)
model.eval()
OUTPUT_DIR.mkdir(exist_ok=True)
for i in range(N_SAMPLES):
print(f"\n[{i + 1}/{N_SAMPLES}] generando...")
image = model.generate_sprite(tokenizer, verbose=True)
path = OUTPUT_DIR / f"pokemon_{i + 1:02d}.png"
cv2.imwrite(str(path), cv2.cvtColor(image, cv2.COLOR_RGB2BGR))
print(f"guardado en {path}")