| | import gradio as gr |
| | import torch |
| | |
| | from retrieval import ( |
| | process_law_data_to_chunks, |
| | |
| | |
| | |
| | tokenize_vi_for_bm25_setup, |
| | search_relevant_laws |
| | ) |
| | from llm_handler import generate_response |
| | from sentence_transformers import SentenceTransformer |
| | import faiss |
| | from rank_bm25 import BM25Okapi |
| | import json |
| | from unsloth import FastLanguageModel |
| |
|
| | |
| | |
| | JSON_FILE_PATH = "data/luat_chi_tiet_output_openai_sdk_final_cleaned.json" |
| | FAISS_INDEX_PATH = "data/my_law_faiss_flatip_normalized.index" |
| | LLM_MODEL_PATH = "models/lora_model_base" |
| | EMBEDDING_MODEL_PATH = "models/embedding_model" |
| |
|
| | |
| | print("Loading and processing law data...") |
| | try: |
| | with open(JSON_FILE_PATH, 'r', encoding='utf-8') as f: |
| | raw_data_from_file = json.load(f) |
| | chunks_data = process_law_data_to_chunks(raw_data_from_file) |
| | print(f"Loaded {len(chunks_data)} chunks.") |
| | if not chunks_data: |
| | raise ValueError("Chunks data is empty after processing.") |
| | except Exception as e: |
| | print(f"Error loading/processing law data: {e}") |
| | chunks_data = [] |
| |
|
| | |
| | print(f"Loading embedding model: {EMBEDDING_MODEL_PATH}...") |
| | device = "cuda" if torch.cuda.is_available() else "cpu" |
| | try: |
| | embedding_model = SentenceTransformer(EMBEDDING_MODEL_PATH, device=device) |
| | print("Embedding model loaded successfully.") |
| | except Exception as e: |
| | print(f"Error loading embedding model: {e}") |
| | embedding_model = None |
| |
|
| | |
| | print(f"Loading FAISS index from: {FAISS_INDEX_PATH}...") |
| | try: |
| | faiss_index = faiss.read_index(FAISS_INDEX_PATH) |
| | print(f"FAISS index loaded. Total vectors: {faiss_index.ntotal}") |
| | except Exception as e: |
| | print(f"Error loading FAISS index: {e}") |
| | faiss_index = None |
| |
|
| | |
| | print("Creating BM25 model...") |
| | bm25_model = None |
| | if chunks_data: |
| | try: |
| | corpus_texts_for_bm25 = [chunk.get('text', '') for chunk in chunks_data] |
| | tokenized_corpus_bm25 = [tokenize_vi_for_bm25_setup(text) for text in corpus_texts_for_bm25] |
| | bm25_model = BM25Okapi(tokenized_corpus_bm25) |
| | print("BM25 model created successfully.") |
| | except Exception as e: |
| | print(f"Error creating BM25 model: {e}") |
| | else: |
| | print("Skipping BM25 model creation as chunks_data is empty.") |
| |
|
| |
|
| | |
| | print(f"Loading LLM model: {LLM_MODEL_PATH}...") |
| | try: |
| | |
| | |
| | llm_model, llm_tokenizer = FastLanguageModel.from_pretrained( |
| | model_name=LLM_MODEL_PATH, |
| | max_seq_length=2048, |
| | dtype=None, |
| | load_in_4bit=True, |
| | ) |
| | FastLanguageModel.for_inference(llm_model) |
| | print("LLM model and tokenizer loaded successfully.") |
| | except Exception as e: |
| | print(f"Error loading LLM model: {e}") |
| | llm_model = None |
| | llm_tokenizer = None |
| | |
| |
|
| | |
| | def respond(message, history: list[tuple[str, str]]): |
| | if not all([chunks_data, embedding_model, faiss_index, bm25_model, llm_model, llm_tokenizer]): |
| | |
| | missing_components = [] |
| | if not chunks_data: missing_components.append("chunks_data") |
| | if not embedding_model: missing_components.append("embedding_model") |
| | if not faiss_index: missing_components.append("faiss_index") |
| | if not bm25_model: missing_components.append("bm25_model") |
| | if not llm_model: missing_components.append("llm_model") |
| | if not llm_tokenizer: missing_components.append("llm_tokenizer") |
| | error_msg = f"Lỗi: Một hoặc nhiều thành phần của hệ thống chưa được khởi tạo thành công. Thành phần thiếu: {', '.join(missing_components)}. Vui lòng kiểm tra logs của Space." |
| | print(error_msg) |
| | return error_msg |
| |
|
| | try: |
| | response_text = generate_response( |
| | query=message, |
| | llama_model=llm_model, |
| | tokenizer=llm_tokenizer, |
| | faiss_index=faiss_index, |
| | embed_model=embedding_model, |
| | chunks_data_list=chunks_data, |
| | bm25_model=bm25_model, |
| | search_function=search_relevant_laws |
| | |
| | |
| | |
| | |
| | |
| | |
| | ) |
| | yield response_text |
| |
|
| | except Exception as e: |
| | |
| | import traceback |
| | print(f"Error during response generation for query '{message}': {e}") |
| | print(traceback.format_exc()) |
| | yield f"Đã xảy ra lỗi nghiêm trọng khi xử lý yêu cầu của bạn. Vui lòng thử lại sau hoặc liên hệ quản trị viên." |
| |
|
| | |
| | |
| | |
| | demo = gr.ChatInterface( |
| | respond, |
| | |
| | |
| | |
| | ) |
| |
|
| | if __name__ == "__main__": |
| | demo.launch() |