File size: 5,482 Bytes
4960ef6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
"""
Qwen Text Encoder
Qwen 文本编码器 - 使用 Qwen3 模型进行文本编码
"""

import torch
import torch.nn as nn
from typing import List, Optional, Union, Tuple


def load_qwen_model(model_path: str, device: str = "cuda"):
    """
    Load Qwen3 embedding model
    加载 Qwen3 嵌入模型
    """
    try:
        from sentence_transformers import SentenceTransformer
        model = SentenceTransformer(model_path)
        model.to(device)
        return model
    except ImportError:
        print("Warning: sentence-transformers not available. Using mock embeddings.")
        return None


def encode_text_with_qwen(
    qwen_model, 
    texts: List[str], 
    device: str = "cuda",
    max_length: int = 512,
    use_query_mode: bool = False
) -> torch.Tensor:
    """
    Encode text using Qwen3 model
    使用 Qwen3 模型编码文本
    Args:
        qwen_model: Qwen3 embedding model
        texts: List of text strings to encode
        device: Device to run on
        max_length: Maximum sequence length
        use_query_mode: Whether to use query prompt for better understanding
    """
    if qwen_model is None:
        # Mock embeddings for testing when sentence-transformers is not available
        batch_size = len(texts)
        return torch.randn(batch_size, 1024, device=device, dtype=torch.float32)
    
    with torch.no_grad():
        # Use query prompt for better text understanding when specified
        embeddings = qwen_model.encode(
            texts, 
            prompt_name="query" if use_query_mode else None,
            convert_to_tensor=True,
            device=device,
            max_seq_length=max_length,
            output_value="token_embeddings" if not use_query_mode else "sentence_embedding"
        )
    
    return embeddings if use_query_mode else torch.stack(embeddings, dim=0)
 

class QwenTextEncoder(nn.Module):
    """
    Qwen Text Encoder wrapper for training and inference
    用于训练和推理的 Qwen 文本编码器包装器
    """
    
    def __init__(
        self,
        model_path: str = "models/Qwen3-Embedding-0.6B",
        device: str = "cuda",
        max_length: int = 512,
        freeze_encoder: bool = True
    ):
        super().__init__()
        self.device = device
        self.max_length = max_length
        self.freeze_encoder = freeze_encoder
        
        # Load Qwen model
        self.qwen_model = load_qwen_model(model_path, device)
        
        # Freeze parameters if specified
        if self.freeze_encoder and self.qwen_model is not None:
            for param in self.qwen_model.parameters():
                param.requires_grad = False
    
    def encode_prompts(
        self,
        prompts: List[str],
        negative_prompts: Optional[List[str]] = None,
        do_classifier_free_guidance: bool = False
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Encode prompts using Qwen3 model
        使用 Qwen3 模型编码提示词
        
        Returns:
            tuple: (text_embeddings, pooled_embeddings)
                - text_embeddings: [batch_size, 1024] for sequence embeddings
                - pooled_embeddings: [batch_size, 1024] for pooled embeddings
        """
        batch_size = len(prompts)
        
        # Encode positive prompts for text embeddings (normal mode)
        text_embeddings = encode_text_with_qwen(
            self.qwen_model, prompts, self.device, 
            max_length=self.max_length, use_query_mode=False
        )
        
        # Encode positive prompts for pooled embeddings (query mode)
        pooled_embeddings = encode_text_with_qwen(
            self.qwen_model, prompts, self.device,
            max_length=self.max_length, use_query_mode=True
        )
        
        # Handle negative prompts
        if do_classifier_free_guidance:
            if negative_prompts is None:
                negative_prompts = [""] * batch_size
            
            # Encode negative prompts
            negative_text_embeddings = encode_text_with_qwen(
                self.qwen_model, negative_prompts, self.device,
                max_length=self.max_length, use_query_mode=False
            )
            
            negative_pooled_embeddings = encode_text_with_qwen(
                self.qwen_model, negative_prompts, self.device,
                max_length=self.max_length, use_query_mode=True
            )
            
            # Concatenate for classifier-free guidance
            text_embeddings = torch.cat([negative_text_embeddings, text_embeddings], dim=0)
            pooled_embeddings = torch.cat([negative_pooled_embeddings, pooled_embeddings], dim=0)
        
        return text_embeddings, pooled_embeddings
    
    def forward(self, prompts: List[str], negative_prompts: Optional[List[str]] = None):
        """
        Forward pass for text encoding
        Args:
            prompts: List of text prompts
            negative_prompts: Optional list of negative prompts
        Returns:
            tuple: (text_embeddings, pooled_embeddings)
        """
        return self.encode_prompts(prompts, negative_prompts, do_classifier_free_guidance=(negative_prompts is not None))
    
    def train(self, mode: bool = True):
        """Override train mode to handle frozen encoder"""
        super().train(mode)
        if self.freeze_encoder and self.qwen_model is not None:
            self.qwen_model.eval()  # Keep encoder in eval mode
        return self