qxyf commited on
Commit
0430e9d
·
1 Parent(s): f4e5afb
Files changed (1) hide show
  1. app.py +21 -10
app.py CHANGED
@@ -66,6 +66,7 @@ def retrieve_content(query_text: str, query_image, source_type: str, text_input:
66
  if not query_text and query_image is None:
67
  return "请至少提供查询文本 或 上传查询图片!"
68
 
 
69
  content = []
70
  if query_text:
71
  content.append({"type": "text", "text": query_text})
@@ -77,7 +78,7 @@ def retrieve_content(query_text: str, query_image, source_type: str, text_input:
77
 
78
  try:
79
  with torch.no_grad():
80
- # 修复在这里:直接传 content,不要 [content]
81
  query_emb = embedder.process(content, normalize=True)[0].cpu().numpy()
82
  except Exception as e:
83
  return f"查询 embedding 生成失败:{str(e)}"
@@ -96,6 +97,7 @@ def retrieve_content(query_text: str, query_image, source_type: str, text_input:
96
  if not text.strip():
97
  return "没有提供有效文本内容!"
98
 
 
99
  segments = []
100
  step = 150
101
  for i in range(0, len(text), step):
@@ -105,9 +107,13 @@ def retrieve_content(query_text: str, query_image, source_type: str, text_input:
105
  seg_embs = []
106
  for seg in segments:
107
  seg_content = [{"type": "text", "text": seg}]
108
- with torch.no_grad():
109
- emb = embedder.process([seg_content], normalize=True)[0].cpu().numpy()
110
- seg_embs.append(emb)
 
 
 
 
111
 
112
  sims = [np.dot(query_emb, emb) / (np.linalg.norm(query_emb) * np.linalg.norm(emb) + 1e-8) for emb in seg_embs]
113
  top_indices = np.argsort(sims)[-3:][::-1]
@@ -138,12 +144,17 @@ def retrieve_content(query_text: str, query_image, source_type: str, text_input:
138
  pil_frame = Image.fromarray(frame_rgb)
139
 
140
  frame_content = [{"type": "image", "image": pil_frame}]
141
- with torch.no_grad():
142
- emb = embedder.process([frame_content], normalize=True)[0].cpu().numpy()
143
- frame_embs.append(emb)
144
-
145
- time_sec = frame_idx / fps
146
- timestamps.append(f"{int(time_sec // 60):02d}:{int(time_sec % 60):02d}")
 
 
 
 
 
147
 
148
  frame_idx += 1
149
 
 
66
  if not query_text and query_image is None:
67
  return "请至少提供查询文本 或 上传查询图片!"
68
 
69
+ # 生成 query embedding
70
  content = []
71
  if query_text:
72
  content.append({"type": "text", "text": query_text})
 
78
 
79
  try:
80
  with torch.no_grad():
81
+ # 修复:直接传 content(已经是 list)
82
  query_emb = embedder.process(content, normalize=True)[0].cpu().numpy()
83
  except Exception as e:
84
  return f"查询 embedding 生成失败:{str(e)}"
 
97
  if not text.strip():
98
  return "没有提供有效文本内容!"
99
 
100
+ # 切段
101
  segments = []
102
  step = 150
103
  for i in range(0, len(text), step):
 
107
  seg_embs = []
108
  for seg in segments:
109
  seg_content = [{"type": "text", "text": seg}]
110
+ try:
111
+ with torch.no_grad():
112
+ # 修复:直接传 seg_content,不要套 [ ]
113
+ emb = embedder.process(seg_content, normalize=True)[0].cpu().numpy()
114
+ seg_embs.append(emb)
115
+ except Exception as e:
116
+ return f"段落 embedding 生成失败:{str(e)}"
117
 
118
  sims = [np.dot(query_emb, emb) / (np.linalg.norm(query_emb) * np.linalg.norm(emb) + 1e-8) for emb in seg_embs]
119
  top_indices = np.argsort(sims)[-3:][::-1]
 
144
  pil_frame = Image.fromarray(frame_rgb)
145
 
146
  frame_content = [{"type": "image", "image": pil_frame}]
147
+ try:
148
+ with torch.no_grad():
149
+ # 修复:直接传 frame_content
150
+ emb = embedder.process(frame_content, normalize=True)[0].cpu().numpy()
151
+ frame_embs.append(emb)
152
+
153
+ time_sec = frame_idx / fps
154
+ timestamps.append(f"{int(time_sec // 60):02d}:{int(time_sec % 60):02d}")
155
+ except Exception as e:
156
+ cap.release()
157
+ return f"视频帧 embedding 生成失败:{str(e)}"
158
 
159
  frame_idx += 1
160