Spaces:
Sleeping
Sleeping
| # SPDX-FileCopyrightText: Copyright (c) 2023-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
| # SPDX-License-Identifier: MIT | |
| # | |
| # Permission is hereby granted, free of charge, to any person obtaining a | |
| # copy of this software and associated documentation files (the "Software"), | |
| # to deal in the Software without restriction, including without limitation | |
| # the rights to use, copy, modify, merge, publish, distribute, sublicense, | |
| # and/or sell copies of the Software, and to permit persons to whom the | |
| # Software is furnished to do so, subject to the following conditions: | |
| # | |
| # The above copyright notice and this permission notice shall be included in | |
| # all copies or substantial portions of the Software. | |
| # | |
| # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | |
| # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | |
| # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL | |
| # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | |
| # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING | |
| # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER | |
| # DEALINGS IN THE SOFTWARE. | |
| import os | |
| import sys | |
| import time | |
| import calendar | |
| import json | |
| from model_setup_manager import download_model_by_name, build_engine_by_name | |
| import logging | |
| import gc | |
| import torch | |
| from pathlib import Path | |
| from trt_llama_api import TrtLlmAPI | |
| from whisper.trt_whisper import WhisperTRTLLM, decode_audio_file | |
| #from langchain.embeddings.huggingface import HuggingFaceEmbeddings | |
| #from langchain_community.embeddings import HuggingFaceBgeEmbeddings | |
| from langchain.embeddings.huggingface import HuggingFaceEmbeddings | |
| from collections import defaultdict | |
| from llama_index import ServiceContext | |
| from llama_index.llms.llama_utils import messages_to_prompt, completion_to_prompt | |
| from llama_index import set_global_service_context | |
| from faiss_vector_storage import FaissEmbeddingStorage | |
| from ui.user_interface import MainInterface | |
| from scipy.io import wavfile | |
| import scipy.signal as sps | |
| import numpy as np | |
| from pynvml import nvmlInit, nvmlDeviceGetHandleByIndex, nvmlDeviceGetMemoryInfo | |
| from CLIP import run_model_on_images, CLIPEmbeddingStorageEngine | |
| from PIL import Image | |
| from transformers import CLIPProcessor, CLIPModel | |
| import shutil | |
| from llm_prompt_templates import LLMPromptTemplate | |
| from utils import (read_model_name) | |
| import win32api | |
| import win32security | |
| selected_CLIP = False | |
| clip_engine = None | |
| selected_ChatGLM = False | |
| app_config_file = 'config\\app_config.json' | |
| model_config_file = 'config\\config.json' | |
| preference_config_file = 'config\\preferences.json' | |
| data_source = 'directory' | |
| # Use GetCurrentProcess to get a handle to the current process | |
| hproc = win32api.GetCurrentProcess() | |
| # Use GetCurrentProcessToken to get the token of the current process | |
| htok = win32security.OpenProcessToken(hproc, win32security.TOKEN_QUERY) | |
| # Retrieve the list of privileges enabled | |
| privileges = win32security.GetTokenInformation(htok, win32security.TokenPrivileges) | |
| # Iterate over privileges and output the ones that are enabled | |
| priv_list = [] | |
| for priv_id, priv_flags in privileges: | |
| # Check if privilege is enabled | |
| if priv_flags == win32security.SE_PRIVILEGE_ENABLED or win32security.SE_PRIVILEGE_ENABLED_BY_DEFAULT: | |
| # Lookup the name of the privilege | |
| priv_name = win32security.LookupPrivilegeName(None, priv_id) | |
| priv_list.append(priv_name) | |
| print(f"Privileges of app process: {priv_list}") | |
| def read_config(file_name): | |
| try: | |
| with open(file_name, 'r', encoding='utf8') as file: | |
| return json.load(file) | |
| except FileNotFoundError: | |
| print(f"The file {file_name} was not found.") | |
| except json.JSONDecodeError: | |
| print(f"There was an error decoding the JSON from the file {file_name}.") | |
| except Exception as e: | |
| print(f"An unexpected error occurred: {e}") | |
| return None | |
| def get_model_config(config, model_name=None): | |
| selected_model = next((model for model in config["models"]["supported"] if model["name"] == model_name), | |
| config["models"]["supported"][0]) | |
| metadata = selected_model["metadata"] | |
| cwd = os.getcwd() # Current working directory, to avoid calling os.getcwd() multiple times | |
| if "ngc_model_name" in selected_model: | |
| return { | |
| "model_path": os.path.join(cwd, "model", selected_model["id"], "engine") if "id" in selected_model else None, | |
| "engine": metadata.get("engine", None), | |
| "tokenizer_path": os.path.join(cwd, "model", selected_model["id"] ,selected_model["prerequisite"]["tokenizer_local_dir"] ) if "tokenizer_local_dir" in selected_model["prerequisite"] else None, | |
| "vocab": os.path.join(cwd, "model", selected_model["id"] ,selected_model["prerequisite"]["vocab_local_dir"], selected_model["prerequisite"]["tokenizer_files"]["vocab_file"]) if "vocab_local_dir" in selected_model["prerequisite"] else None, | |
| "max_new_tokens": metadata.get("max_new_tokens", None), | |
| "max_input_token": metadata.get("max_input_token", None), | |
| "temperature": metadata.get("temperature", None), | |
| "prompt_template": metadata.get("prompt_template", None) | |
| } | |
| elif "hf_model_name" in selected_model: | |
| return { | |
| "model_path": os.path.join(cwd, "model", selected_model["id"]) if "id" in selected_model else None, | |
| "tokenizer_path": os.path.join(cwd, "model", selected_model["id"]) if "id" in selected_model else None, | |
| "prompt_template": metadata.get("prompt_template", None) | |
| } | |
| def get_asr_model_config(config, model_name=None): | |
| models = config["models"]["supported_asr"] | |
| selected_model = next((model for model in models if model["name"] == model_name), models[0]) | |
| return { | |
| "model_path": os.path.join(os.getcwd(), selected_model["metadata"]["model_path"]), | |
| "assets_path": os.path.join(os.getcwd(), selected_model["metadata"]["assets_path"]) | |
| } | |
| def get_data_path(config): | |
| return os.path.join(os.getcwd(), config["dataset"]["path"]) | |
| # read the app specific config | |
| app_config = read_config(app_config_file) | |
| streaming = app_config["streaming"] | |
| similarity_top_k = app_config["similarity_top_k"] | |
| is_chat_engine = app_config["is_chat_engine"] | |
| embedded_model_name = app_config["embedded_model"] | |
| embedded_model = os.path.join(os.getcwd(), "model", embedded_model_name) | |
| embedded_dimension = app_config["embedded_dimension"] | |
| use_py_session = app_config["use_py_session"] | |
| trtLlm_debug_mode = app_config["trtLlm_debug_mode"] | |
| add_special_tokens = app_config["add_special_tokens"] | |
| verbose = app_config["verbose"] | |
| # read model specific config | |
| selected_model_name = None | |
| selected_data_directory = None | |
| config = read_config(model_config_file) | |
| if os.path.exists(preference_config_file): | |
| perf_config = read_config(preference_config_file) | |
| selected_model_name = perf_config.get('models', {}).get('selected') | |
| selected_data_directory = perf_config.get('dataset', {}).get('path') | |
| if selected_model_name == None: | |
| selected_model_name = config["models"].get("selected") | |
| if selected_model_name == "CLIP": | |
| selected_CLIP = True | |
| if selected_model_name == "ChatGLM 3 6B int4 (Supports Chinese)": | |
| selected_ChatGLM = True | |
| model_config = get_model_config(config, selected_model_name) | |
| data_dir = config["dataset"]["path"] if selected_data_directory == None else selected_data_directory | |
| asr_model_name = "Whisper Medium Int8" | |
| asr_model_config = get_asr_model_config(config, asr_model_name) | |
| asr_engine_path = asr_model_config["model_path"] | |
| asr_assets_path = asr_model_config["assets_path"] | |
| whisper_model = None | |
| whisper_model_loaded = False | |
| enable_asr = config["models"]["enable_asr"] | |
| nvmlInit() | |
| def generate_inferance_engine(data, force_rewrite=False): | |
| """ | |
| Initialize and return a FAISS-based inference engine. | |
| Args: | |
| data: The directory where the data for the inference engine is located. | |
| force_rewrite (bool): If True, force rewriting the index. | |
| Returns: | |
| The initialized inference engine. | |
| Raises: | |
| RuntimeError: If unable to generate the inference engine. | |
| """ | |
| try: | |
| global engine | |
| faiss_storage = FaissEmbeddingStorage(data_dir=data, | |
| dimension=embedded_dimension) | |
| faiss_storage.initialize_index(force_rewrite=force_rewrite) | |
| engine = faiss_storage.get_engine(is_chat_engine=is_chat_engine, streaming=streaming, | |
| similarity_top_k=similarity_top_k) | |
| except Exception as e: | |
| raise RuntimeError(f"Unable to generate the inference engine: {e}") | |
| def generate_clip_engine(data_dir, model_path, clip_model, clip_processor, force_rewrite=False): | |
| global clip_engine | |
| clip_engine = CLIPEmbeddingStorageEngine(data_dir, model_path, clip_model, clip_processor) | |
| clip_engine.create_nodes(force_rewrite) | |
| clip_engine.initialize_index(force_rewrite) | |
| llm = None | |
| embed_model = None | |
| service_context = None | |
| clip_model = None | |
| clip_processor = None | |
| if selected_CLIP: | |
| # Initialize model and processor | |
| clip_model = CLIPModel.from_pretrained(model_config["model_path"]).to('cuda') | |
| clip_processor = CLIPProcessor.from_pretrained(model_config["model_path"]) | |
| generate_clip_engine(data_dir, model_config["model_path"], clip_model, clip_processor) | |
| else: | |
| # create trt_llm engine object | |
| model_name, _ = read_model_name(model_config["model_path"]) | |
| prompt_template_obj = LLMPromptTemplate() | |
| text_qa_template_str = prompt_template_obj.model_context_template(model_name) | |
| selected_completion_to_prompt = text_qa_template_str | |
| llm = TrtLlmAPI( | |
| model_path=model_config["model_path"], | |
| engine_name=model_config["engine"], | |
| tokenizer_dir=model_config["tokenizer_path"], | |
| temperature=model_config["temperature"], | |
| max_new_tokens=model_config["max_new_tokens"], | |
| context_window=model_config["max_input_token"], | |
| vocab_file=model_config["vocab"], | |
| messages_to_prompt=messages_to_prompt, | |
| completion_to_prompt=selected_completion_to_prompt, | |
| use_py_session=use_py_session, | |
| add_special_tokens=add_special_tokens, | |
| trtLlm_debug_mode=trtLlm_debug_mode, | |
| verbose=verbose | |
| ) | |
| # create embeddings model object | |
| embed_model = HuggingFaceEmbeddings(model_name=embedded_model) | |
| service_context = ServiceContext.from_defaults(llm=llm, embed_model=embed_model, | |
| context_window=model_config["max_input_token"], chunk_size=512, | |
| chunk_overlap=200) | |
| set_global_service_context(service_context) | |
| # load the vectorstore index | |
| generate_inferance_engine(data_dir) | |
| def call_llm_streamed(query): | |
| partial_response = "" | |
| response = llm.stream_complete(query, formatted=False) | |
| for token in response: | |
| partial_response += token.delta | |
| yield partial_response | |
| def chatbot(query, chat_history, session_id): | |
| if selected_CLIP: | |
| ts = calendar.timegm(time.gmtime()) | |
| temp_image_folder_name = "Temp/Temp_Images" | |
| if os.path.isdir(temp_image_folder_name): | |
| try: | |
| shutil.rmtree(os.path.join(os.getcwd(), temp_image_folder_name)) | |
| except Exception as e: | |
| print("Exception during folder delete", e) | |
| image_results_path = os.path.join(os.getcwd(), temp_image_folder_name, str(ts)) | |
| res_im_paths = clip_engine.query(query, image_results_path) | |
| if len(res_im_paths) == 0: | |
| yield "No supported images found in the selected folder" | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| return | |
| div_start = '<div class="chat-output-images">' | |
| div_end = '</div>' | |
| im_elements = '' | |
| for i, im in enumerate(res_im_paths): | |
| if i>2 : break # display atmost 3 images. | |
| cur_data_link_src = temp_image_folder_name +"/" + str(ts) + "/" + os.path.basename(im) | |
| cur_src = "file/" + temp_image_folder_name +"/" + str(ts) + "/" + os.path.basename(im) | |
| im_elements += '<img data-link="{data_link_src}" src="{src}"/>'.format(src=cur_src, data_link_src=cur_data_link_src) | |
| full_div = (div_start + im_elements + div_end) | |
| folder_link = f'<a data-link="{image_results_path}">{"See all matches"}</a>' | |
| prefix = "" | |
| if(len(res_im_paths)>1): | |
| prefix = "Here are the top matching pictures from your dataset" | |
| else: | |
| prefix = "Here is the top matching picture from your dataset" | |
| response = prefix + "<br>"+ full_div + "<br>"+ folder_link | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| yield response | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| return | |
| if data_source == "nodataset": | |
| yield llm.complete(query, formatted=False).text | |
| return | |
| if is_chat_engine: | |
| response = engine.chat(query) | |
| else: | |
| response = engine.query(query) | |
| lowest_score_file = None | |
| lowest_score = sys.float_info.max | |
| for node in response.source_nodes: | |
| metadata = node.metadata | |
| if 'filename' in metadata: | |
| if node.score < lowest_score: | |
| lowest_score = node.score | |
| lowest_score_file = metadata['filename'] | |
| file_links = [] | |
| seen_files = set() # Set to track unique file names | |
| ts = calendar.timegm(time.gmtime()) | |
| temp_docs_folder_name = "Temp/Temp_Docs" | |
| docs_path = os.path.join(os.getcwd(), temp_docs_folder_name, str(ts)) | |
| os.makedirs(docs_path, exist_ok=True) | |
| # Generate links for the file with the highest aggregated score | |
| if lowest_score_file: | |
| abs_path = Path(os.path.join(os.getcwd(), lowest_score_file.replace('\\', '/'))) | |
| file_name = os.path.basename(abs_path) | |
| doc_path = os.path.join(docs_path, file_name) | |
| shutil.copy(abs_path, doc_path) | |
| if file_name not in seen_files: # Ensure the file hasn't already been processed | |
| if data_source == 'directory': | |
| file_link = f'<a data-link="{doc_path}">{file_name}</a>' | |
| else: | |
| exit("Wrong data_source type") | |
| file_links.append(file_link) | |
| seen_files.add(file_name) # Mark file as processed | |
| response_txt = str(response) | |
| if file_links: | |
| response_txt += "<br>Reference files:<br>" + "<br>".join(file_links) | |
| if not lowest_score_file: # If no file with a high score was found | |
| response_txt = llm.complete(query).text | |
| yield response_txt | |
| def stream_chatbot(query, chat_history, session_id): | |
| if selected_CLIP: | |
| ts = calendar.timegm(time.gmtime()) | |
| temp_image_folder_name = "Temp/Temp_Images" | |
| if os.path.isdir(temp_image_folder_name): | |
| try: | |
| shutil.rmtree(os.path.join(os.getcwd(), temp_image_folder_name)) | |
| except Exception as e: | |
| print("Exception during folder delete", e) | |
| image_results_path = os.path.join(os.getcwd(), temp_image_folder_name, str(ts)) | |
| res_im_paths = clip_engine.query(query, image_results_path) | |
| if len(res_im_paths) == 0: | |
| yield "No supported images found in the selected folder" | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| return | |
| div_start = '<div class="chat-output-images">' | |
| div_end = '</div>' | |
| im_elements = '' | |
| for i, im in enumerate(res_im_paths): | |
| if i>2 : break # display atmost 3 images. | |
| cur_data_link_src = temp_image_folder_name +"/" + str(ts) + "/" + os.path.basename(im) | |
| cur_src = "file/" + temp_image_folder_name +"/" + str(ts) + "/" + os.path.basename(im) | |
| im_elements += '<img data-link="{data_link_src}" src="{src}"/>'.format(src=cur_src, data_link_src=cur_data_link_src) | |
| full_div = (div_start + im_elements + div_end) | |
| folder_link = f'<a data-link="{image_results_path}">{"See all matches"}</a>' | |
| prefix = "" | |
| if(len(res_im_paths)>1): | |
| prefix = "Here are the top matching pictures from your dataset" | |
| else: | |
| prefix = "Here is the top matching picture from your dataset" | |
| response = prefix + "<br>"+ full_div + "<br>"+ folder_link | |
| yield response | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| return | |
| if data_source == "nodataset": | |
| for response in call_llm_streamed(query): | |
| yield response | |
| return | |
| if is_chat_engine: | |
| response = engine.stream_chat(query) | |
| else: | |
| response = engine.query(query) | |
| partial_response = "" | |
| if len(response.source_nodes) == 0: | |
| response = llm.stream_complete(query, formatted=False) | |
| for token in response: | |
| partial_response += token.delta | |
| yield partial_response | |
| else: | |
| # Aggregate scores by file | |
| lowest_score_file = None | |
| lowest_score = sys.float_info.max | |
| for node in response.source_nodes: | |
| if 'filename' in node.metadata: | |
| if node.score < lowest_score: | |
| lowest_score = node.score | |
| lowest_score_file = node.metadata['filename'] | |
| file_links = [] | |
| seen_files = set() | |
| for token in response.response_gen: | |
| partial_response += token | |
| yield partial_response | |
| time.sleep(0.05) | |
| time.sleep(0.2) | |
| ts = calendar.timegm(time.gmtime()) | |
| temp_docs_folder_name = "Temp/Temp_Docs" | |
| docs_path = os.path.join(os.getcwd(), temp_docs_folder_name, str(ts)) | |
| os.makedirs(docs_path, exist_ok=True) | |
| if lowest_score_file: | |
| abs_path = Path(os.path.join(os.getcwd(), lowest_score_file.replace('\\', '/'))) | |
| file_name = os.path.basename(abs_path) | |
| doc_path = os.path.join(docs_path, file_name) | |
| shutil.copy(abs_path, doc_path) | |
| if file_name not in seen_files: # Check if file_name is already seen | |
| if data_source == 'directory': | |
| file_link = f'<a data-link="{doc_path}">{file_name}</a>' | |
| else: | |
| exit("Wrong data_source type") | |
| file_links.append(file_link) | |
| seen_files.add(file_name) # Add file_name to the set | |
| if file_links: | |
| partial_response += "<br>Reference files:<br>" + "<br>".join(file_links) | |
| yield partial_response | |
| # call garbage collector after inference | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| interface = MainInterface(chatbot=stream_chatbot if streaming else chatbot, streaming=streaming) | |
| def on_shutdown_handler(session_id): | |
| global llm, whisper_model, clip_model, clip_processor, clip_engine | |
| import gc | |
| if whisper_model is not None: | |
| whisper_model.unload_model() | |
| del whisper_model | |
| whisper_model = None | |
| if llm is not None: | |
| llm.unload_model() | |
| del llm | |
| llm = None | |
| if clip_model is not None: | |
| del clip_model | |
| del clip_processor | |
| del clip_engine | |
| clip_model = None | |
| clip_processor = None | |
| clip_engine = None | |
| temp_data_folder_name = "Temp" | |
| if os.path.isdir(temp_data_folder_name): | |
| try: | |
| shutil.rmtree(os.path.join(os.getcwd(), temp_data_folder_name)) | |
| except Exception as e: | |
| print("Exception during temp folder delete", e) | |
| # Force a garbage collection cycle | |
| gc.collect() | |
| interface.on_shutdown(on_shutdown_handler) | |
| def reset_chat_handler(session_id): | |
| global faiss_storage | |
| global engine | |
| print('reset chat called', session_id) | |
| if selected_CLIP: | |
| return | |
| if is_chat_engine == True: | |
| faiss_storage.reset_engine(engine) | |
| interface.on_reset_chat(reset_chat_handler) | |
| def on_dataset_path_updated_handler(source, new_directory, video_count, session_id): | |
| print('data set path updated to ', source, new_directory, video_count, session_id) | |
| global engine | |
| global data_dir | |
| if selected_CLIP: | |
| data_dir = new_directory | |
| generate_clip_engine(data_dir, model_config["model_path"], clip_model, clip_processor) | |
| return | |
| if source == 'directory': | |
| if data_dir != new_directory: | |
| data_dir = new_directory | |
| generate_inferance_engine(data_dir) | |
| interface.on_dataset_path_updated(on_dataset_path_updated_handler) | |
| def on_model_change_handler(model, model_info, session_id): | |
| global llm, embedded_model, engine, data_dir, service_context, clip_model, clip_processor, selected_CLIP, selected_model_name, embed_model, model_config, selected_ChatGLM, clip_engine | |
| selected_model_name = model | |
| selected_ChatGLM = False | |
| if llm is not None: | |
| llm.unload_model() | |
| del llm | |
| llm = None | |
| if clip_model != None: | |
| del clip_model | |
| clip_model = None | |
| del clip_processor | |
| clip_processor = None | |
| del clip_engine | |
| clip_engine = None | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| cwd = os.getcwd() | |
| model_config = get_model_config(config, selected_model_name) | |
| selected_CLIP = False | |
| if selected_model_name == "CLIP": | |
| selected_CLIP = True | |
| if clip_model == None: | |
| clip_model = CLIPModel.from_pretrained(model_config["model_path"]).to('cuda') | |
| clip_processor = CLIPProcessor.from_pretrained(model_config["model_path"]) | |
| generate_clip_engine(data_dir, model_config["model_path"], clip_model, clip_processor) | |
| return | |
| model_path = os.path.join(cwd, "model", model_info["id"], "engine") if "id" in model_info else None | |
| engine_name = model_info["metadata"].get('engine', None) | |
| if not model_path or not engine_name: | |
| print("Model path or engine not provided in metadata") | |
| return | |
| if selected_model_name == "ChatGLM 3 6B int4 (Supports Chinese)": | |
| selected_ChatGLM = True | |
| model_name, _ = read_model_name(model_path) | |
| prompt_template = LLMPromptTemplate() | |
| text_qa_template_str = prompt_template.model_context_template(model_name) | |
| selected_completion_to_prompt = text_qa_template_str | |
| #selected_completion_to_prompt = chatglm_completion_to_prompt if selected_ChatGLM else completion_to_prompt | |
| llm = TrtLlmAPI( | |
| model_path=model_path, | |
| engine_name=engine_name, | |
| tokenizer_dir=os.path.join(cwd, "model", model_info["id"] ,model_info["prerequisite"]["tokenizer_local_dir"] ) if "tokenizer_local_dir" in model_info["prerequisite"] else None, | |
| temperature=model_info["metadata"].get("temperature"), | |
| max_new_tokens=model_info["metadata"].get("max_new_tokens"), | |
| context_window=model_info["metadata"].get("max_input_token"), | |
| vocab_file=os.path.join(cwd, "model", model_info["id"] ,model_info["prerequisite"]["vocab_local_dir"], model_info["prerequisite"]["tokenizer_files"]["vocab_file"]) if "vocab_local_dir" in model_info["prerequisite"] else None, | |
| messages_to_prompt=messages_to_prompt, | |
| completion_to_prompt=selected_completion_to_prompt, | |
| use_py_session=use_py_session, | |
| add_special_tokens=add_special_tokens, | |
| trtLlm_debug_mode=trtLlm_debug_mode, | |
| verbose=verbose | |
| ) | |
| if embed_model is None : embed_model = HuggingFaceEmbeddings(model_name=embedded_model) | |
| if service_context is None: | |
| service_context = ServiceContext.from_defaults(llm=llm, embed_model=embed_model, | |
| context_window=model_config["max_input_token"], chunk_size=512, | |
| chunk_overlap=200) | |
| else: | |
| service_context = ServiceContext.from_service_context(service_context=service_context, llm=llm) | |
| set_global_service_context(service_context) | |
| generate_inferance_engine(data_dir) | |
| interface.on_model_change(on_model_change_handler) | |
| def on_dataset_source_change_handler(source, path, session_id): | |
| global data_source, data_dir, engine | |
| data_source = source | |
| if data_source == "nodataset": | |
| print(' No dataset source selected', session_id) | |
| return | |
| print('dataset source updated ', source, path, session_id) | |
| if data_source == "directory": | |
| data_dir = path | |
| else: | |
| print("Wrong data type selected") | |
| generate_inferance_engine(data_dir) | |
| interface.on_dataset_source_updated(on_dataset_source_change_handler) | |
| def handle_regenerate_index(source, path, session_id): | |
| if selected_CLIP: | |
| generate_clip_engine(data_dir, model_config["model_path"], clip_model, clip_processor, force_rewrite=True) | |
| else: | |
| generate_inferance_engine(path, force_rewrite=True) | |
| print("on regenerate index", source, path, session_id) | |
| def mic_init_handler(): | |
| global whisper_model, whisper_model_loaded, enable_asr | |
| enable_asr = config["models"]["enable_asr"] | |
| if not enable_asr: | |
| return False | |
| vid_mem_info = nvmlDeviceGetMemoryInfo(nvmlDeviceGetHandleByIndex(0)) | |
| free_vid_mem = vid_mem_info.free / (1024*1024) | |
| print("free video memory in MB = ", free_vid_mem) | |
| if whisper_model is not None: | |
| whisper_model.unload_model() | |
| del whisper_model | |
| whisper_model = None | |
| whisper_model = WhisperTRTLLM(asr_engine_path, assets_dir=asr_assets_path) | |
| whisper_model_loaded = True | |
| return True | |
| interface.on_mic_button_click(mic_init_handler) | |
| def mic_recording_done_handler(audio_path): | |
| transcription = "" | |
| global whisper_model, enable_asr, whisper_model_loaded | |
| if not enable_asr: | |
| return "" | |
| # Check and wait until model is loaded before running it. | |
| checks_left_for_model_loading = 40 | |
| sleep_time = 0.2 | |
| while checks_left_for_model_loading>0 and not whisper_model_loaded: | |
| time.sleep(sleep_time) | |
| checks_left_for_model_loading -= 1 | |
| assert checks_left_for_model_loading>0, f"Whisper model loading not finished even after {(checks_left_for_model_loading*sleep_time)} seconds" | |
| if checks_left_for_model_loading == 0: | |
| return "" | |
| # Covert the audio file into required sampling rate | |
| current_sampling_rate, data = wavfile.read(audio_path) | |
| new_sampling_rate = 16000 | |
| number_of_samples = round(len(data) * float(new_sampling_rate) / current_sampling_rate) | |
| data = sps.resample(data, number_of_samples) | |
| new_file_path = os.path.join( os.path.dirname(audio_path), "whisper_audio_input.wav" ) | |
| wavfile.write(new_file_path, new_sampling_rate, data.astype(np.int16)) | |
| language = "english" | |
| if selected_ChatGLM: language = "chinese" | |
| transcription = decode_audio_file( new_file_path, whisper_model, language=language, mel_filters_dir=asr_assets_path) | |
| if whisper_model is not None: | |
| whisper_model.unload_model() | |
| del whisper_model | |
| whisper_model = None | |
| whisper_model_loaded = False | |
| return transcription | |
| interface.on_mic_recording_done(mic_recording_done_handler) | |
| def model_download_handler(model_info): | |
| download_path = os.path.join(os.getcwd(), "model") | |
| status = download_model_by_name(model_info=model_info, download_path=download_path) | |
| print(f"Model download status: {status}") | |
| return status | |
| interface.on_model_downloaded(model_download_handler) | |
| def model_install_handler(model_info): | |
| download_path = os.path.join(os.getcwd(), "model") | |
| global llm, service_context | |
| #unload the current model | |
| if llm is not None: | |
| llm.unload_model() | |
| del llm | |
| llm = None | |
| # build the engine | |
| status = build_engine_by_name(model_info=model_info , download_path= download_path) | |
| print(f"Engine build status: {status}") | |
| return status | |
| interface.on_model_installed(model_install_handler) | |
| def model_delete_handler(model_info): | |
| print("Model deleting ", model_info) | |
| model_dir = os.path.join(os.getcwd(), "model", model_info['id']) | |
| isSuccess = True | |
| if os.path.isdir(model_dir): | |
| try: | |
| shutil.rmtree(model_dir) | |
| except Exception as e: | |
| print("Exception during temp folder delete", e) | |
| isSuccess = False | |
| return isSuccess | |
| interface.on_model_delete(model_delete_handler) | |
| interface.on_regenerate_index(handle_regenerate_index) | |
| # render the interface | |
| interface.render() |