Spaces:
Running
on
Zero
Running
on
Zero
| from torch import Tensor, nn | |
| from transformers import (CLIPTextModel, CLIPTokenizer, T5EncoderModel, | |
| T5Tokenizer) | |
| import os | |
| import torch | |
| ''' | |
| class HFEmbedder(nn.Module): | |
| def __init__(self, version: str, max_length: int, is_clip, **hf_kwargs): | |
| super().__init__() | |
| self.is_clip = is_clip | |
| self.max_length = max_length | |
| self.output_key = "pooler_output" if self.is_clip else "last_hidden_state" | |
| if self.is_clip: | |
| #self.tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(version, max_length=max_length, truncation=True) | |
| self.tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained("/home/user/app/models/tokenizer", max_length=max_length, truncation=True) | |
| self.hf_module: CLIPTextModel = CLIPTextModel.from_pretrained(version, **hf_kwargs) | |
| # --- DEBUG 信息 --- | |
| print(f"--- CLIP Model Info ---") | |
| print(f" Requested version/path: {version}") | |
| print(f" Tokenizer loaded from: {getattr(self.tokenizer, 'name_or_path', 'Unknown')}") | |
| print(f" Model loaded from: {getattr(self.hf_module, 'name_or_path', 'Unknown')}") | |
| print(f" Tokenizer max length: {getattr(self.tokenizer, 'model_max_length', 'N/A')}") | |
| print(f" Model max position embeddings: {getattr(self.hf_module.config, 'max_position_embeddings', 'N/A')}") | |
| # 关键调试信息:词汇表大小 | |
| tokenizer_vocab_size = len(self.tokenizer.get_vocab()) if hasattr(self.tokenizer, 'get_vocab') else getattr(self.tokenizer, 'vocab_size', 'Unknown') | |
| print(f" Tokenizer vocab size (len(get_vocab())): {tokenizer_vocab_size}") | |
| print(f" Tokenizer vocab size (attribute): {getattr(self.tokenizer, 'vocab_size', 'N/A')}") | |
| print(f" Model config vocab size: {self.hf_module.config.vocab_size}") | |
| print(f" Actual model embedding weight shape: {self.hf_module.text_model.embeddings.token_embedding.weight.shape}") | |
| print(f"-------------------------") | |
| else: | |
| self.tokenizer: T5Tokenizer = T5Tokenizer.from_pretrained(version, max_length=max_length, truncation=True) | |
| self.hf_module: T5EncoderModel = T5EncoderModel.from_pretrained(version, **hf_kwargs) | |
| # --- DEBUG 信息 --- | |
| print(f"--- T5 Model Info ---") | |
| print(f" Requested version/path: {version}") | |
| print(f" Tokenizer loaded from: {getattr(self.tokenizer, 'name_or_path', 'Unknown')}") | |
| print(f" Model loaded from: {getattr(self.hf_module, 'name_or_path', 'Unknown')}") | |
| print(f" Tokenizer max length: {getattr(self.tokenizer, 'model_max_length', 'N/A')}") | |
| print(f" Model max position embeddings: {getattr(self.hf_module.config, 'd_model', 'N/A (T5 uses relative pos)')}") # T5 uses relative | |
| tokenizer_vocab_size = len(self.tokenizer.get_vocab()) if hasattr(self.tokenizer, 'get_vocab') else getattr(self.tokenizer, 'vocab_size', 'Unknown') | |
| print(f" Tokenizer vocab size (len(get_vocab())): {tokenizer_vocab_size}") | |
| print(f" Tokenizer vocab size (attribute): {getattr(self.tokenizer, 'vocab_size', 'N/A')}") | |
| print(f" Model config vocab size: {self.hf_module.config.vocab_size}") | |
| print(f" Actual model embedding weight shape: {self.hf_module.encoder.embed_tokens.weight.shape}") | |
| print(f"----------------------") | |
| self.hf_module = self.hf_module.eval().requires_grad_(False) | |
| def forward(self, text: list[str]) -> Tensor: | |
| # Ensure text is a list | |
| if isinstance(text, str): | |
| text = [text] | |
| batch_encoding = self.tokenizer( | |
| text, | |
| truncation=True, | |
| max_length=self.max_length, | |
| return_length=False, | |
| return_overflowing_tokens=False, | |
| padding="max_length", | |
| return_tensors="pt", | |
| ) | |
| print(f'Batch Encoding {batch_encoding}') | |
| encoder_type = 'clip' if self.is_clip else 't5' | |
| print(f'Forward pass for {encoder_type}') | |
| input_ids = batch_encoding["input_ids"] | |
| print(f"Input IDs shape: {input_ids.shape}, Max Length: {self.max_length}") | |
| # 更严格的断言 | |
| assert input_ids.shape == (len(text), self.max_length), f"Input IDs shape {input_ids.shape} does not match expected ({len(text)}, {self.max_length})" | |
| #print(f"Input IDs:\n{input_ids}") | |
| # --- 关键调试:检查输入 ID 范围 --- | |
| min_id, max_id = input_ids.min().item(), input_ids.max().item() | |
| print(f"Input IDs range: [{min_id}, {max_id}]") | |
| vocab_source = "tokenizer" if self.is_clip else "model_config" | |
| vocab_size = len(self.tokenizer.get_vocab()) if self.is_clip and hasattr(self.tokenizer, 'get_vocab') else self.hf_module.config.vocab_size | |
| print(f"Vocab size (from {vocab_source}): {vocab_size}") | |
| if max_id >= vocab_size: | |
| raise IndexError(f"Found input ID ({max_id}) >= vocab size ({vocab_size}). This will cause an embedding error.") | |
| if min_id < 0: | |
| raise IndexError(f"Found negative input ID ({min_id}). This is invalid.") | |
| # 确保输入在正确的设备上 | |
| input_ids = input_ids.to(self.hf_module.device) | |
| attention_mask = batch_encoding["attention_mask"].to(self.hf_module.device) | |
| print(f"Input IDs device: {input_ids.device}") | |
| print(f"Attention Mask device: {attention_mask.device}") | |
| # --- FIX FOR CLIP POSITION IDs --- | |
| # Prepare arguments for the model call | |
| model_kwargs = { | |
| "input_ids": input_ids, | |
| "attention_mask": attention_mask, | |
| "output_hidden_states": False, | |
| } | |
| # If it's a CLIP model, explicitly generate and pass position_ids | |
| if self.is_clip: | |
| # Generate position_ids: [0, 1, 2, ..., max_length-1] for each item in the batch | |
| # Shape: (batch_size, max_length) | |
| position_ids = torch.arange(self.max_length, dtype=torch.long, device=input_ids.device).expand(input_ids.size(0), -1) | |
| print(f"Generated CLIP position_ids: shape={position_ids.shape}, range=[{position_ids.min().item()}, {position_ids.max().item()}]") | |
| # Check if generated position_ids are within the model's limit | |
| max_pos_emb = getattr(self.hf_module.config, 'max_position_embeddings', -1) | |
| if max_pos_emb > 0 and position_ids.max() >= max_pos_emb: | |
| raise ValueError(f"Generated position_ids max ({position_ids.max().item()}) >= model's max_position_embeddings ({max_pos_emb})") | |
| # Pass the explicitly created position_ids to the model | |
| model_kwargs["position_ids"] = position_ids | |
| try: | |
| outputs = self.hf_module(**model_kwargs) | |
| except IndexError as e: | |
| # 捕获并提供更详细的错误上下文 | |
| print(f"*** IndexError caught during model forward pass ***") | |
| print(f"Error: {e}") | |
| print(f"Input IDs shape: {input_ids.shape}") | |
| print(f"Input IDs range: [{input_ids.min().item()}, {input_ids.max().item()}]") | |
| print(f"Model vocab size: {self.hf_module.config.vocab_size}") | |
| if self.is_clip: | |
| print(f"Tokenizer vocab size: {len(self.tokenizer.get_vocab()) if hasattr(self.tokenizer, 'get_vocab') else 'N/A'}") | |
| print(f"Embedding layer weight shape: {self.hf_module.text_model.embeddings.token_embedding.weight.shape}") | |
| raise # Re-raise the error after logging | |
| return outputs[self.output_key] | |
| ''' | |
| class HFEmbedder(nn.Module): | |
| def __init__(self, version: str, max_length: int, is_clip, **hf_kwargs): | |
| super().__init__() | |
| self.is_clip = is_clip | |
| self.max_length = max_length | |
| self.output_key = "pooler_output" if self.is_clip else "last_hidden_state" | |
| if self.is_clip: | |
| #self.tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(version, max_length=max_length, truncation=True) | |
| self.tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained("/home/user/app/models/tokenizer", max_length=max_length, truncation=True) | |
| self.hf_module: CLIPTextModel = CLIPTextModel.from_pretrained(version, **hf_kwargs) | |
| else: | |
| self.tokenizer: T5Tokenizer = T5Tokenizer.from_pretrained(version, max_length=max_length) | |
| #self.tokenizer: T5Tokenizer = T5Tokenizer.from_pretrained("black-forest-labs/FLUX.1-dev/tokenizer_2", max_length=max_length) | |
| self.hf_module: T5EncoderModel = T5EncoderModel.from_pretrained(version, **hf_kwargs) | |
| self.hf_module = self.hf_module.eval().requires_grad_(False) | |
| def forward(self, text: list[str]) -> Tensor: | |
| batch_encoding = self.tokenizer( | |
| text, | |
| truncation=True, | |
| max_length=self.max_length, | |
| return_length=False, | |
| return_overflowing_tokens=False, | |
| padding="max_length", | |
| return_tensors="pt", | |
| ) | |
| outputs = self.hf_module( | |
| input_ids=batch_encoding["input_ids"].to(self.hf_module.device), | |
| attention_mask=None, | |
| output_hidden_states=False, | |
| ) | |
| return outputs[self.output_key] |