RAG_AIEXP_01 / app.py
MrSimple07's picture
new version of rag
147e01b
raw
history blame
14.7 kB
import os
import json
import zipfile
from typing import List, Dict, Any
import pandas as pd
from huggingface_hub import hf_hub_download, list_repo_files
from llama_index.core import Document, VectorStoreIndex, KeywordTableIndex, Settings
from llama_index.core.retrievers import VectorIndexRetriever, QueryFusionRetriever
from llama_index.retrievers.bm25 import BM25Retriever
from llama_index.core.query_engine import RetrieverQueryEngine
from llama_index.core.response_synthesizers import get_response_synthesizer, ResponseMode
from llama_index.core.text_splitter import SentenceSplitter
from sentence_transformers import SentenceTransformer
import gradio as gr
import sys
GOOGLE_API_KEY = os.getenv('GOOGLE_API_KEY')
OPENAI_API_KEY = os.getenv('OPENAI_API_KEY')
HF_REPO_ID = "MrSimple01/AIEXP_RAG_FILES"
HF_TOKEN = os.getenv('HF_TOKEN')
AVAILABLE_MODELS = {
"Gemini 2.5 Flash": {
"provider": "google",
"model_name": "gemini-2.5-flash",
"api_key": GOOGLE_API_KEY
},
"Gemini 2.5 Pro": {
"provider": "google",
"model_name": "gemini-2.5-pro",
"api_key": GOOGLE_API_KEY
},
"GPT-4o": {
"provider": "openai",
"model_name": "gpt-4o",
"api_key": OPENAI_API_KEY
},
"GPT-4o Mini": {
"provider": "openai",
"model_name": "gpt-4o-mini",
"api_key": OPENAI_API_KEY
},
"GPT-5": {
"provider": "openai",
"model_name": "gpt-5",
"api_key": OPENAI_API_KEY
}
}
DEFAULT_MODEL = "Gemini 2.5 Flash"
DOWNLOAD_DIR = "rag_files"
JSON_FILES_DIR = "JSON"
TABLE_DATA_DIR = "Табличные данные_JSON"
IMAGE_DATA_DIR = "Изображения"
CHUNK_SIZE = 512
CHUNK_OVERLAP = 50
TABLE_MAX_ROWS_PER_CHUNK = 30
os.makedirs(DOWNLOAD_DIR, exist_ok=True)
def get_llm_model(model_name):
config = AVAILABLE_MODELS[model_name]
if config["provider"] == "google":
from llama_index.llms.gemini import Gemini
return Gemini(model=config["model_name"], api_key=config["api_key"])
else:
from llama_index.llms.openai import OpenAI
return OpenAI(model=config["model_name"], api_key=config["api_key"])
def get_embedding_model():
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
return HuggingFaceEmbedding(model_name="sentence-transformers/all-MiniLM-L6-v2")
def list_zip_files_in_repo(repo_id: str) -> List[str]:
files = list_repo_files(repo_id, token=HF_TOKEN)
return [f for f in files if f.startswith(JSON_FILES_DIR) and f.endswith('.zip')]
def download_file_from_hf(repo_id: str, path_in_repo: str, dest_dir: str) -> str:
local_path = hf_hub_download(repo_id=repo_id, filename=path_in_repo, repo_type="dataset", token=HF_TOKEN)
base = os.path.basename(local_path)
dst = os.path.join(dest_dir, base)
if local_path != dst:
try:
with open(local_path, 'rb') as r, open(dst, 'wb') as w:
w.write(r.read())
except Exception:
pass
return dst
def read_jsons_from_zip(zip_path: str) -> List[Dict[str, Any]]:
docs = []
with zipfile.ZipFile(zip_path, 'r') as z:
for name in z.namelist():
if name.lower().endswith('.json'):
with z.open(name) as f:
try:
text = f.read().decode('utf-8')
data = json.loads(text)
docs.append(data)
except Exception as e:
print(f"Failed to load {name} in {zip_path}: {e}")
return docs
def chunk_text_field(text: str, doc_meta: Dict[str, Any], splitter: SentenceSplitter) -> List[Document]:
nodes = splitter.split_text(text)
chunks = []
for i, node_text in enumerate(nodes):
md = dict(doc_meta)
md.update({
'chunk_id': f"{md.get('document_id','unknown')}_text_{i}",
'chunk_type': 'text'
})
chunks.append(Document(text=node_text, metadata=md))
return chunks
def chunk_table(table: Dict[str, Any], table_meta: Dict[str, Any], max_rows: int = TABLE_MAX_ROWS_PER_CHUNK) -> List[Document]:
headers = table.get('headers') or []
rows = table.get('data') or []
if not rows:
text = table.get('table_description') or table.get('table_title') or ''
md = {**table_meta, 'chunk_type': 'table', 'chunk_id': f"{table_meta.get('document_id')}_table_single"}
return [Document(text=text, metadata=md)]
chunks = []
for i in range(0, len(rows), max_rows):
block = rows[i:i+max_rows]
lines = []
lines.append(f"Table {table_meta.get('table_number','?')} - {table_meta.get('table_title','')}")
lines.append(f"Headers: {headers}")
for r in block:
row_items = [f"{k}: {v}" for k, v in r.items()]
lines.append(" | ".join(row_items))
chunk_text = "\n".join(lines)
md = dict(table_meta)
md.update({'chunk_type': 'table', 'chunk_id': f"{table_meta.get('document_id')}_table_{i // max_rows}"})
chunks.append(Document(text=chunk_text, metadata=md))
return chunks
def chunk_image(image_entry: Dict[str, Any], image_meta: Dict[str, Any]) -> Document:
txt = f"Image: {image_entry.get('Название изображения') or image_entry.get('title','')}. "
txt += f"Описание: {image_entry.get('Описание изображение') or image_entry.get('description','')}. "
txt += f"Файл: {image_entry.get('Файл изображения') or image_entry.get('file','')}."
md = dict(image_meta)
md.update({'chunk_type': 'image', 'chunk_id': f"{image_meta.get('document_id')}_image_{image_entry.get('№ Изображения','0')}"})
return Document(text=txt, metadata=md)
def build_chunks_from_repo(repo_id: str) -> List[Document]:
zip_paths = list_zip_files_in_repo(repo_id)
print(f"Found {len(zip_paths)} zip files under {JSON_FILES_DIR} in repo {repo_id}")
splitter = SentenceSplitter(chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP)
all_chunks = []
for remote_path in zip_paths:
print(f"Downloading {remote_path}...")
local_zip = download_file_from_hf(repo_id, remote_path, DOWNLOAD_DIR)
print(f"Parsing {local_zip}...")
json_docs = read_jsons_from_zip(local_zip)
for doc in json_docs:
doc_meta = doc.get('document_metadata', {})
doc_id = doc_meta.get('document_id') or doc_meta.get('document_name') or 'unknown_doc'
base_meta = {'document_id': doc_id, 'document_name': doc_meta.get('document_name','')}
for sec in doc.get('sections', []):
sec_meta = dict(base_meta)
sec_meta.update({'section_id': sec.get('section_id'), 'section_title': None})
text = sec.get('section_text') or sec.get('text') or ''
if text and text.strip():
chunks = chunk_text_field(text, sec_meta, splitter)
all_chunks.extend(chunks)
for sheet in doc.get('sheets', []) + doc.get('tables', []) if (doc.get('sheets') or doc.get('tables')) else []:
table_meta = dict(base_meta)
table_meta.update({
'sheet_name': sheet.get('sheet_name') or sheet.get('table_title'),
'section': sheet.get('section'),
'table_number': sheet.get('table_number'),
'table_title': sheet.get('table_title')
})
table_chunks = chunk_table(sheet, table_meta, max_rows=TABLE_MAX_ROWS_PER_CHUNK)
all_chunks.extend(table_chunks)
for img in doc.get('images', []) or doc.get('image_data', []) or doc.get('image_entries', []):
img_meta = dict(base_meta)
chunk = chunk_image(img, img_meta)
all_chunks.append(chunk)
print(f"Built total {len(all_chunks)} chunks")
return all_chunks
def create_hybrid_index(documents):
print("Creating vector index...")
vector_index = VectorStoreIndex.from_documents(documents)
print("Creating keyword index...")
keyword_index = KeywordTableIndex.from_documents(documents)
return vector_index, keyword_index
def create_fusion_retriever(vector_index, keyword_index, documents):
vector_retriever = VectorIndexRetriever(index=vector_index, similarity_top_k=5)
bm25_retriever = BM25Retriever.from_defaults(
docstore=vector_index.docstore,
similarity_top_k=5
)
fusion_retriever = QueryFusionRetriever(
[vector_retriever, bm25_retriever],
similarity_top_k=5,
num_queries=1,
mode="reciprocal_rerank",
use_async=False
)
return fusion_retriever
def create_query_engine(vector_index, keyword_index, documents):
fusion_retriever = create_fusion_retriever(vector_index, keyword_index, documents)
response_synthesizer = get_response_synthesizer(
response_mode=ResponseMode.COMPACT,
use_async=False
)
query_engine = RetrieverQueryEngine(
retriever=fusion_retriever,
response_synthesizer=response_synthesizer
)
return query_engine
def initialize_system():
print("Initializing system...")
embed_model = get_embedding_model()
llm = get_llm_model(DEFAULT_MODEL)
Settings.embed_model = embed_model
Settings.llm = llm
Settings.chunk_size = CHUNK_SIZE
Settings.chunk_overlap = CHUNK_OVERLAP
print("Loading documents...")
documents = build_chunks_from_repo(HF_REPO_ID)
print("Creating indices...")
vector_index, keyword_index = create_hybrid_index(documents)
print("Creating query engine...")
query_engine = create_query_engine(vector_index, keyword_index, documents)
print("System initialized successfully!")
return query_engine, vector_index, keyword_index, documents
def answer_question(question, query_engine):
if not question.strip():
return "<div style='color: black;'>Please enter a question</div>"
try:
response = query_engine.query(question)
answer_html = f"""
<div style='background-color: #f8f9fa; padding: 20px; border-radius: 10px; color: black;'>
<h3 style='color: #007bff;'>Answer:</h3>
<p>{response.response}</p>
</div>
"""
sources_html = "<div style='background-color: #e9ecef; padding: 15px; border-radius: 8px; color: black;'>"
sources_html += "<h4>Sources:</h4>"
for i, node in enumerate(response.source_nodes):
sources_html += f"""
<div style='margin: 10px 0; padding: 10px; background-color: white; border-left: 3px solid #007bff;'>
<strong>Document {i+1}:</strong> {node.metadata.get('document_id', 'unknown')}<br>
<strong>Score:</strong> {node.score:.3f}<br>
<strong>Text:</strong> {node.text[:200]}...
</div>
"""
sources_html += "</div>"
return answer_html, sources_html
except Exception as e:
error_html = f"<div style='color: red;'>Error: {str(e)}</div>"
return error_html, error_html
def switch_model(model_name, vector_index, keyword_index, documents):
try:
print(f"Switching to model: {model_name}")
new_llm = get_llm_model(model_name)
Settings.llm = new_llm
new_query_engine = create_query_engine(vector_index, keyword_index, documents)
return new_query_engine, f"✅ Model switched to: {model_name}"
except Exception as e:
return None, f"❌ Error: {str(e)}"
query_engine = None
vector_index = None
keyword_index = None
documents = None
current_model = DEFAULT_MODEL
def main_answer_question(question):
global query_engine
return answer_question(question, query_engine)
def main_switch_model(model_name):
global query_engine, vector_index, keyword_index, documents, current_model
new_query_engine, status = switch_model(model_name, vector_index, keyword_index, documents)
if new_query_engine:
query_engine = new_query_engine
current_model = model_name
return status
def create_interface():
with gr.Blocks(title="AIEXP - RAG System", theme=gr.themes.Soft()) as demo:
gr.Markdown("# AIEXP - AI Expert for Regulatory Documentation")
with gr.Row():
model_dropdown = gr.Dropdown(
choices=list(AVAILABLE_MODELS.keys()),
value=current_model,
label="Select Language Model"
)
switch_btn = gr.Button("Switch Model")
model_status = gr.Textbox(
value=f"Current model: {current_model}",
label="Model Status",
interactive=False
)
with gr.Row():
question_input = gr.Textbox(
label="Your Question",
placeholder="Ask a question about the documents...",
lines=3
)
ask_btn = gr.Button("Get Answer", variant="primary")
with gr.Row():
answer_output = gr.HTML(
label="Answer",
value="<div style='padding: 20px; text-align: center;'>Answer will appear here...</div>"
)
sources_output = gr.HTML(
label="Sources",
value="<div style='padding: 20px; text-align: center;'>Sources will appear here...</div>"
)
switch_btn.click(
fn=main_switch_model,
inputs=[model_dropdown],
outputs=[model_status]
)
ask_btn.click(
fn=main_answer_question,
inputs=[question_input],
outputs=[answer_output, sources_output]
)
question_input.submit(
fn=main_answer_question,
inputs=[question_input],
outputs=[answer_output, sources_output]
)
return demo
def main():
global query_engine, vector_index, keyword_index, documents
print("Starting AIEXP - AI Expert for Regulatory Documentation")
query_engine, vector_index, keyword_index, documents = initialize_system()
if query_engine:
print("Launching web interface...")
demo = create_interface()
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=True
)
else:
print("Failed to initialize system")
sys.exit(1)
if __name__ == "__main__":
main()