import os import logging import torch from typing import Optional, List, Dict, Any, Union from fastapi import FastAPI, HTTPException, Form from pydantic import BaseModel, Field from dotenv import load_dotenv from enum import Enum from openai.types.create_embedding_response import CreateEmbeddingResponse, Usage from openai.types.embedding import Embedding import os import torch import torch.nn.functional as F import unicodedata import numpy as np import logging from PIL import Image from urllib.parse import urlparse from dataclasses import dataclass from typing import Optional, List, Union, Dict, Any from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLPreTrainedModel, Qwen3VLModel, Qwen3VLConfig from transformers.models.qwen3_vl.processing_qwen3_vl import Qwen3VLProcessor from transformers.modeling_outputs import ModelOutput from transformers.processing_utils import Unpack from transformers.utils import TransformersKwargs from transformers.cache_utils import Cache from transformers.utils.generic import check_model_inputs from qwen_vl_utils.vision_process import process_vision_info logger = logging.getLogger(__name__) # Constants for configuration MAX_LENGTH = 2048 IMAGE_BASE_FACTOR = 16 IMAGE_FACTOR = IMAGE_BASE_FACTOR * 2 MIN_PIXELS = 4 * IMAGE_FACTOR * IMAGE_FACTOR MAX_PIXELS = 1800 * IMAGE_FACTOR * IMAGE_FACTOR FPS = 1 MAX_FRAMES = 64 FRAME_MAX_PIXELS = 768 * IMAGE_FACTOR * IMAGE_FACTOR MAX_TOTAL_PIXELS = 10 * FRAME_MAX_PIXELS PAD_TOKEN = "<|endoftext|>" # Define output structure for embeddings @dataclass class Qwen3VLForEmbeddingOutput(ModelOutput): last_hidden_state: Optional[torch.FloatTensor] = None attention_mask: Optional[torch.Tensor] = None # Define model class to compute embeddings class Qwen3VLForEmbedding(Qwen3VLPreTrainedModel): _checkpoint_conversion_mapping = {} accepts_loss_kwargs = False config: Qwen3VLConfig def __init__(self, config): super().__init__(config) self.model = Qwen3VLModel(config) self.post_init() def get_input_embeddings(self): return self.model.get_input_embeddings() def set_input_embeddings(self, value): self.model.set_input_embeddings(value) def set_decoder(self, decoder): self.model.set_decoder(decoder) def get_decoder(self): return self.model.get_decoder() # Extract video features from model def get_video_features(self, pixel_values_videos: torch.FloatTensor, video_grid_thw: Optional[torch.LongTensor] = None): return self.model.get_video_features(pixel_values_videos, video_grid_thw) # Extract image features from model def get_image_features(self, pixel_values: torch.FloatTensor, image_grid_thw: Optional[torch.LongTensor] = None): return self.model.get_image_features(pixel_values, image_grid_thw) # Make modules accessible through properties @property def language_model(self): return self.model.language_model @property def visual(self): return self.model.visual # Forward pass through model with input parameters # @check_model_inputs def forward(self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, pixel_values: Optional[torch.Tensor] = None, pixel_values_videos: Optional[torch.FloatTensor] = None, image_grid_thw: Optional[torch.LongTensor] = None, video_grid_thw: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, **kwargs: Unpack[TransformersKwargs], ) -> Union[tuple, Qwen3VLForEmbeddingOutput]: # Pass inputs through the model outputs = self.model( input_ids=input_ids, pixel_values=pixel_values, pixel_values_videos=pixel_values_videos, image_grid_thw=image_grid_thw, video_grid_thw=video_grid_thw, position_ids=position_ids, attention_mask=attention_mask, past_key_values=past_key_values, inputs_embeds=inputs_embeds, cache_position=cache_position, **kwargs, ) # Return the model output return Qwen3VLForEmbeddingOutput( last_hidden_state=outputs.last_hidden_state, attention_mask=attention_mask, ) # Define embedder class for processing inputs and generating embeddings class Qwen3VLEmbedder(): def __init__( self, model_name_or_path: str, max_length: int = MAX_LENGTH, min_pixels: int = MIN_PIXELS, max_pixels: int = MAX_PIXELS, total_pixels: int = MAX_TOTAL_PIXELS, fps: float = FPS, max_frames: int = MAX_FRAMES, default_instruction: str = "Represent the user's input.", **kwargs ): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.max_length = max_length self.min_pixels = min_pixels self.max_pixels = max_pixels self.total_pixels = total_pixels self.fps = fps self.max_frames = max_frames self.default_instruction = default_instruction self.model = Qwen3VLForEmbedding.from_pretrained( model_name_or_path, trust_remote_code=True, **kwargs ).to(device) self.processor = Qwen3VLProcessor.from_pretrained( model_name_or_path, padding_side='right' ) self.model.eval() @torch.no_grad() def forward(self, inputs: Dict[str, Any]) -> Dict[str, torch.Tensor]: outputs = self.model(**inputs) return { 'last_hidden_state': outputs.last_hidden_state, 'attention_mask': inputs.get('attention_mask') } # Preprocess input conversations for model consumption def _preprocess_inputs(self, conversations: List[List[Dict]]) -> Dict[str, torch.Tensor]: text = self.processor.apply_chat_template( conversations, add_generation_prompt=True, tokenize=False ) try: images, video_inputs, video_kwargs = process_vision_info( conversations, image_patch_size=16, return_video_metadata=True, return_video_kwargs=True ) except Exception as e: logger.error(f"Error in processing vision info: {e}") images = None video_inputs = None video_kwargs = {'do_sample_frames': False} text = self.processor.apply_chat_template( [{'role': 'user', 'content': [{'type': 'text', 'text': 'NULL'}]}], add_generation_prompt=True, tokenize=False ) if video_inputs is not None: videos, video_metadata = zip(*video_inputs) videos = list(videos) video_metadata = list(video_metadata) else: videos, video_metadata = None, None inputs = self.processor( text=text, images=images, videos=videos, video_metadata=video_metadata, truncation=True, max_length=self.max_length, padding=True, do_resize=False, return_tensors='pt', **video_kwargs ) return inputs # Pool the last hidden state by attention mask for embeddings @staticmethod def _pooling_last(hidden_state: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: flipped_tensor = attention_mask.flip(dims=[1]) last_one_positions = flipped_tensor.argmax(dim=1) col = attention_mask.shape[1] - last_one_positions - 1 row = torch.arange(hidden_state.shape[0], device=hidden_state.device) return hidden_state[row, col] # Process inputs to generate normalized embeddings def process(self, inputs: List[List[Dict]], normalize: bool = True) -> tuple: # print("conversations:\n", inputs) processed_inputs = self._preprocess_inputs(inputs) processed_inputs = {k: v.to(self.model.device) for k, v in processed_inputs.items()} outputs = self.forward(processed_inputs) embeddings = self._pooling_last(outputs['last_hidden_state'], outputs['attention_mask']) # Normalize the embeddings if specified if normalize: embeddings = F.normalize(embeddings, p=2, dim=-1) return embeddings class EmbeddingRequest(BaseModel): messages: List[Dict[str, Any]] = Field(..., description="输入文本或文本列表") model: str = Field(default="Qwen3VL", description="模型名称") encoding_format: str = Field(default="float", description="输出格式") continue_final_message: Optional[bool] = Field(default=False, description="是否继续生成最终消息") add_special_tokens: Optional[bool] = Field(default=False, description="是否添加特殊标记") load_dotenv() logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") logger = logging.getLogger(__name__) app = FastAPI( title="Qwen3VL Embedding API", description="API for Qwen3VL Embedding model", version="1.0.0" ) # 加载模型 logger.info("Loading Qwen3VL Embedding model...") try: model_path = os.getenv("EMBEDDING_MODEL_PATH", "") embedding_model = Qwen3VLEmbedder( model_name_or_path=model_path, max_length=2048 ) logger.info("Qwen3VL Embedding model loaded successfully!") except Exception as e: logger.error(f"Failed to load model: {e}") embedding_model = None def wrap_embedding_list(embedding_list: List[float], index: int = 0) -> Embedding: """ 将embedding列表包装成Embedding类的实例 Args: embedding_list: 包含浮点数值的列表,表示嵌入向量 index: 在嵌入列表中的索引 Returns: Embedding类的实例 """ return Embedding( embedding=embedding_list, index=index, object="embedding" ) def wrap_multiple_embedding_lists(embedding_lists: List[List[float]]) -> List[Embedding]: """ 将多个embedding列表包装成Embedding类的实例列表 Args: embedding_lists: 包含多个嵌入向量列表的列表 Returns: Embedding类实例的列表 """ return [wrap_embedding_list(embedding_list, idx) for idx, embedding_list in enumerate(embedding_lists)] @app.post("/v1/embeddings", response_model=CreateEmbeddingResponse) async def create_embeddings(request: EmbeddingRequest): """ OpenAI 兼容的 Embeddings 接口 """ try: if embedding_model is None: raise HTTPException(status_code=500, detail="模型未正确加载") conversation = request.messages embedding_result = embedding_model.process(conversation, normalize=True) embedding_list = embedding_result.cpu().tolist() embedding_objects = wrap_multiple_embedding_lists(embedding_list) return CreateEmbeddingResponse( data = embedding_objects, model = request.model, object = "list", usage = Usage( prompt_tokens = len(request.messages), total_tokens = len(request.messages) ) ) except Exception as e: logger.error(f"Error during embedding: {e}") raise HTTPException(status_code=500, detail=str(e)) if __name__ == "__main__": import uvicorn port = int(os.getenv("EMBEDDING_API_PORT", 8006)) uvicorn.run(app, host="0.0.0.0", port=port)