Spaces:
Running
Running
| import os | |
| import json | |
| import zipfile | |
| from typing import List, Dict, Any | |
| import pandas as pd | |
| from huggingface_hub import hf_hub_download, list_repo_files | |
| from llama_index.core import Document, VectorStoreIndex, KeywordTableIndex, Settings | |
| from llama_index.core.retrievers import VectorIndexRetriever, QueryFusionRetriever | |
| from llama_index.retrievers.bm25 import BM25Retriever | |
| from llama_index.core.query_engine import RetrieverQueryEngine | |
| from llama_index.core.response_synthesizers import get_response_synthesizer, ResponseMode | |
| from llama_index.core.text_splitter import SentenceSplitter | |
| from sentence_transformers import SentenceTransformer | |
| import gradio as gr | |
| import sys | |
| GOOGLE_API_KEY = os.getenv('GOOGLE_API_KEY') | |
| OPENAI_API_KEY = os.getenv('OPENAI_API_KEY') | |
| HF_REPO_ID = "MrSimple01/AIEXP_RAG_FILES" | |
| HF_TOKEN = os.getenv('HF_TOKEN') | |
| AVAILABLE_MODELS = { | |
| "Gemini 2.5 Flash": { | |
| "provider": "google", | |
| "model_name": "gemini-2.5-flash", | |
| "api_key": GOOGLE_API_KEY | |
| }, | |
| "Gemini 2.5 Pro": { | |
| "provider": "google", | |
| "model_name": "gemini-2.5-pro", | |
| "api_key": GOOGLE_API_KEY | |
| }, | |
| "GPT-4o": { | |
| "provider": "openai", | |
| "model_name": "gpt-4o", | |
| "api_key": OPENAI_API_KEY | |
| }, | |
| "GPT-4o Mini": { | |
| "provider": "openai", | |
| "model_name": "gpt-4o-mini", | |
| "api_key": OPENAI_API_KEY | |
| }, | |
| "GPT-5": { | |
| "provider": "openai", | |
| "model_name": "gpt-5", | |
| "api_key": OPENAI_API_KEY | |
| } | |
| } | |
| DEFAULT_MODEL = "Gemini 2.5 Flash" | |
| DOWNLOAD_DIR = "rag_files" | |
| JSON_FILES_DIR = "JSON" | |
| TABLE_DATA_DIR = "Табличные данные_JSON" | |
| IMAGE_DATA_DIR = "Изображения" | |
| CHUNK_SIZE = 512 | |
| CHUNK_OVERLAP = 50 | |
| TABLE_MAX_ROWS_PER_CHUNK = 30 | |
| os.makedirs(DOWNLOAD_DIR, exist_ok=True) | |
| def get_llm_model(model_name): | |
| config = AVAILABLE_MODELS[model_name] | |
| if config["provider"] == "google": | |
| from llama_index.llms.gemini import Gemini | |
| return Gemini(model=config["model_name"], api_key=config["api_key"]) | |
| else: | |
| from llama_index.llms.openai import OpenAI | |
| return OpenAI(model=config["model_name"], api_key=config["api_key"]) | |
| def get_embedding_model(): | |
| from llama_index.embeddings.huggingface import HuggingFaceEmbedding | |
| return HuggingFaceEmbedding(model_name="sentence-transformers/all-MiniLM-L6-v2") | |
| def list_zip_files_in_repo(repo_id: str) -> List[str]: | |
| files = list_repo_files(repo_id, token=HF_TOKEN) | |
| return [f for f in files if f.startswith(JSON_FILES_DIR) and f.endswith('.zip')] | |
| def download_file_from_hf(repo_id: str, path_in_repo: str, dest_dir: str) -> str: | |
| local_path = hf_hub_download(repo_id=repo_id, filename=path_in_repo, repo_type="dataset", token=HF_TOKEN) | |
| base = os.path.basename(local_path) | |
| dst = os.path.join(dest_dir, base) | |
| if local_path != dst: | |
| try: | |
| with open(local_path, 'rb') as r, open(dst, 'wb') as w: | |
| w.write(r.read()) | |
| except Exception: | |
| pass | |
| return dst | |
| def read_jsons_from_zip(zip_path: str) -> List[Dict[str, Any]]: | |
| docs = [] | |
| with zipfile.ZipFile(zip_path, 'r') as z: | |
| for name in z.namelist(): | |
| if name.lower().endswith('.json'): | |
| with z.open(name) as f: | |
| try: | |
| text = f.read().decode('utf-8') | |
| data = json.loads(text) | |
| docs.append(data) | |
| except Exception as e: | |
| print(f"Failed to load {name} in {zip_path}: {e}") | |
| return docs | |
| def chunk_text_field(text: str, doc_meta: Dict[str, Any], splitter: SentenceSplitter) -> List[Document]: | |
| nodes = splitter.split_text(text) | |
| chunks = [] | |
| for i, node_text in enumerate(nodes): | |
| md = dict(doc_meta) | |
| md.update({ | |
| 'chunk_id': f"{md.get('document_id','unknown')}_text_{i}", | |
| 'chunk_type': 'text' | |
| }) | |
| chunks.append(Document(text=node_text, metadata=md)) | |
| return chunks | |
| def chunk_table(table: Dict[str, Any], table_meta: Dict[str, Any], max_rows: int = TABLE_MAX_ROWS_PER_CHUNK) -> List[Document]: | |
| headers = table.get('headers') or [] | |
| rows = table.get('data') or [] | |
| if not rows: | |
| text = table.get('table_description') or table.get('table_title') or '' | |
| md = {**table_meta, 'chunk_type': 'table', 'chunk_id': f"{table_meta.get('document_id')}_table_single"} | |
| return [Document(text=text, metadata=md)] | |
| chunks = [] | |
| for i in range(0, len(rows), max_rows): | |
| block = rows[i:i+max_rows] | |
| lines = [] | |
| lines.append(f"Table {table_meta.get('table_number','?')} - {table_meta.get('table_title','')}") | |
| lines.append(f"Headers: {headers}") | |
| for r in block: | |
| row_items = [f"{k}: {v}" for k, v in r.items()] | |
| lines.append(" | ".join(row_items)) | |
| chunk_text = "\n".join(lines) | |
| md = dict(table_meta) | |
| md.update({'chunk_type': 'table', 'chunk_id': f"{table_meta.get('document_id')}_table_{i // max_rows}"}) | |
| chunks.append(Document(text=chunk_text, metadata=md)) | |
| return chunks | |
| def chunk_image(image_entry: Dict[str, Any], image_meta: Dict[str, Any]) -> Document: | |
| txt = f"Image: {image_entry.get('Название изображения') or image_entry.get('title','')}. " | |
| txt += f"Описание: {image_entry.get('Описание изображение') or image_entry.get('description','')}. " | |
| txt += f"Файл: {image_entry.get('Файл изображения') or image_entry.get('file','')}." | |
| md = dict(image_meta) | |
| md.update({'chunk_type': 'image', 'chunk_id': f"{image_meta.get('document_id')}_image_{image_entry.get('№ Изображения','0')}"}) | |
| return Document(text=txt, metadata=md) | |
| def build_chunks_from_repo(repo_id: str) -> List[Document]: | |
| zip_paths = list_zip_files_in_repo(repo_id) | |
| print(f"Found {len(zip_paths)} zip files under {JSON_FILES_DIR} in repo {repo_id}") | |
| splitter = SentenceSplitter(chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP) | |
| all_chunks = [] | |
| for remote_path in zip_paths: | |
| print(f"Downloading {remote_path}...") | |
| local_zip = download_file_from_hf(repo_id, remote_path, DOWNLOAD_DIR) | |
| print(f"Parsing {local_zip}...") | |
| json_docs = read_jsons_from_zip(local_zip) | |
| for doc in json_docs: | |
| doc_meta = doc.get('document_metadata', {}) | |
| doc_id = doc_meta.get('document_id') or doc_meta.get('document_name') or 'unknown_doc' | |
| base_meta = {'document_id': doc_id, 'document_name': doc_meta.get('document_name','')} | |
| for sec in doc.get('sections', []): | |
| sec_meta = dict(base_meta) | |
| sec_meta.update({'section_id': sec.get('section_id'), 'section_title': None}) | |
| text = sec.get('section_text') or sec.get('text') or '' | |
| if text and text.strip(): | |
| chunks = chunk_text_field(text, sec_meta, splitter) | |
| all_chunks.extend(chunks) | |
| for sheet in doc.get('sheets', []) + doc.get('tables', []) if (doc.get('sheets') or doc.get('tables')) else []: | |
| table_meta = dict(base_meta) | |
| table_meta.update({ | |
| 'sheet_name': sheet.get('sheet_name') or sheet.get('table_title'), | |
| 'section': sheet.get('section'), | |
| 'table_number': sheet.get('table_number'), | |
| 'table_title': sheet.get('table_title') | |
| }) | |
| table_chunks = chunk_table(sheet, table_meta, max_rows=TABLE_MAX_ROWS_PER_CHUNK) | |
| all_chunks.extend(table_chunks) | |
| for img in doc.get('images', []) or doc.get('image_data', []) or doc.get('image_entries', []): | |
| img_meta = dict(base_meta) | |
| chunk = chunk_image(img, img_meta) | |
| all_chunks.append(chunk) | |
| print(f"Built total {len(all_chunks)} chunks") | |
| return all_chunks | |
| def create_hybrid_index(documents): | |
| print("Creating vector index...") | |
| vector_index = VectorStoreIndex.from_documents(documents) | |
| print("Creating keyword index...") | |
| keyword_index = KeywordTableIndex.from_documents(documents) | |
| return vector_index, keyword_index | |
| def create_fusion_retriever(vector_index, keyword_index, documents): | |
| vector_retriever = VectorIndexRetriever(index=vector_index, similarity_top_k=5) | |
| bm25_retriever = BM25Retriever.from_defaults( | |
| docstore=vector_index.docstore, | |
| similarity_top_k=5 | |
| ) | |
| fusion_retriever = QueryFusionRetriever( | |
| [vector_retriever, bm25_retriever], | |
| similarity_top_k=5, | |
| num_queries=1, | |
| mode="reciprocal_rerank", | |
| use_async=False | |
| ) | |
| return fusion_retriever | |
| def create_query_engine(vector_index, keyword_index, documents): | |
| fusion_retriever = create_fusion_retriever(vector_index, keyword_index, documents) | |
| response_synthesizer = get_response_synthesizer( | |
| response_mode=ResponseMode.COMPACT, | |
| use_async=False | |
| ) | |
| query_engine = RetrieverQueryEngine( | |
| retriever=fusion_retriever, | |
| response_synthesizer=response_synthesizer | |
| ) | |
| return query_engine | |
| def initialize_system(): | |
| print("Initializing system...") | |
| embed_model = get_embedding_model() | |
| llm = get_llm_model(DEFAULT_MODEL) | |
| Settings.embed_model = embed_model | |
| Settings.llm = llm | |
| Settings.chunk_size = CHUNK_SIZE | |
| Settings.chunk_overlap = CHUNK_OVERLAP | |
| print("Loading documents...") | |
| documents = build_chunks_from_repo(HF_REPO_ID) | |
| print("Creating indices...") | |
| vector_index, keyword_index = create_hybrid_index(documents) | |
| print("Creating query engine...") | |
| query_engine = create_query_engine(vector_index, keyword_index, documents) | |
| print("System initialized successfully!") | |
| return query_engine, vector_index, keyword_index, documents | |
| def answer_question(question, query_engine): | |
| if not question.strip(): | |
| return "<div style='color: black;'>Please enter a question</div>" | |
| try: | |
| response = query_engine.query(question) | |
| answer_html = f""" | |
| <div style='background-color: #f8f9fa; padding: 20px; border-radius: 10px; color: black;'> | |
| <h3 style='color: #007bff;'>Answer:</h3> | |
| <p>{response.response}</p> | |
| </div> | |
| """ | |
| sources_html = "<div style='background-color: #e9ecef; padding: 15px; border-radius: 8px; color: black;'>" | |
| sources_html += "<h4>Sources:</h4>" | |
| for i, node in enumerate(response.source_nodes): | |
| sources_html += f""" | |
| <div style='margin: 10px 0; padding: 10px; background-color: white; border-left: 3px solid #007bff;'> | |
| <strong>Document {i+1}:</strong> {node.metadata.get('document_id', 'unknown')}<br> | |
| <strong>Score:</strong> {node.score:.3f}<br> | |
| <strong>Text:</strong> {node.text[:200]}... | |
| </div> | |
| """ | |
| sources_html += "</div>" | |
| return answer_html, sources_html | |
| except Exception as e: | |
| error_html = f"<div style='color: red;'>Error: {str(e)}</div>" | |
| return error_html, error_html | |
| def switch_model(model_name, vector_index, keyword_index, documents): | |
| try: | |
| print(f"Switching to model: {model_name}") | |
| new_llm = get_llm_model(model_name) | |
| Settings.llm = new_llm | |
| new_query_engine = create_query_engine(vector_index, keyword_index, documents) | |
| return new_query_engine, f"✅ Model switched to: {model_name}" | |
| except Exception as e: | |
| return None, f"❌ Error: {str(e)}" | |
| query_engine = None | |
| vector_index = None | |
| keyword_index = None | |
| documents = None | |
| current_model = DEFAULT_MODEL | |
| def main_answer_question(question): | |
| global query_engine | |
| return answer_question(question, query_engine) | |
| def main_switch_model(model_name): | |
| global query_engine, vector_index, keyword_index, documents, current_model | |
| new_query_engine, status = switch_model(model_name, vector_index, keyword_index, documents) | |
| if new_query_engine: | |
| query_engine = new_query_engine | |
| current_model = model_name | |
| return status | |
| def create_interface(): | |
| with gr.Blocks(title="AIEXP - RAG System", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown("# AIEXP - AI Expert for Regulatory Documentation") | |
| with gr.Row(): | |
| model_dropdown = gr.Dropdown( | |
| choices=list(AVAILABLE_MODELS.keys()), | |
| value=current_model, | |
| label="Select Language Model" | |
| ) | |
| switch_btn = gr.Button("Switch Model") | |
| model_status = gr.Textbox( | |
| value=f"Current model: {current_model}", | |
| label="Model Status", | |
| interactive=False | |
| ) | |
| with gr.Row(): | |
| question_input = gr.Textbox( | |
| label="Your Question", | |
| placeholder="Ask a question about the documents...", | |
| lines=3 | |
| ) | |
| ask_btn = gr.Button("Get Answer", variant="primary") | |
| with gr.Row(): | |
| answer_output = gr.HTML( | |
| label="Answer", | |
| value="<div style='padding: 20px; text-align: center;'>Answer will appear here...</div>" | |
| ) | |
| sources_output = gr.HTML( | |
| label="Sources", | |
| value="<div style='padding: 20px; text-align: center;'>Sources will appear here...</div>" | |
| ) | |
| switch_btn.click( | |
| fn=main_switch_model, | |
| inputs=[model_dropdown], | |
| outputs=[model_status] | |
| ) | |
| ask_btn.click( | |
| fn=main_answer_question, | |
| inputs=[question_input], | |
| outputs=[answer_output, sources_output] | |
| ) | |
| question_input.submit( | |
| fn=main_answer_question, | |
| inputs=[question_input], | |
| outputs=[answer_output, sources_output] | |
| ) | |
| return demo | |
| def main(): | |
| global query_engine, vector_index, keyword_index, documents | |
| print("Starting AIEXP - AI Expert for Regulatory Documentation") | |
| query_engine, vector_index, keyword_index, documents = initialize_system() | |
| if query_engine: | |
| print("Launching web interface...") | |
| demo = create_interface() | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| share=True | |
| ) | |
| else: | |
| print("Failed to initialize system") | |
| sys.exit(1) | |
| if __name__ == "__main__": | |
| main() |