Spaces:
Running
Running
| import gradio as gr | |
| import torch | |
| import numpy as np | |
| from transformers import AutoProcessor, AutoModelForSequenceClassification | |
| from scripts.qwen3_vl_embedding import Qwen3VLEmbedder | |
| from scripts.qwen3_vl_reranker import Qwen3VLReranker | |
| import cv2 | |
| import os | |
| from typing import List | |
| from PIL import Image | |
| # ================================ | |
| # 全局模型加载 | |
| # ================================ | |
| # Embedding 模型 | |
| model_id = "Qwen/Qwen3-VL-Embedding-2B" | |
| embedder = Qwen3VLEmbedder( | |
| model_name_or_path=model_id, | |
| torch_dtype=torch.float16, | |
| ) | |
| # Reranker 模型 | |
| try: | |
| reranker_id = "Qwen/Qwen3-VL-Reranker-2B" | |
| # 使用官方脚本加载模型 | |
| reranker_model = Qwen3VLReranker( | |
| model_name_or_path=reranker_id, | |
| torch_dtype=torch.float16, | |
| device_map="cuda" if torch.cuda.is_available() else "cpu" | |
| ) | |
| print("Qwen3-VL-Reranker-2B 加载完成 (Using official script)") | |
| except Exception as e: | |
| print(f"Reranker 加载失败: {str(e)}") | |
| reranker_model = None | |
| # ================================ | |
| # Reranker 函数(核心修复版) | |
| # ================================ | |
| def extract_content_from_list(content_list): | |
| """ | |
| 辅助函数:从 content list 中提取 text, image, video | |
| """ | |
| text_parts = [] | |
| images = [] | |
| videos = [] | |
| for item in content_list: | |
| if item['type'] == 'text': | |
| text_parts.append(item['text']) | |
| elif item['type'] == 'image': | |
| images.append(item['image']) | |
| elif item['type'] == 'video': | |
| videos.append(item['video']) | |
| text = "\n".join(text_parts) if text_parts else None | |
| image = images[0] if images else None | |
| video = videos[0] if videos else None | |
| return text, image, video | |
| def qwen_vl_rerank(query_content: list, candidates: list): | |
| """ | |
| query_content: list of dict,例如 [{"type": "text", "text": "..."}, {"type": "image", "image": pil_img}] | |
| candidates: list of list,每个元素是 candidate 的 content list,例如 [{"type": "image", "image": frame}] | |
| 返回: list of (original_index, score) 按分数降序 | |
| """ | |
| if not candidates or reranker_model is None: | |
| return [] | |
| # 构造符合 Qwen3-VL-Reranker process 方法的输入 | |
| q_text, q_image, q_video = extract_content_from_list(query_content) | |
| documents = [] | |
| for cand in candidates: | |
| d_text, d_image, d_video = extract_content_from_list(cand) | |
| documents.append({ | |
| "text": d_text, | |
| "image": d_image, | |
| "video": d_video | |
| }) | |
| inputs = { | |
| "query": { | |
| "text": q_text, | |
| "image": q_image, | |
| "video": q_video | |
| }, | |
| "documents": documents | |
| } | |
| try: | |
| # 调用 process 方法 | |
| scores = reranker_model.process(inputs) | |
| # 返回 (原索引, 分数) 并排序 | |
| indexed_scores = sorted(enumerate(scores), key=lambda x: x[1], reverse=True) | |
| return indexed_scores | |
| except Exception as e: | |
| print(f"Reranker 执行失败: {str(e)}") | |
| return [] | |
| # ================================ | |
| # Tab 1: Embedding 生成 | |
| # ================================ | |
| def compute_embedding(text: str = "", image=None): | |
| if not text and image is None: | |
| return "请输入文本或上传图片!", None | |
| content = [] | |
| if text: | |
| content.append({"type": "text", "text": text}) | |
| if image is not None: | |
| content.append({"type": "image", "image": image}) | |
| inputs = content | |
| try: | |
| with torch.no_grad(): | |
| embeddings = embedder.process(inputs, normalize=True) | |
| embedding = embeddings[0].cpu().numpy().flatten() | |
| emb_str = ", ".join(f"{x:.4f}" for x in embedding[:10]) | |
| result_text = f"Embedding 向量(前10维):[{emb_str}]...\n总维度: {len(embedding)}" | |
| except Exception as e: | |
| result_text = f"生成 embedding 失败:{str(e)}" | |
| return result_text, image | |
| embedding_tab = gr.Interface( | |
| fn=compute_embedding, | |
| inputs=[ | |
| gr.Textbox(label="输入文本(可选)", placeholder="例如:一只猫在阳光下睡觉"), | |
| gr.Image( | |
| label="上传本地图片(仅支持上传,不支持链接)", | |
| sources=["upload"], | |
| type="pil" | |
| ), | |
| ], | |
| outputs=[ | |
| gr.Textbox(label="生成的 Embedding"), | |
| gr.Image(label="上传的原图片预览"), | |
| ], | |
| title="Embedding 生成(仅支持本地上传图片)", | |
| description="支持文本 + 本地上传图片的多模态 embedding 生成。图片上传后会自动展示。", | |
| examples=[ | |
| ["一只可爱的猫咪在阳光下睡觉", None], | |
| ["这张图片里有什么动物?", None], | |
| ], | |
| cache_examples=False, | |
| ) | |
| # ================================ | |
| # Tab 2: 内容检索(含 Reranker) | |
| # ================================ | |
| def retrieve_content(query_text: str, query_image, source_type: str, text_input: str, txt_file, video_file): | |
| if not query_text and query_image is None: | |
| return "请至少提供查询文本 或 上传查询图片!" | |
| # 生成 query content(用于 reranker) | |
| query_content = [] | |
| if query_text: | |
| query_content.append({"type": "text", "text": query_text}) | |
| if query_image is not None: | |
| query_content.append({"type": "image", "image": query_image}) | |
| # 生成 query embedding(用于初筛) | |
| try: | |
| with torch.no_grad(): | |
| query_emb = embedder.process(query_content, normalize=True)[0].cpu().numpy() | |
| except Exception as e: | |
| return f"查询 embedding 生成失败:{str(e)}" | |
| results = [] | |
| if source_type in ["文本输入", "上传 TXT"]: | |
| text = text_input | |
| if txt_file is not None: | |
| try: | |
| with open(txt_file.name, "r", encoding="utf-8") as f: | |
| text = f.read() | |
| except Exception as e: | |
| return f"读取 TXT 文件失败:{str(e)}" | |
| if not text.strip(): | |
| return "没有提供有效文本内容!" | |
| # 切段 | |
| segments = [] | |
| step = 150 | |
| for i in range(0, len(text), step): | |
| seg = text[i:i+step+100] | |
| segments.append(seg) | |
| seg_embs = [] | |
| seg_contents = [] # 保存用于 reranker 的 content list | |
| for seg in segments: | |
| seg_content = [{"type": "text", "text": seg}] | |
| try: | |
| with torch.no_grad(): | |
| emb = embedder.process(seg_content, normalize=True)[0].cpu().numpy() | |
| seg_embs.append(emb) | |
| seg_contents.append(seg_content) | |
| except Exception as e: | |
| return f"段落 embedding 生成失败:{str(e)}" | |
| # Embedding 初筛 Top-10 | |
| sims = [np.dot(query_emb, emb) / (np.linalg.norm(query_emb) * np.linalg.norm(emb) + 1e-8) for emb in seg_embs] | |
| top_indices = np.argsort(sims)[-5:][::-1] # Top-10 初筛 | |
| # 准备 reranker candidates | |
| top_candidates = [seg_contents[i] for i in top_indices] | |
| # 调用 reranker | |
| reranked = qwen_vl_rerank(query_content, top_candidates) | |
| results = [] | |
| if reranked: | |
| for r_idx, score in reranked[:3]: | |
| orig_idx = top_indices[r_idx] | |
| results.append(f"Reranker 分数: {score:.4f} | 原始相似度: {sims[orig_idx]:.4f}\n段落: {segments[orig_idx][:200]}...\n") | |
| else: | |
| # 如果 reranker 失败,回退到 embedding 排序 | |
| for idx in top_indices[:3]: | |
| results.append(f"相似度: {sims[idx]:.4f}\n段落: {segments[idx][:200]}...\n(仅 Embedding 排序)") | |
| elif source_type == "视频上传": | |
| if video_file is None: | |
| return "请上传视频文件!" | |
| video_path = video_file | |
| cap = cv2.VideoCapture(video_path) | |
| if not cap.isOpened(): | |
| return "无法打开视频文件!" | |
| fps = cap.get(cv2.CAP_PROP_FPS) or 30 | |
| frame_idx = 0 | |
| frame_embs = [] | |
| timestamps = [] | |
| frame_indices = [] | |
| while cap.isOpened(): | |
| ret, frame = cap.read() | |
| if not ret: | |
| break | |
| if frame_idx % 60 == 0: # 每秒约1帧 | |
| frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
| pil_frame = Image.fromarray(frame_rgb) | |
| frame_content = [{"type": "image", "image": pil_frame}] | |
| try: | |
| with torch.no_grad(): | |
| emb = embedder.process(frame_content, normalize=True)[0].cpu().numpy() | |
| frame_embs.append(emb) | |
| time_sec = frame_idx / fps | |
| timestamps.append(f"{int(time_sec // 60):02d}:{int(time_sec % 60):02d}") | |
| frame_indices.append(frame_idx) | |
| except Exception as e: | |
| cap.release() | |
| return f"视频帧 embedding 生成失败:{str(e)}" | |
| frame_idx += 1 | |
| cap.release() | |
| if not frame_embs: | |
| return "视频无有效帧!" | |
| # Embedding 初筛 Top-10 | |
| sims = [np.dot(query_emb, emb) / (np.linalg.norm(query_emb) * np.linalg.norm(emb) + 1e-8) for emb in frame_embs] | |
| top_indices = np.argsort(sims)[-10:][::-1] | |
| # 准备 Reranker candidates | |
| candidates = [] | |
| valid_indices = [] | |
| rerank_cap = cv2.VideoCapture(video_path) | |
| for orig_idx in top_indices: | |
| frame_num = frame_indices[orig_idx] | |
| rerank_cap.set(cv2.CAP_PROP_POS_FRAMES, frame_num) | |
| ret, frame = rerank_cap.read() | |
| if ret: | |
| frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
| pil_frame = Image.fromarray(frame_rgb) | |
| candidates.append([{"type": "image", "image": pil_frame}]) | |
| valid_indices.append(orig_idx) | |
| rerank_cap.release() | |
| # 调用 Reranker | |
| if candidates: | |
| reranked_results = qwen_vl_rerank(query_content, candidates) | |
| for r_idx, score in reranked_results[:3]: | |
| original_idx = valid_indices[r_idx] | |
| ts = timestamps[original_idx] | |
| results.append(f"Reranker 分数: {score:.4f} | 原始相似度: {sims[original_idx]:.4f} | 时间戳: {ts}\n(Reranker 优选)") | |
| else: | |
| # 回退 | |
| for idx in top_indices[:3]: | |
| results.append(f"相似度: {sims[idx]:.4f} | 时间戳: {timestamps[idx]}\n(仅 Embedding 排序)") | |
| if not results: | |
| return "未找到匹配内容" | |
| return "\n---\n".join(results) | |
| retrieval_tab = gr.Interface( | |
| fn=retrieve_content, | |
| inputs=[ | |
| gr.Textbox(label="查询描述(文字,可选)", placeholder="穿着白衬衫的男人 或 可爱的猫咪在阳光下睡觉"), | |
| gr.Image( | |
| label="查询图片(可选,仅支持上传)", | |
| sources=["upload"], | |
| type="pil" | |
| ), | |
| gr.Radio(["文本输入", "上传 TXT", "视频上传"], label="查找来源类型", value="文本输入"), | |
| gr.Textbox(label="直接输入长文本 或 TXT 内容", lines=8), | |
| gr.File(label="上传 TXT 文件(.txt)", file_types=[".txt"]), | |
| gr.Video(label="上传视频文件(mp4,建议短视频)", sources=["upload"]), | |
| ], | |
| outputs=gr.Textbox(label="检索结果(Top 3 最相似部分)", lines=12), | |
| title="内容检索模式(已集成 Qwen3-VL-Reranker)", | |
| description="支持文字和/或图片查询(至少一个),返回最相似的文本段落或视频帧 + 时间戳。\n已使用 Reranker 提升排序精度。", | |
| ) | |
| # 组合两个 Tab | |
| demo = gr.TabbedInterface( | |
| [embedding_tab, retrieval_tab], | |
| tab_names=["Embedding 生成", "内容检索"] | |
| ) | |
| demo.launch() |