import gradio as gr import torch import numpy as np from transformers import AutoProcessor, AutoModelForSequenceClassification from scripts.qwen3_vl_embedding import Qwen3VLEmbedder from scripts.qwen3_vl_reranker import Qwen3VLReranker import cv2 import os from typing import List from PIL import Image # ================================ # 全局模型加载 # ================================ # Embedding 模型 model_id = "Qwen/Qwen3-VL-Embedding-2B" embedder = Qwen3VLEmbedder( model_name_or_path=model_id, torch_dtype=torch.float16, ) # Reranker 模型 try: reranker_id = "Qwen/Qwen3-VL-Reranker-2B" # 使用官方脚本加载模型 reranker_model = Qwen3VLReranker( model_name_or_path=reranker_id, torch_dtype=torch.float16, device_map="cuda" if torch.cuda.is_available() else "cpu" ) print("Qwen3-VL-Reranker-2B 加载完成 (Using official script)") except Exception as e: print(f"Reranker 加载失败: {str(e)}") reranker_model = None # ================================ # Reranker 函数(核心修复版) # ================================ def extract_content_from_list(content_list): """ 辅助函数:从 content list 中提取 text, image, video """ text_parts = [] images = [] videos = [] for item in content_list: if item['type'] == 'text': text_parts.append(item['text']) elif item['type'] == 'image': images.append(item['image']) elif item['type'] == 'video': videos.append(item['video']) text = "\n".join(text_parts) if text_parts else None image = images[0] if images else None video = videos[0] if videos else None return text, image, video def qwen_vl_rerank(query_content: list, candidates: list): """ query_content: list of dict,例如 [{"type": "text", "text": "..."}, {"type": "image", "image": pil_img}] candidates: list of list,每个元素是 candidate 的 content list,例如 [{"type": "image", "image": frame}] 返回: list of (original_index, score) 按分数降序 """ if not candidates or reranker_model is None: return [] # 构造符合 Qwen3-VL-Reranker process 方法的输入 q_text, q_image, q_video = extract_content_from_list(query_content) documents = [] for cand in candidates: d_text, d_image, d_video = extract_content_from_list(cand) documents.append({ "text": d_text, "image": d_image, "video": d_video }) inputs = { "query": { "text": q_text, "image": q_image, "video": q_video }, "documents": documents } try: # 调用 process 方法 scores = reranker_model.process(inputs) # 返回 (原索引, 分数) 并排序 indexed_scores = sorted(enumerate(scores), key=lambda x: x[1], reverse=True) return indexed_scores except Exception as e: print(f"Reranker 执行失败: {str(e)}") return [] # ================================ # Tab 1: Embedding 生成 # ================================ def compute_embedding(text: str = "", image=None): if not text and image is None: return "请输入文本或上传图片!", None content = [] if text: content.append({"type": "text", "text": text}) if image is not None: content.append({"type": "image", "image": image}) inputs = content try: with torch.no_grad(): embeddings = embedder.process(inputs, normalize=True) embedding = embeddings[0].cpu().numpy().flatten() emb_str = ", ".join(f"{x:.4f}" for x in embedding[:10]) result_text = f"Embedding 向量(前10维):[{emb_str}]...\n总维度: {len(embedding)}" except Exception as e: result_text = f"生成 embedding 失败:{str(e)}" return result_text, image embedding_tab = gr.Interface( fn=compute_embedding, inputs=[ gr.Textbox(label="输入文本(可选)", placeholder="例如:一只猫在阳光下睡觉"), gr.Image( label="上传本地图片(仅支持上传,不支持链接)", sources=["upload"], type="pil" ), ], outputs=[ gr.Textbox(label="生成的 Embedding"), gr.Image(label="上传的原图片预览"), ], title="Embedding 生成(仅支持本地上传图片)", description="支持文本 + 本地上传图片的多模态 embedding 生成。图片上传后会自动展示。", examples=[ ["一只可爱的猫咪在阳光下睡觉", None], ["这张图片里有什么动物?", None], ], cache_examples=False, ) # ================================ # Tab 2: 内容检索(含 Reranker) # ================================ def retrieve_content(query_text: str, query_image, source_type: str, text_input: str, txt_file, video_file): if not query_text and query_image is None: return "请至少提供查询文本 或 上传查询图片!" # 生成 query content(用于 reranker) query_content = [] if query_text: query_content.append({"type": "text", "text": query_text}) if query_image is not None: query_content.append({"type": "image", "image": query_image}) # 生成 query embedding(用于初筛) try: with torch.no_grad(): query_emb = embedder.process(query_content, normalize=True)[0].cpu().numpy() except Exception as e: return f"查询 embedding 生成失败:{str(e)}" results = [] if source_type in ["文本输入", "上传 TXT"]: text = text_input if txt_file is not None: try: with open(txt_file.name, "r", encoding="utf-8") as f: text = f.read() except Exception as e: return f"读取 TXT 文件失败:{str(e)}" if not text.strip(): return "没有提供有效文本内容!" # 切段 segments = [] step = 150 for i in range(0, len(text), step): seg = text[i:i+step+100] segments.append(seg) seg_embs = [] seg_contents = [] # 保存用于 reranker 的 content list for seg in segments: seg_content = [{"type": "text", "text": seg}] try: with torch.no_grad(): emb = embedder.process(seg_content, normalize=True)[0].cpu().numpy() seg_embs.append(emb) seg_contents.append(seg_content) except Exception as e: return f"段落 embedding 生成失败:{str(e)}" # Embedding 初筛 Top-10 sims = [np.dot(query_emb, emb) / (np.linalg.norm(query_emb) * np.linalg.norm(emb) + 1e-8) for emb in seg_embs] top_indices = np.argsort(sims)[-5:][::-1] # Top-10 初筛 # 准备 reranker candidates top_candidates = [seg_contents[i] for i in top_indices] # 调用 reranker reranked = qwen_vl_rerank(query_content, top_candidates) results = [] if reranked: for r_idx, score in reranked[:3]: orig_idx = top_indices[r_idx] results.append(f"Reranker 分数: {score:.4f} | 原始相似度: {sims[orig_idx]:.4f}\n段落: {segments[orig_idx][:200]}...\n") else: # 如果 reranker 失败,回退到 embedding 排序 for idx in top_indices[:3]: results.append(f"相似度: {sims[idx]:.4f}\n段落: {segments[idx][:200]}...\n(仅 Embedding 排序)") elif source_type == "视频上传": if video_file is None: return "请上传视频文件!" video_path = video_file cap = cv2.VideoCapture(video_path) if not cap.isOpened(): return "无法打开视频文件!" fps = cap.get(cv2.CAP_PROP_FPS) or 30 frame_idx = 0 frame_embs = [] timestamps = [] frame_indices = [] while cap.isOpened(): ret, frame = cap.read() if not ret: break if frame_idx % 60 == 0: # 每秒约1帧 frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) pil_frame = Image.fromarray(frame_rgb) frame_content = [{"type": "image", "image": pil_frame}] try: with torch.no_grad(): emb = embedder.process(frame_content, normalize=True)[0].cpu().numpy() frame_embs.append(emb) time_sec = frame_idx / fps timestamps.append(f"{int(time_sec // 60):02d}:{int(time_sec % 60):02d}") frame_indices.append(frame_idx) except Exception as e: cap.release() return f"视频帧 embedding 生成失败:{str(e)}" frame_idx += 1 cap.release() if not frame_embs: return "视频无有效帧!" # Embedding 初筛 Top-10 sims = [np.dot(query_emb, emb) / (np.linalg.norm(query_emb) * np.linalg.norm(emb) + 1e-8) for emb in frame_embs] top_indices = np.argsort(sims)[-10:][::-1] # 准备 Reranker candidates candidates = [] valid_indices = [] rerank_cap = cv2.VideoCapture(video_path) for orig_idx in top_indices: frame_num = frame_indices[orig_idx] rerank_cap.set(cv2.CAP_PROP_POS_FRAMES, frame_num) ret, frame = rerank_cap.read() if ret: frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) pil_frame = Image.fromarray(frame_rgb) candidates.append([{"type": "image", "image": pil_frame}]) valid_indices.append(orig_idx) rerank_cap.release() # 调用 Reranker if candidates: reranked_results = qwen_vl_rerank(query_content, candidates) for r_idx, score in reranked_results[:3]: original_idx = valid_indices[r_idx] ts = timestamps[original_idx] results.append(f"Reranker 分数: {score:.4f} | 原始相似度: {sims[original_idx]:.4f} | 时间戳: {ts}\n(Reranker 优选)") else: # 回退 for idx in top_indices[:3]: results.append(f"相似度: {sims[idx]:.4f} | 时间戳: {timestamps[idx]}\n(仅 Embedding 排序)") if not results: return "未找到匹配内容" return "\n---\n".join(results) retrieval_tab = gr.Interface( fn=retrieve_content, inputs=[ gr.Textbox(label="查询描述(文字,可选)", placeholder="穿着白衬衫的男人 或 可爱的猫咪在阳光下睡觉"), gr.Image( label="查询图片(可选,仅支持上传)", sources=["upload"], type="pil" ), gr.Radio(["文本输入", "上传 TXT", "视频上传"], label="查找来源类型", value="文本输入"), gr.Textbox(label="直接输入长文本 或 TXT 内容", lines=8), gr.File(label="上传 TXT 文件(.txt)", file_types=[".txt"]), gr.Video(label="上传视频文件(mp4,建议短视频)", sources=["upload"]), ], outputs=gr.Textbox(label="检索结果(Top 3 最相似部分)", lines=12), title="内容检索模式(已集成 Qwen3-VL-Reranker)", description="支持文字和/或图片查询(至少一个),返回最相似的文本段落或视频帧 + 时间戳。\n已使用 Reranker 提升排序精度。", ) # 组合两个 Tab demo = gr.TabbedInterface( [embedding_tab, retrieval_tab], tab_names=["Embedding 生成", "内容检索"] ) demo.launch()