qxyf's picture
降低废用
d92871a
import gradio as gr
import torch
import numpy as np
from transformers import AutoProcessor, AutoModelForSequenceClassification
from scripts.qwen3_vl_embedding import Qwen3VLEmbedder
from scripts.qwen3_vl_reranker import Qwen3VLReranker
import cv2
import os
from typing import List
from PIL import Image
# ================================
# 全局模型加载
# ================================
# Embedding 模型
model_id = "Qwen/Qwen3-VL-Embedding-2B"
embedder = Qwen3VLEmbedder(
model_name_or_path=model_id,
torch_dtype=torch.float16,
)
# Reranker 模型
try:
reranker_id = "Qwen/Qwen3-VL-Reranker-2B"
# 使用官方脚本加载模型
reranker_model = Qwen3VLReranker(
model_name_or_path=reranker_id,
torch_dtype=torch.float16,
device_map="cuda" if torch.cuda.is_available() else "cpu"
)
print("Qwen3-VL-Reranker-2B 加载完成 (Using official script)")
except Exception as e:
print(f"Reranker 加载失败: {str(e)}")
reranker_model = None
# ================================
# Reranker 函数(核心修复版)
# ================================
def extract_content_from_list(content_list):
"""
辅助函数:从 content list 中提取 text, image, video
"""
text_parts = []
images = []
videos = []
for item in content_list:
if item['type'] == 'text':
text_parts.append(item['text'])
elif item['type'] == 'image':
images.append(item['image'])
elif item['type'] == 'video':
videos.append(item['video'])
text = "\n".join(text_parts) if text_parts else None
image = images[0] if images else None
video = videos[0] if videos else None
return text, image, video
def qwen_vl_rerank(query_content: list, candidates: list):
"""
query_content: list of dict,例如 [{"type": "text", "text": "..."}, {"type": "image", "image": pil_img}]
candidates: list of list,每个元素是 candidate 的 content list,例如 [{"type": "image", "image": frame}]
返回: list of (original_index, score) 按分数降序
"""
if not candidates or reranker_model is None:
return []
# 构造符合 Qwen3-VL-Reranker process 方法的输入
q_text, q_image, q_video = extract_content_from_list(query_content)
documents = []
for cand in candidates:
d_text, d_image, d_video = extract_content_from_list(cand)
documents.append({
"text": d_text,
"image": d_image,
"video": d_video
})
inputs = {
"query": {
"text": q_text,
"image": q_image,
"video": q_video
},
"documents": documents
}
try:
# 调用 process 方法
scores = reranker_model.process(inputs)
# 返回 (原索引, 分数) 并排序
indexed_scores = sorted(enumerate(scores), key=lambda x: x[1], reverse=True)
return indexed_scores
except Exception as e:
print(f"Reranker 执行失败: {str(e)}")
return []
# ================================
# Tab 1: Embedding 生成
# ================================
def compute_embedding(text: str = "", image=None):
if not text and image is None:
return "请输入文本或上传图片!", None
content = []
if text:
content.append({"type": "text", "text": text})
if image is not None:
content.append({"type": "image", "image": image})
inputs = content
try:
with torch.no_grad():
embeddings = embedder.process(inputs, normalize=True)
embedding = embeddings[0].cpu().numpy().flatten()
emb_str = ", ".join(f"{x:.4f}" for x in embedding[:10])
result_text = f"Embedding 向量(前10维):[{emb_str}]...\n总维度: {len(embedding)}"
except Exception as e:
result_text = f"生成 embedding 失败:{str(e)}"
return result_text, image
embedding_tab = gr.Interface(
fn=compute_embedding,
inputs=[
gr.Textbox(label="输入文本(可选)", placeholder="例如:一只猫在阳光下睡觉"),
gr.Image(
label="上传本地图片(仅支持上传,不支持链接)",
sources=["upload"],
type="pil"
),
],
outputs=[
gr.Textbox(label="生成的 Embedding"),
gr.Image(label="上传的原图片预览"),
],
title="Embedding 生成(仅支持本地上传图片)",
description="支持文本 + 本地上传图片的多模态 embedding 生成。图片上传后会自动展示。",
examples=[
["一只可爱的猫咪在阳光下睡觉", None],
["这张图片里有什么动物?", None],
],
cache_examples=False,
)
# ================================
# Tab 2: 内容检索(含 Reranker)
# ================================
def retrieve_content(query_text: str, query_image, source_type: str, text_input: str, txt_file, video_file):
if not query_text and query_image is None:
return "请至少提供查询文本 或 上传查询图片!"
# 生成 query content(用于 reranker)
query_content = []
if query_text:
query_content.append({"type": "text", "text": query_text})
if query_image is not None:
query_content.append({"type": "image", "image": query_image})
# 生成 query embedding(用于初筛)
try:
with torch.no_grad():
query_emb = embedder.process(query_content, normalize=True)[0].cpu().numpy()
except Exception as e:
return f"查询 embedding 生成失败:{str(e)}"
results = []
if source_type in ["文本输入", "上传 TXT"]:
text = text_input
if txt_file is not None:
try:
with open(txt_file.name, "r", encoding="utf-8") as f:
text = f.read()
except Exception as e:
return f"读取 TXT 文件失败:{str(e)}"
if not text.strip():
return "没有提供有效文本内容!"
# 切段
segments = []
step = 150
for i in range(0, len(text), step):
seg = text[i:i+step+100]
segments.append(seg)
seg_embs = []
seg_contents = [] # 保存用于 reranker 的 content list
for seg in segments:
seg_content = [{"type": "text", "text": seg}]
try:
with torch.no_grad():
emb = embedder.process(seg_content, normalize=True)[0].cpu().numpy()
seg_embs.append(emb)
seg_contents.append(seg_content)
except Exception as e:
return f"段落 embedding 生成失败:{str(e)}"
# Embedding 初筛 Top-10
sims = [np.dot(query_emb, emb) / (np.linalg.norm(query_emb) * np.linalg.norm(emb) + 1e-8) for emb in seg_embs]
top_indices = np.argsort(sims)[-5:][::-1] # Top-10 初筛
# 准备 reranker candidates
top_candidates = [seg_contents[i] for i in top_indices]
# 调用 reranker
reranked = qwen_vl_rerank(query_content, top_candidates)
results = []
if reranked:
for r_idx, score in reranked[:3]:
orig_idx = top_indices[r_idx]
results.append(f"Reranker 分数: {score:.4f} | 原始相似度: {sims[orig_idx]:.4f}\n段落: {segments[orig_idx][:200]}...\n")
else:
# 如果 reranker 失败,回退到 embedding 排序
for idx in top_indices[:3]:
results.append(f"相似度: {sims[idx]:.4f}\n段落: {segments[idx][:200]}...\n(仅 Embedding 排序)")
elif source_type == "视频上传":
if video_file is None:
return "请上传视频文件!"
video_path = video_file
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
return "无法打开视频文件!"
fps = cap.get(cv2.CAP_PROP_FPS) or 30
frame_idx = 0
frame_embs = []
timestamps = []
frame_indices = []
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
if frame_idx % 60 == 0: # 每秒约1帧
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
pil_frame = Image.fromarray(frame_rgb)
frame_content = [{"type": "image", "image": pil_frame}]
try:
with torch.no_grad():
emb = embedder.process(frame_content, normalize=True)[0].cpu().numpy()
frame_embs.append(emb)
time_sec = frame_idx / fps
timestamps.append(f"{int(time_sec // 60):02d}:{int(time_sec % 60):02d}")
frame_indices.append(frame_idx)
except Exception as e:
cap.release()
return f"视频帧 embedding 生成失败:{str(e)}"
frame_idx += 1
cap.release()
if not frame_embs:
return "视频无有效帧!"
# Embedding 初筛 Top-10
sims = [np.dot(query_emb, emb) / (np.linalg.norm(query_emb) * np.linalg.norm(emb) + 1e-8) for emb in frame_embs]
top_indices = np.argsort(sims)[-10:][::-1]
# 准备 Reranker candidates
candidates = []
valid_indices = []
rerank_cap = cv2.VideoCapture(video_path)
for orig_idx in top_indices:
frame_num = frame_indices[orig_idx]
rerank_cap.set(cv2.CAP_PROP_POS_FRAMES, frame_num)
ret, frame = rerank_cap.read()
if ret:
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
pil_frame = Image.fromarray(frame_rgb)
candidates.append([{"type": "image", "image": pil_frame}])
valid_indices.append(orig_idx)
rerank_cap.release()
# 调用 Reranker
if candidates:
reranked_results = qwen_vl_rerank(query_content, candidates)
for r_idx, score in reranked_results[:3]:
original_idx = valid_indices[r_idx]
ts = timestamps[original_idx]
results.append(f"Reranker 分数: {score:.4f} | 原始相似度: {sims[original_idx]:.4f} | 时间戳: {ts}\n(Reranker 优选)")
else:
# 回退
for idx in top_indices[:3]:
results.append(f"相似度: {sims[idx]:.4f} | 时间戳: {timestamps[idx]}\n(仅 Embedding 排序)")
if not results:
return "未找到匹配内容"
return "\n---\n".join(results)
retrieval_tab = gr.Interface(
fn=retrieve_content,
inputs=[
gr.Textbox(label="查询描述(文字,可选)", placeholder="穿着白衬衫的男人 或 可爱的猫咪在阳光下睡觉"),
gr.Image(
label="查询图片(可选,仅支持上传)",
sources=["upload"],
type="pil"
),
gr.Radio(["文本输入", "上传 TXT", "视频上传"], label="查找来源类型", value="文本输入"),
gr.Textbox(label="直接输入长文本 或 TXT 内容", lines=8),
gr.File(label="上传 TXT 文件(.txt)", file_types=[".txt"]),
gr.Video(label="上传视频文件(mp4,建议短视频)", sources=["upload"]),
],
outputs=gr.Textbox(label="检索结果(Top 3 最相似部分)", lines=12),
title="内容检索模式(已集成 Qwen3-VL-Reranker)",
description="支持文字和/或图片查询(至少一个),返回最相似的文本段落或视频帧 + 时间戳。\n已使用 Reranker 提升排序精度。",
)
# 组合两个 Tab
demo = gr.TabbedInterface(
[embedding_tab, retrieval_tab],
tab_names=["Embedding 生成", "内容检索"]
)
demo.launch()