ZhouZJ36DL's picture
modified: src/flux/modules/conditioner.py
eb86c26
from torch import Tensor, nn
from transformers import (CLIPTextModel, CLIPTokenizer, T5EncoderModel,
T5Tokenizer)
import os
import torch
'''
class HFEmbedder(nn.Module):
def __init__(self, version: str, max_length: int, is_clip, **hf_kwargs):
super().__init__()
self.is_clip = is_clip
self.max_length = max_length
self.output_key = "pooler_output" if self.is_clip else "last_hidden_state"
if self.is_clip:
#self.tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(version, max_length=max_length, truncation=True)
self.tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained("/home/user/app/models/tokenizer", max_length=max_length, truncation=True)
self.hf_module: CLIPTextModel = CLIPTextModel.from_pretrained(version, **hf_kwargs)
# --- DEBUG 信息 ---
print(f"--- CLIP Model Info ---")
print(f" Requested version/path: {version}")
print(f" Tokenizer loaded from: {getattr(self.tokenizer, 'name_or_path', 'Unknown')}")
print(f" Model loaded from: {getattr(self.hf_module, 'name_or_path', 'Unknown')}")
print(f" Tokenizer max length: {getattr(self.tokenizer, 'model_max_length', 'N/A')}")
print(f" Model max position embeddings: {getattr(self.hf_module.config, 'max_position_embeddings', 'N/A')}")
# 关键调试信息:词汇表大小
tokenizer_vocab_size = len(self.tokenizer.get_vocab()) if hasattr(self.tokenizer, 'get_vocab') else getattr(self.tokenizer, 'vocab_size', 'Unknown')
print(f" Tokenizer vocab size (len(get_vocab())): {tokenizer_vocab_size}")
print(f" Tokenizer vocab size (attribute): {getattr(self.tokenizer, 'vocab_size', 'N/A')}")
print(f" Model config vocab size: {self.hf_module.config.vocab_size}")
print(f" Actual model embedding weight shape: {self.hf_module.text_model.embeddings.token_embedding.weight.shape}")
print(f"-------------------------")
else:
self.tokenizer: T5Tokenizer = T5Tokenizer.from_pretrained(version, max_length=max_length, truncation=True)
self.hf_module: T5EncoderModel = T5EncoderModel.from_pretrained(version, **hf_kwargs)
# --- DEBUG 信息 ---
print(f"--- T5 Model Info ---")
print(f" Requested version/path: {version}")
print(f" Tokenizer loaded from: {getattr(self.tokenizer, 'name_or_path', 'Unknown')}")
print(f" Model loaded from: {getattr(self.hf_module, 'name_or_path', 'Unknown')}")
print(f" Tokenizer max length: {getattr(self.tokenizer, 'model_max_length', 'N/A')}")
print(f" Model max position embeddings: {getattr(self.hf_module.config, 'd_model', 'N/A (T5 uses relative pos)')}") # T5 uses relative
tokenizer_vocab_size = len(self.tokenizer.get_vocab()) if hasattr(self.tokenizer, 'get_vocab') else getattr(self.tokenizer, 'vocab_size', 'Unknown')
print(f" Tokenizer vocab size (len(get_vocab())): {tokenizer_vocab_size}")
print(f" Tokenizer vocab size (attribute): {getattr(self.tokenizer, 'vocab_size', 'N/A')}")
print(f" Model config vocab size: {self.hf_module.config.vocab_size}")
print(f" Actual model embedding weight shape: {self.hf_module.encoder.embed_tokens.weight.shape}")
print(f"----------------------")
self.hf_module = self.hf_module.eval().requires_grad_(False)
def forward(self, text: list[str]) -> Tensor:
# Ensure text is a list
if isinstance(text, str):
text = [text]
batch_encoding = self.tokenizer(
text,
truncation=True,
max_length=self.max_length,
return_length=False,
return_overflowing_tokens=False,
padding="max_length",
return_tensors="pt",
)
print(f'Batch Encoding {batch_encoding}')
encoder_type = 'clip' if self.is_clip else 't5'
print(f'Forward pass for {encoder_type}')
input_ids = batch_encoding["input_ids"]
print(f"Input IDs shape: {input_ids.shape}, Max Length: {self.max_length}")
# 更严格的断言
assert input_ids.shape == (len(text), self.max_length), f"Input IDs shape {input_ids.shape} does not match expected ({len(text)}, {self.max_length})"
#print(f"Input IDs:\n{input_ids}")
# --- 关键调试:检查输入 ID 范围 ---
min_id, max_id = input_ids.min().item(), input_ids.max().item()
print(f"Input IDs range: [{min_id}, {max_id}]")
vocab_source = "tokenizer" if self.is_clip else "model_config"
vocab_size = len(self.tokenizer.get_vocab()) if self.is_clip and hasattr(self.tokenizer, 'get_vocab') else self.hf_module.config.vocab_size
print(f"Vocab size (from {vocab_source}): {vocab_size}")
if max_id >= vocab_size:
raise IndexError(f"Found input ID ({max_id}) >= vocab size ({vocab_size}). This will cause an embedding error.")
if min_id < 0:
raise IndexError(f"Found negative input ID ({min_id}). This is invalid.")
# 确保输入在正确的设备上
input_ids = input_ids.to(self.hf_module.device)
attention_mask = batch_encoding["attention_mask"].to(self.hf_module.device)
print(f"Input IDs device: {input_ids.device}")
print(f"Attention Mask device: {attention_mask.device}")
# --- FIX FOR CLIP POSITION IDs ---
# Prepare arguments for the model call
model_kwargs = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"output_hidden_states": False,
}
# If it's a CLIP model, explicitly generate and pass position_ids
if self.is_clip:
# Generate position_ids: [0, 1, 2, ..., max_length-1] for each item in the batch
# Shape: (batch_size, max_length)
position_ids = torch.arange(self.max_length, dtype=torch.long, device=input_ids.device).expand(input_ids.size(0), -1)
print(f"Generated CLIP position_ids: shape={position_ids.shape}, range=[{position_ids.min().item()}, {position_ids.max().item()}]")
# Check if generated position_ids are within the model's limit
max_pos_emb = getattr(self.hf_module.config, 'max_position_embeddings', -1)
if max_pos_emb > 0 and position_ids.max() >= max_pos_emb:
raise ValueError(f"Generated position_ids max ({position_ids.max().item()}) >= model's max_position_embeddings ({max_pos_emb})")
# Pass the explicitly created position_ids to the model
model_kwargs["position_ids"] = position_ids
try:
outputs = self.hf_module(**model_kwargs)
except IndexError as e:
# 捕获并提供更详细的错误上下文
print(f"*** IndexError caught during model forward pass ***")
print(f"Error: {e}")
print(f"Input IDs shape: {input_ids.shape}")
print(f"Input IDs range: [{input_ids.min().item()}, {input_ids.max().item()}]")
print(f"Model vocab size: {self.hf_module.config.vocab_size}")
if self.is_clip:
print(f"Tokenizer vocab size: {len(self.tokenizer.get_vocab()) if hasattr(self.tokenizer, 'get_vocab') else 'N/A'}")
print(f"Embedding layer weight shape: {self.hf_module.text_model.embeddings.token_embedding.weight.shape}")
raise # Re-raise the error after logging
return outputs[self.output_key]
'''
class HFEmbedder(nn.Module):
def __init__(self, version: str, max_length: int, is_clip, **hf_kwargs):
super().__init__()
self.is_clip = is_clip
self.max_length = max_length
self.output_key = "pooler_output" if self.is_clip else "last_hidden_state"
if self.is_clip:
#self.tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(version, max_length=max_length, truncation=True)
self.tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained("/home/user/app/models/tokenizer", max_length=max_length, truncation=True)
self.hf_module: CLIPTextModel = CLIPTextModel.from_pretrained(version, **hf_kwargs)
else:
self.tokenizer: T5Tokenizer = T5Tokenizer.from_pretrained(version, max_length=max_length)
#self.tokenizer: T5Tokenizer = T5Tokenizer.from_pretrained("black-forest-labs/FLUX.1-dev/tokenizer_2", max_length=max_length)
self.hf_module: T5EncoderModel = T5EncoderModel.from_pretrained(version, **hf_kwargs)
self.hf_module = self.hf_module.eval().requires_grad_(False)
def forward(self, text: list[str]) -> Tensor:
batch_encoding = self.tokenizer(
text,
truncation=True,
max_length=self.max_length,
return_length=False,
return_overflowing_tokens=False,
padding="max_length",
return_tensors="pt",
)
outputs = self.hf_module(
input_ids=batch_encoding["input_ids"].to(self.hf_module.device),
attention_mask=None,
output_hidden_states=False,
)
return outputs[self.output_key]