File size: 5,482 Bytes
4960ef6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 |
"""
Qwen Text Encoder
Qwen 文本编码器 - 使用 Qwen3 模型进行文本编码
"""
import torch
import torch.nn as nn
from typing import List, Optional, Union, Tuple
def load_qwen_model(model_path: str, device: str = "cuda"):
"""
Load Qwen3 embedding model
加载 Qwen3 嵌入模型
"""
try:
from sentence_transformers import SentenceTransformer
model = SentenceTransformer(model_path)
model.to(device)
return model
except ImportError:
print("Warning: sentence-transformers not available. Using mock embeddings.")
return None
def encode_text_with_qwen(
qwen_model,
texts: List[str],
device: str = "cuda",
max_length: int = 512,
use_query_mode: bool = False
) -> torch.Tensor:
"""
Encode text using Qwen3 model
使用 Qwen3 模型编码文本
Args:
qwen_model: Qwen3 embedding model
texts: List of text strings to encode
device: Device to run on
max_length: Maximum sequence length
use_query_mode: Whether to use query prompt for better understanding
"""
if qwen_model is None:
# Mock embeddings for testing when sentence-transformers is not available
batch_size = len(texts)
return torch.randn(batch_size, 1024, device=device, dtype=torch.float32)
with torch.no_grad():
# Use query prompt for better text understanding when specified
embeddings = qwen_model.encode(
texts,
prompt_name="query" if use_query_mode else None,
convert_to_tensor=True,
device=device,
max_seq_length=max_length,
output_value="token_embeddings" if not use_query_mode else "sentence_embedding"
)
return embeddings if use_query_mode else torch.stack(embeddings, dim=0)
class QwenTextEncoder(nn.Module):
"""
Qwen Text Encoder wrapper for training and inference
用于训练和推理的 Qwen 文本编码器包装器
"""
def __init__(
self,
model_path: str = "models/Qwen3-Embedding-0.6B",
device: str = "cuda",
max_length: int = 512,
freeze_encoder: bool = True
):
super().__init__()
self.device = device
self.max_length = max_length
self.freeze_encoder = freeze_encoder
# Load Qwen model
self.qwen_model = load_qwen_model(model_path, device)
# Freeze parameters if specified
if self.freeze_encoder and self.qwen_model is not None:
for param in self.qwen_model.parameters():
param.requires_grad = False
def encode_prompts(
self,
prompts: List[str],
negative_prompts: Optional[List[str]] = None,
do_classifier_free_guidance: bool = False
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Encode prompts using Qwen3 model
使用 Qwen3 模型编码提示词
Returns:
tuple: (text_embeddings, pooled_embeddings)
- text_embeddings: [batch_size, 1024] for sequence embeddings
- pooled_embeddings: [batch_size, 1024] for pooled embeddings
"""
batch_size = len(prompts)
# Encode positive prompts for text embeddings (normal mode)
text_embeddings = encode_text_with_qwen(
self.qwen_model, prompts, self.device,
max_length=self.max_length, use_query_mode=False
)
# Encode positive prompts for pooled embeddings (query mode)
pooled_embeddings = encode_text_with_qwen(
self.qwen_model, prompts, self.device,
max_length=self.max_length, use_query_mode=True
)
# Handle negative prompts
if do_classifier_free_guidance:
if negative_prompts is None:
negative_prompts = [""] * batch_size
# Encode negative prompts
negative_text_embeddings = encode_text_with_qwen(
self.qwen_model, negative_prompts, self.device,
max_length=self.max_length, use_query_mode=False
)
negative_pooled_embeddings = encode_text_with_qwen(
self.qwen_model, negative_prompts, self.device,
max_length=self.max_length, use_query_mode=True
)
# Concatenate for classifier-free guidance
text_embeddings = torch.cat([negative_text_embeddings, text_embeddings], dim=0)
pooled_embeddings = torch.cat([negative_pooled_embeddings, pooled_embeddings], dim=0)
return text_embeddings, pooled_embeddings
def forward(self, prompts: List[str], negative_prompts: Optional[List[str]] = None):
"""
Forward pass for text encoding
Args:
prompts: List of text prompts
negative_prompts: Optional list of negative prompts
Returns:
tuple: (text_embeddings, pooled_embeddings)
"""
return self.encode_prompts(prompts, negative_prompts, do_classifier_free_guidance=(negative_prompts is not None))
def train(self, mode: bool = True):
"""Override train mode to handle frozen encoder"""
super().train(mode)
if self.freeze_encoder and self.qwen_model is not None:
self.qwen_model.eval() # Keep encoder in eval mode
return self
|