import os import numbers import asyncio from typing import Any, Union from .prompt import PROMPTS from ._utils import logger from ._llm import Qwen3 from ._videoutil import ( retrieved_segment_caption, retrieved_segment_caption_kw, ) from dotenv import load_dotenv from transformers import AutoTokenizer load_dotenv() qwen3_model = Qwen3() tiktoken_model_path = qwen3_model.download_tokenizer_files() tiktoken_model_path = os.path.abspath(tiktoken_model_path) try: tiktoken_model_p = AutoTokenizer.from_pretrained( tiktoken_model_path, trust_remote_code=True, local_files_only=True # <--- 关键参数 ) except Exception as e: print(f"加载本地 tokenizer 失败: {e}") print(f"请检查路径是否存在: {tiktoken_model_path}") raise e def truncate_list_by_token_size(list_data: list, key: callable, max_token_size: int): """Truncate a list of data by token size""" if max_token_size <= 0: return [] tokens = 0 for i, data in enumerate(list_data): tokens += len(tiktoken_model_p.encode(key(data))) if tokens > max_token_size: return list_data[:i] return list_data def _extract_keywords_query( query ): # use_llm_func: callable = global_config["llm"]["cheap_model_func"] keywords_prompt = PROMPTS["keywords_extraction"] keywords_prompt = keywords_prompt.format(input_text=query) messages = [ {"role": "user", "content": keywords_prompt} ] final_result = qwen3_model.generate_result(messages) return final_result def _result_query(query,sys_prompt): messages = [ {"role": "system", "content": sys_prompt}, {"role": "user", "content": query} ] content = qwen3_model.generate_result(messages) return content def enclose_string_with_quotes(content: Any) -> str: """Enclose a string with quotes""" if isinstance(content, numbers.Number): return str(content) content = str(content) content = content.strip().strip("'").strip('"') return f'"{content}"' def list_of_list_to_csv(data: list[list]): return "\n".join( [ ",\t".join([f"{enclose_string_with_quotes(data_dd)}" for data_dd in data_d]) for data_d in data ] ) async def videorag_query( query, text_chunks_db, chunks_vdb, video_path_db, video_segments, video_segment_feature_vdb, query_param, ) -> str: results = await chunks_vdb.query(query) if not len(results): return PROMPTS["fail_response"] chunks_ids = [r["id"] for r in results] chunks = await text_chunks_db.get_by_ids(chunks_ids) # print("chunks :\n", chunks) maybe_trun_chunks = truncate_list_by_token_size( chunks, key=lambda x: x["content"], max_token_size = query_param.naive_max_token_for_text_unit, ) logger.info(f"Truncate {len(chunks)} to {len(maybe_trun_chunks)} chunks") section = "-----New Chunk-----\n".join([c["content"] for c in maybe_trun_chunks]) retreived_chunk_context = section query_for_visual_retrieval = query segment_results = await video_segment_feature_vdb.query(query_for_visual_retrieval) # print(f"Retrieved Segments {segment_results}") visual_retrieved_segments = set() if len(segment_results): for n in segment_results: visual_retrieved_segments.add(n['__id__']) # caption retrieved_segments = sorted( visual_retrieved_segments, key=lambda x: ( '_'.join(x.split('_')[:-1]), # video_name eval(x.split('_')[-1]) # index ) ) rough_captions = {} print("retrieved_segments :: \n", retrieved_segments) for s_id in retrieved_segments: video_name = '_'.join(s_id.split('_')[:-1]) index = s_id.split('_')[-1] rough_captions[s_id] = video_segments._data[video_name][index]["content"] remain_segments = retrieved_segments keywords_for_caption = _extract_keywords_query(query ) print(f"Keywords: {keywords_for_caption}") # caption_results = retrieved_segment_caption( # remain_segments, # video_path_db, # video_segments, # num_sampled_frames = query_param.retrieved_num_sampled_frames # ) caption_results = retrieved_segment_caption_kw( remain_segments, video_path_db, video_segments, keywords_for_caption, num_sampled_frames = query_param.retrieved_num_sampled_frames ) ## data table text_units_section_list = [["video_name", "start_time", "end_time", "content"]] for s_id in caption_results: video_name = '_'.join(s_id.split('_')[:-1]) index = s_id.split('_')[-1] start_time = eval(video_segments._data[video_name][index]["time"].split('-')[0]) end_time = eval(video_segments._data[video_name][index]["time"].split('-')[1]) start_time = f"{start_time // 3600}:{(start_time % 3600) // 60}:{start_time % 60}" end_time = f"{end_time // 3600}:{(end_time % 3600) // 60}:{end_time % 60}" text_units_section_list.append([video_name, start_time, end_time, caption_results[s_id]]) text_units_context = list_of_list_to_csv(text_units_section_list) retreived_video_context = f"\n-----Retrieved Knowledge From Videos-----\n```csv\n{text_units_context}\n```\n" # print(retreived_video_context) sys_prompt_temp = PROMPTS["videorag_response"] sys_prompt = sys_prompt_temp.format( video_data=retreived_video_context, chunk_data=retreived_chunk_context, ) response = _result_query(query, sys_prompt, ) return response