File size: 9,385 Bytes
f0e942d
 
c524e88
f0e942d
2a5465f
f0e942d
a6f7d55
f0e942d
 
 
 
 
 
 
 
c1b91a1
 
c524e88
749bc4b
ff3110c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
749bc4b
f0e942d
a25edb3
f0e942d
ff3110c
 
 
 
 
 
 
 
 
 
 
 
 
 
f0e942d
 
 
9db2fe3
f0e942d
ff3110c
 
 
 
f0e942d
 
 
 
 
 
 
 
 
ef45452
ff3110c
 
c524e88
ff3110c
 
 
 
c1b91a1
c524e88
ff3110c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3c6ed71
 
ff3110c
 
 
63432be
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b299f93
 
ff3110c
b299f93
ff3110c
 
 
 
 
 
 
 
 
 
 
4a31fc7
f0e942d
a6f7d55
 
 
 
 
 
 
 
 
eb86c26
 
a6f7d55
eb86c26
a6f7d55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
from torch import Tensor, nn
from transformers import (CLIPTextModel, CLIPTokenizer, T5EncoderModel,
                          T5Tokenizer)
import os
import torch

'''
class HFEmbedder(nn.Module):
    def __init__(self, version: str, max_length: int, is_clip, **hf_kwargs):
        super().__init__()
        self.is_clip = is_clip
        self.max_length = max_length
        self.output_key = "pooler_output" if self.is_clip else "last_hidden_state"

        if self.is_clip:
            #self.tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(version, max_length=max_length, truncation=True)
            self.tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained("/home/user/app/models/tokenizer", max_length=max_length, truncation=True)
            self.hf_module: CLIPTextModel = CLIPTextModel.from_pretrained(version, **hf_kwargs)

            # --- DEBUG 信息 ---
            print(f"--- CLIP Model Info ---")
            print(f"  Requested version/path: {version}")
            print(f"  Tokenizer loaded from: {getattr(self.tokenizer, 'name_or_path', 'Unknown')}")
            print(f"  Model loaded from: {getattr(self.hf_module, 'name_or_path', 'Unknown')}")
            print(f"  Tokenizer max length: {getattr(self.tokenizer, 'model_max_length', 'N/A')}")
            print(f"  Model max position embeddings: {getattr(self.hf_module.config, 'max_position_embeddings', 'N/A')}")
            
            # 关键调试信息:词汇表大小
            tokenizer_vocab_size = len(self.tokenizer.get_vocab()) if hasattr(self.tokenizer, 'get_vocab') else getattr(self.tokenizer, 'vocab_size', 'Unknown')
            print(f"  Tokenizer vocab size (len(get_vocab())): {tokenizer_vocab_size}")
            print(f"  Tokenizer vocab size (attribute): {getattr(self.tokenizer, 'vocab_size', 'N/A')}")
            print(f"  Model config vocab size: {self.hf_module.config.vocab_size}")
            print(f"  Actual model embedding weight shape: {self.hf_module.text_model.embeddings.token_embedding.weight.shape}")
            print(f"-------------------------")

        else:
            self.tokenizer: T5Tokenizer = T5Tokenizer.from_pretrained(version, max_length=max_length, truncation=True)
            self.hf_module: T5EncoderModel = T5EncoderModel.from_pretrained(version, **hf_kwargs)
            
            # --- DEBUG 信息 ---
            print(f"--- T5 Model Info ---")
            print(f"  Requested version/path: {version}")
            print(f"  Tokenizer loaded from: {getattr(self.tokenizer, 'name_or_path', 'Unknown')}")
            print(f"  Model loaded from: {getattr(self.hf_module, 'name_or_path', 'Unknown')}")
            print(f"  Tokenizer max length: {getattr(self.tokenizer, 'model_max_length', 'N/A')}")
            print(f"  Model max position embeddings: {getattr(self.hf_module.config, 'd_model', 'N/A (T5 uses relative pos)')}") # T5 uses relative
            tokenizer_vocab_size = len(self.tokenizer.get_vocab()) if hasattr(self.tokenizer, 'get_vocab') else getattr(self.tokenizer, 'vocab_size', 'Unknown')
            print(f"  Tokenizer vocab size (len(get_vocab())): {tokenizer_vocab_size}")
            print(f"  Tokenizer vocab size (attribute): {getattr(self.tokenizer, 'vocab_size', 'N/A')}")
            print(f"  Model config vocab size: {self.hf_module.config.vocab_size}")
            print(f"  Actual model embedding weight shape: {self.hf_module.encoder.embed_tokens.weight.shape}")
            print(f"----------------------")

        self.hf_module = self.hf_module.eval().requires_grad_(False)


    def forward(self, text: list[str]) -> Tensor:
         # Ensure text is a list
        if isinstance(text, str):
            text = [text]

        batch_encoding = self.tokenizer(
            text,
            truncation=True,
            max_length=self.max_length,
            return_length=False,
            return_overflowing_tokens=False,
            padding="max_length",
            return_tensors="pt",
        )
        print(f'Batch Encoding {batch_encoding}')
        encoder_type = 'clip' if self.is_clip else 't5'
        print(f'Forward pass for {encoder_type}')
        input_ids = batch_encoding["input_ids"]
        print(f"Input IDs shape: {input_ids.shape}, Max Length: {self.max_length}")
        
        # 更严格的断言
        assert input_ids.shape == (len(text), self.max_length), f"Input IDs shape {input_ids.shape} does not match expected ({len(text)}, {self.max_length})"
        #print(f"Input IDs:\n{input_ids}")

        # --- 关键调试:检查输入 ID 范围 ---
        min_id, max_id = input_ids.min().item(), input_ids.max().item()
        print(f"Input IDs range: [{min_id}, {max_id}]")

        vocab_source = "tokenizer" if self.is_clip else "model_config"
        vocab_size = len(self.tokenizer.get_vocab()) if self.is_clip and hasattr(self.tokenizer, 'get_vocab') else self.hf_module.config.vocab_size
        print(f"Vocab size (from {vocab_source}): {vocab_size}")

        if max_id >= vocab_size:
            raise IndexError(f"Found input ID ({max_id}) >= vocab size ({vocab_size}). This will cause an embedding error.")

        if min_id < 0:
             raise IndexError(f"Found negative input ID ({min_id}). This is invalid.")

        # 确保输入在正确的设备上
        input_ids = input_ids.to(self.hf_module.device)
        attention_mask = batch_encoding["attention_mask"].to(self.hf_module.device)

        print(f"Input IDs device: {input_ids.device}")
        print(f"Attention Mask device: {attention_mask.device}")

        # --- FIX FOR CLIP POSITION IDs ---
        # Prepare arguments for the model call
        model_kwargs = {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "output_hidden_states": False,
        }

        # If it's a CLIP model, explicitly generate and pass position_ids
        if self.is_clip:
            # Generate position_ids: [0, 1, 2, ..., max_length-1] for each item in the batch
            # Shape: (batch_size, max_length)
            position_ids = torch.arange(self.max_length, dtype=torch.long, device=input_ids.device).expand(input_ids.size(0), -1)
            print(f"Generated CLIP position_ids: shape={position_ids.shape}, range=[{position_ids.min().item()}, {position_ids.max().item()}]")
            # Check if generated position_ids are within the model's limit
            max_pos_emb = getattr(self.hf_module.config, 'max_position_embeddings', -1)
            if max_pos_emb > 0 and position_ids.max() >= max_pos_emb:
                 raise ValueError(f"Generated position_ids max ({position_ids.max().item()}) >= model's max_position_embeddings ({max_pos_emb})")
            # Pass the explicitly created position_ids to the model
            model_kwargs["position_ids"] = position_ids
        try:
            outputs = self.hf_module(**model_kwargs)
        except IndexError as e:
            # 捕获并提供更详细的错误上下文
            print(f"*** IndexError caught during model forward pass ***")
            print(f"Error: {e}")
            print(f"Input IDs shape: {input_ids.shape}")
            print(f"Input IDs range: [{input_ids.min().item()}, {input_ids.max().item()}]")
            print(f"Model vocab size: {self.hf_module.config.vocab_size}")
            if self.is_clip:
                 print(f"Tokenizer vocab size: {len(self.tokenizer.get_vocab()) if hasattr(self.tokenizer, 'get_vocab') else 'N/A'}")
                 print(f"Embedding layer weight shape: {self.hf_module.text_model.embeddings.token_embedding.weight.shape}")
            raise # Re-raise the error after logging

        return outputs[self.output_key]
'''
class HFEmbedder(nn.Module):
    def __init__(self, version: str, max_length: int, is_clip, **hf_kwargs):
        super().__init__()
        self.is_clip = is_clip
        self.max_length = max_length
        self.output_key = "pooler_output" if self.is_clip else "last_hidden_state"

        if self.is_clip:
            #self.tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(version, max_length=max_length, truncation=True)
            self.tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained("/home/user/app/models/tokenizer", max_length=max_length, truncation=True)
            self.hf_module: CLIPTextModel = CLIPTextModel.from_pretrained(version, **hf_kwargs)

        else:
            self.tokenizer: T5Tokenizer = T5Tokenizer.from_pretrained(version, max_length=max_length)
            #self.tokenizer: T5Tokenizer = T5Tokenizer.from_pretrained("black-forest-labs/FLUX.1-dev/tokenizer_2", max_length=max_length)
            self.hf_module: T5EncoderModel = T5EncoderModel.from_pretrained(version, **hf_kwargs)

        self.hf_module = self.hf_module.eval().requires_grad_(False)

    def forward(self, text: list[str]) -> Tensor:
        batch_encoding = self.tokenizer(
            text,
            truncation=True,
            max_length=self.max_length,
            return_length=False,
            return_overflowing_tokens=False,
            padding="max_length",
            return_tensors="pt",
        )

        outputs = self.hf_module(
            input_ids=batch_encoding["input_ids"].to(self.hf_module.device),
            attention_mask=None,
            output_hidden_states=False,
        )
        return outputs[self.output_key]