File size: 5,752 Bytes
10a30d9 6a5243f 10a30d9 6a5243f 10a30d9 6a5243f 10a30d9 6a5243f 10a30d9 6a5243f 10a30d9 6a5243f 10a30d9 6a5243f 10a30d9 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 | import os
import numbers
import asyncio
from typing import Any, Union
from .prompt import PROMPTS
from ._utils import logger
from ._llm import Qwen3
from ._videoutil import (
retrieved_segment_caption,
retrieved_segment_caption_kw,
)
from dotenv import load_dotenv
from transformers import AutoTokenizer
load_dotenv()
qwen3_model = Qwen3()
tiktoken_model_path = qwen3_model.download_tokenizer_files()
tiktoken_model_path = os.path.abspath(tiktoken_model_path)
try:
tiktoken_model_p = AutoTokenizer.from_pretrained(
tiktoken_model_path,
trust_remote_code=True,
local_files_only=True # <--- 关键参数
)
except Exception as e:
print(f"加载本地 tokenizer 失败: {e}")
print(f"请检查路径是否存在: {tiktoken_model_path}")
raise e
def truncate_list_by_token_size(list_data: list, key: callable, max_token_size: int):
"""Truncate a list of data by token size"""
if max_token_size <= 0:
return []
tokens = 0
for i, data in enumerate(list_data):
tokens += len(tiktoken_model_p.encode(key(data)))
if tokens > max_token_size:
return list_data[:i]
return list_data
def _extract_keywords_query(
query
):
# use_llm_func: callable = global_config["llm"]["cheap_model_func"]
keywords_prompt = PROMPTS["keywords_extraction"]
keywords_prompt = keywords_prompt.format(input_text=query)
messages = [
{"role": "user", "content": keywords_prompt}
]
final_result = qwen3_model.generate_result(messages)
return final_result
def _result_query(query,sys_prompt):
messages = [
{"role": "system", "content": sys_prompt},
{"role": "user", "content": query}
]
content = qwen3_model.generate_result(messages)
return content
def enclose_string_with_quotes(content: Any) -> str:
"""Enclose a string with quotes"""
if isinstance(content, numbers.Number):
return str(content)
content = str(content)
content = content.strip().strip("'").strip('"')
return f'"{content}"'
def list_of_list_to_csv(data: list[list]):
return "\n".join(
[
",\t".join([f"{enclose_string_with_quotes(data_dd)}" for data_dd in data_d])
for data_d in data
]
)
async def videorag_query(
query,
text_chunks_db,
chunks_vdb,
video_path_db,
video_segments,
video_segment_feature_vdb,
query_param,
) -> str:
results = await chunks_vdb.query(query)
if not len(results):
return PROMPTS["fail_response"]
chunks_ids = [r["id"] for r in results]
chunks = await text_chunks_db.get_by_ids(chunks_ids)
# print("chunks :\n", chunks)
maybe_trun_chunks = truncate_list_by_token_size(
chunks,
key=lambda x: x["content"],
max_token_size = query_param.naive_max_token_for_text_unit,
)
logger.info(f"Truncate {len(chunks)} to {len(maybe_trun_chunks)} chunks")
section = "-----New Chunk-----\n".join([c["content"] for c in maybe_trun_chunks])
retreived_chunk_context = section
query_for_visual_retrieval = query
segment_results = await video_segment_feature_vdb.query(query_for_visual_retrieval)
# print(f"Retrieved Segments {segment_results}")
visual_retrieved_segments = set()
if len(segment_results):
for n in segment_results:
visual_retrieved_segments.add(n['__id__'])
# caption
retrieved_segments = sorted(
visual_retrieved_segments,
key=lambda x: (
'_'.join(x.split('_')[:-1]), # video_name
eval(x.split('_')[-1]) # index
)
)
rough_captions = {}
print("retrieved_segments :: \n", retrieved_segments)
for s_id in retrieved_segments:
video_name = '_'.join(s_id.split('_')[:-1])
index = s_id.split('_')[-1]
rough_captions[s_id] = video_segments._data[video_name][index]["content"]
remain_segments = retrieved_segments
keywords_for_caption = _extract_keywords_query(query )
print(f"Keywords: {keywords_for_caption}")
# caption_results = retrieved_segment_caption(
# remain_segments,
# video_path_db,
# video_segments,
# num_sampled_frames = query_param.retrieved_num_sampled_frames
# )
caption_results = retrieved_segment_caption_kw(
remain_segments,
video_path_db,
video_segments,
keywords_for_caption,
num_sampled_frames = query_param.retrieved_num_sampled_frames
)
## data table
text_units_section_list = [["video_name", "start_time", "end_time", "content"]]
for s_id in caption_results:
video_name = '_'.join(s_id.split('_')[:-1])
index = s_id.split('_')[-1]
start_time = eval(video_segments._data[video_name][index]["time"].split('-')[0])
end_time = eval(video_segments._data[video_name][index]["time"].split('-')[1])
start_time = f"{start_time // 3600}:{(start_time % 3600) // 60}:{start_time % 60}"
end_time = f"{end_time // 3600}:{(end_time % 3600) // 60}:{end_time % 60}"
text_units_section_list.append([video_name, start_time, end_time, caption_results[s_id]])
text_units_context = list_of_list_to_csv(text_units_section_list)
retreived_video_context = f"\n-----Retrieved Knowledge From Videos-----\n```csv\n{text_units_context}\n```\n"
# print(retreived_video_context)
sys_prompt_temp = PROMPTS["videorag_response"]
sys_prompt = sys_prompt_temp.format(
video_data=retreived_video_context,
chunk_data=retreived_chunk_context,
)
response = _result_query(query,
sys_prompt,
)
return response
|