sofzcc's picture
Update app.py
b37715b verified
raw
history blame
25.6 kB
import os
import re
import json
import yaml
import argparse
from pathlib import Path
from typing import List, Dict, Tuple, Optional
import numpy as np
import faiss
import gradio as gr
from transformers import pipeline, AutoTokenizer, AutoModelForQuestionAnswering
from sentence_transformers import SentenceTransformer
import PyPDF2
import docx
# ----------- Configuration Loader -----------
class Config:
"""Load and manage configuration from YAML file."""
def __init__(self, config_path: str = "config.yaml"):
with open(config_path, 'r', encoding='utf-8') as f:
self.data = yaml.safe_load(f)
@property
def client_name(self) -> str:
return self.data.get('client', {}).get('name', 'RAG Assistant')
@property
def client_description(self) -> str:
return self.data.get('client', {}).get('description', 'AI-powered Q&A with document retrieval and citation')
@property
def client_logo(self) -> Optional[str]:
return self.data.get('client', {}).get('logo')
@property
def theme_color(self) -> str:
return self.data.get('client', {}).get('theme_color', 'blue')
@property
def kb_directory(self) -> Path:
return Path(self.data.get('kb', {}).get('directory', './kb'))
@property
def index_directory(self) -> Path:
return Path(self.data.get('kb', {}).get('index_directory', './.index'))
@property
def embedding_model(self) -> str:
return self.data.get('models', {}).get('embedding', 'sentence-transformers/all-MiniLM-L6-v2')
@property
def qa_model(self) -> str:
return self.data.get('models', {}).get('qa', 'deepset/roberta-base-squad2')
@property
def confidence_threshold(self) -> float:
return self.data.get('thresholds', {}).get('confidence', 0.25)
@property
def similarity_threshold(self) -> float:
return self.data.get('thresholds', {}).get('similarity', 0.35)
@property
def chunk_size(self) -> int:
return self.data.get('chunking', {}).get('chunk_size', 800)
@property
def chunk_overlap(self) -> int:
return self.data.get('chunking', {}).get('overlap', 200)
@property
def quick_actions(self) -> List[Tuple[str, str]]:
actions = self.data.get('quick_actions', [])
return [(a['label'], a['query']) for a in actions]
@property
def welcome_message(self) -> str:
return self.data.get('messages', {}).get('welcome',
'πŸ‘‹ How can I help? Ask me anything or use a quick action button below.')
@property
def no_answer_message(self) -> str:
return self.data.get('messages', {}).get('no_answer',
"❌ **I don't know the answer to that** but if you have any document with details I can learn about it.")
@property
def upload_prompt(self) -> str:
return self.data.get('messages', {}).get('upload_prompt',
'πŸ“€ Upload a relevant document above, and I\'ll be able to help you find the information you need!')
# Global config instance
config = None
# ----------- Document Extraction -----------
def extract_text_from_pdf(file_path: str) -> str:
"""Extract text from PDF file."""
text = ""
try:
with open(file_path, 'rb') as file:
pdf_reader = PyPDF2.PdfReader(file)
for page in pdf_reader.pages:
text += page.extract_text() + "\n"
except Exception as e:
raise RuntimeError(f"Error reading PDF: {str(e)}")
return text
def extract_text_from_docx(file_path: str) -> str:
"""Extract text from DOCX file."""
try:
doc = docx.Document(file_path)
text = "\n".join([paragraph.text for paragraph in doc.paragraphs])
return text
except Exception as e:
raise RuntimeError(f"Error reading DOCX: {str(e)}")
def extract_text_from_txt(file_path: str) -> str:
"""Extract text from TXT file."""
try:
with open(file_path, 'r', encoding='utf-8', errors='ignore') as file:
return file.read()
except Exception as e:
raise RuntimeError(f"Error reading TXT: {str(e)}")
def extract_text_from_file(file_path: str) -> Tuple[str, str]:
"""Extract text from uploaded file based on extension."""
ext = Path(file_path).suffix.lower()
if ext == '.pdf':
return extract_text_from_pdf(file_path), 'PDF'
elif ext == '.docx':
return extract_text_from_docx(file_path), 'DOCX'
elif ext in ['.txt', '.md']:
return extract_text_from_txt(file_path), 'Text'
else:
raise ValueError(f"Unsupported file type: {ext}. Supported: .pdf, .docx, .txt, .md")
# ----------- Document Processing -----------
HEADING_RE = re.compile(r"^(#{1,6})\s+(.*)$", re.MULTILINE)
def read_markdown_files(kb_dir: Path) -> List[Dict]:
"""Read all markdown files from the knowledge base directory."""
docs = []
for md_path in sorted(kb_dir.glob("*.md")):
text = md_path.read_text(encoding="utf-8", errors="ignore")
title = md_path.stem.replace("_", " ").title()
m = re.search(r"^#\s+(.*)$", text, flags=re.MULTILINE)
if m:
title = m.group(1).strip()
docs.append({
"filepath": str(md_path),
"filename": md_path.name,
"title": title,
"text": text
})
return docs
def chunk_markdown(doc: Dict, chunk_chars: int = None, overlap: int = None) -> List[Dict]:
"""Split markdown document into overlapping chunks."""
if chunk_chars is None:
chunk_chars = config.chunk_size
if overlap is None:
overlap = config.chunk_overlap
text = doc["text"]
sections = re.split(r"(?=^##\s+|\n##\s+|\n###\s+|^###\s+)", text, flags=re.MULTILINE)
if len(sections) == 1:
sections = [text]
chunks = []
for sec in sections:
sec = sec.strip()
if not sec or len(sec) < 50:
continue
heading_match = HEADING_RE.search(sec)
section_heading = heading_match.group(2).strip() if heading_match else doc["title"]
start = 0
while start < len(sec):
end = min(start + chunk_chars, len(sec))
chunk_text = sec[start:end].strip()
if len(chunk_text) > 50:
chunks.append({
"doc_title": doc["title"],
"filename": doc["filename"],
"filepath": doc["filepath"],
"section": section_heading,
"content": chunk_text
})
if end == len(sec):
break
start = max(0, end - overlap)
return chunks
# ----------- KB Index -----------
class KBIndex:
def __init__(self):
self.embedder = SentenceTransformer(config.embedding_model)
self.reader_tokenizer = AutoTokenizer.from_pretrained(config.qa_model)
self.reader_model = AutoModelForQuestionAnswering.from_pretrained(config.qa_model)
self.reader = pipeline(
"question-answering",
model=self.reader_model,
tokenizer=self.reader_tokenizer,
max_answer_len=200,
handle_impossible_answer=True
)
self.index = None
self.embeddings = None
self.metadata = []
self.uploaded_file_active = False
# Paths based on config
self.embeddings_path = config.index_directory / "kb_embeddings.npy"
self.metadata_path = config.index_directory / "kb_metadata.json"
self.faiss_path = config.index_directory / "kb_faiss.index"
def build(self, kb_dir: Path):
"""Build the FAISS index from markdown files."""
docs = read_markdown_files(kb_dir)
if not docs:
raise RuntimeError(f"No markdown files found in {kb_dir.resolve()}")
all_chunks = []
for d in docs:
all_chunks.extend(chunk_markdown(d))
if not all_chunks:
raise RuntimeError("No content chunks generated from KB.")
texts = [c["content"] for c in all_chunks]
embeddings = self.embedder.encode(
texts,
batch_size=32,
convert_to_numpy=True,
show_progress_bar=True
)
faiss.normalize_L2(embeddings)
dim = embeddings.shape[1]
index = faiss.IndexFlatIP(dim)
index.add(embeddings)
self.index = index
self.embeddings = embeddings
self.metadata = all_chunks
self.uploaded_file_active = False
# Ensure index directory exists
config.index_directory.mkdir(exist_ok=True, parents=True)
np.save(self.embeddings_path, embeddings)
with open(self.metadata_path, "w", encoding="utf-8") as f:
json.dump(self.metadata, f, ensure_ascii=False, indent=2)
faiss.write_index(index, str(self.faiss_path))
def build_from_uploaded_file(self, file_path: str, filename: str):
"""Build temporary index from an uploaded file."""
text_content, file_type = extract_text_from_file(file_path)
if not text_content or len(text_content.strip()) < 100:
raise RuntimeError("File appears to be empty or too short.")
doc = {
"filepath": file_path,
"filename": filename,
"title": Path(filename).stem.replace("_", " ").title(),
"text": text_content
}
all_chunks = chunk_markdown(doc)
if not all_chunks:
raise RuntimeError("Could not extract meaningful content from file.")
texts = [c["content"] for c in all_chunks]
embeddings = self.embedder.encode(
texts,
batch_size=32,
convert_to_numpy=True,
show_progress_bar=False
)
faiss.normalize_L2(embeddings)
dim = embeddings.shape[1]
index = faiss.IndexFlatIP(dim)
index.add(embeddings)
self.index = index
self.embeddings = embeddings
self.metadata = all_chunks
self.uploaded_file_active = True
return len(all_chunks), file_type
def load(self) -> bool:
"""Load pre-built index from disk."""
if not (self.embeddings_path.exists() and self.metadata_path.exists() and self.faiss_path.exists()):
return False
self.embeddings = np.load(self.embeddings_path)
with open(self.metadata_path, "r", encoding="utf-8") as f:
self.metadata = json.load(f)
self.index = faiss.read_index(str(self.faiss_path))
self.uploaded_file_active = False
return True
def retrieve(self, query: str, top_k: int = 6) -> List[Tuple[int, float]]:
"""Retrieve top-k most similar chunks for a query."""
q_emb = self.embedder.encode([query], convert_to_numpy=True)
faiss.normalize_L2(q_emb)
D, I = self.index.search(q_emb, top_k)
return list(zip(I[0].tolist(), D[0].tolist()))
def answer(self, question: str, retrieved: List[Tuple[int, float]]) -> Tuple[Optional[str], float, List[Dict], float]:
"""Extract answer from retrieved chunks using QA model."""
candidates = []
for idx, sim in retrieved:
meta = self.metadata[idx]
ctx = meta["content"]
try:
out = self.reader(question=question, context=ctx)
score = float(out.get("score", 0.0))
answer_text = out.get("answer", "").strip()
if answer_text and len(answer_text) > 3:
expanded_answer = self._expand_answer(answer_text, ctx)
candidates.append({
"text": expanded_answer,
"original": answer_text,
"score": score,
"meta": meta,
"sim": float(sim),
"context": ctx
})
except Exception as e:
continue
if not candidates:
return None, 0.0, [], max([s for _, s in retrieved]) if retrieved else 0.0
candidates.sort(key=lambda x: x["score"] * 0.7 + x["sim"] * 0.3, reverse=True)
best = candidates[0]
citations = []
seen = set()
for idx, _ in retrieved[:3]:
m = self.metadata[idx]
key = (m["filename"], m["section"])
if key in seen:
continue
seen.add(key)
citations.append({
"title": m["doc_title"],
"filename": m["filename"],
"section": m["section"]
})
best_sim = max([s for _, s in retrieved]) if retrieved else 0.0
return best["text"], best["score"], citations, best_sim
def _expand_answer(self, answer: str, context: str, max_chars: int = 300) -> str:
"""Expand the extracted answer with surrounding context."""
answer_pos = context.lower().find(answer.lower())
if answer_pos == -1:
return answer
start = answer_pos
end = answer_pos + len(answer)
while start > 0 and context[start - 1] not in '.!?\n':
start -= 1
if answer_pos - start > max_chars // 2:
break
while end < len(context) and context[end] not in '.!?\n':
end += 1
if end - answer_pos > max_chars // 2:
break
if end < len(context) and context[end] in '.!?':
end += 1
expanded = context[start:end].strip()
if len(expanded) < 50:
sentences = context.split('.')
for i, sent in enumerate(sentences):
if answer.lower() in sent.lower():
result = sent.strip()
if i + 1 < len(sentences) and len(result) < 100:
result += ". " + sentences[i + 1].strip()
return result + ("." if not result.endswith(".") else "")
return expanded
# Initialize KB (will be done after config is loaded)
kb = None
def ensure_index():
"""Build index on first run or load from cache."""
try:
# Try to load existing index first
if kb.load():
print(f"βœ… Loaded existing index from {config.index_directory}")
return
except Exception as e:
print(f"⚠️ Could not load existing index: {e}")
# Try to build new index if KB directory exists and has files
if config.kb_directory.exists():
md_files = list(config.kb_directory.glob("*.md"))
if md_files:
try:
print(f"πŸ”¨ Building index from {len(md_files)} markdown files...")
kb.build(config.kb_directory)
print(f"βœ… Index built successfully!")
except Exception as e:
print(f"⚠️ Could not build index: {e}")
print(f"ℹ️ You can upload documents via the UI or add .md files to {config.kb_directory}")
else:
print(f"ℹ️ No markdown files found in {config.kb_directory}")
print(f"ℹ️ Upload documents via the UI or add .md files to start using the knowledge base")
else:
print(f"ℹ️ KB directory {config.kb_directory} not found. Creating it...")
config.kb_directory.mkdir(exist_ok=True, parents=True)
print(f"ℹ️ Add .md files to {config.kb_directory} or upload documents via the UI")
# ----------- Response Generation -----------
def format_citations(citations: List[Dict]) -> str:
"""Format citations as markdown list."""
if not citations:
return ""
lines = []
for c in citations:
lines.append(f"β€’ **{c['title']}** β€” _{c['section']}_")
return "\n".join(lines)
def respond(user_msg: str, history: List, uploaded_file_info: str = None) -> str:
"""Generate response to user query using RAG pipeline."""
user_msg = (user_msg or "").strip()
if not user_msg:
return config.welcome_message
if kb.index is None or len(kb.metadata) == 0:
return f"{config.no_answer_message}\n\n{config.upload_prompt}"
source_info = f" in the uploaded file" if kb.uploaded_file_active and uploaded_file_info else " in the knowledge base"
retrieved = kb.retrieve(user_msg, top_k=6)
if not retrieved or (retrieved and max([s for _, s in retrieved]) < 0.20):
return f"{config.no_answer_message}\n\n{config.upload_prompt}"
answer, qa_score, citations, best_sim = kb.answer(user_msg, retrieved)
if not answer or qa_score < 0.15 or best_sim < 0.25:
return (
f"{config.no_answer_message}\n\n"
f"The question seems outside the scope of what I currently know{source_info}. "
f"Try uploading a relevant document, or rephrase your question if you think the information might be here."
)
answer = answer.strip()
if answer and answer[-1] not in '.!?':
answer += "."
low_confidence = (qa_score < config.confidence_threshold) or (best_sim < config.similarity_threshold)
citations_md = format_citations(citations)
if low_confidence:
return (
f"⚠️ **Answer (Low Confidence):**\n\n{answer}\n\n"
f"---\n"
f"πŸ“š **Related Sources:**\n{citations_md}\n\n"
f"πŸ’¬ *I'm not entirely certain about this answer. If you have a more detailed document about this topic, please upload it for better accuracy.*"
)
else:
return (
f"βœ… **Answer:**\n\n{answer}\n\n"
f"---\n"
f"πŸ“š **Sources:**\n{citations_md}\n\n"
f"πŸ’‘ *Say \"show more details\" to see the full context.*"
)
# ----------- UI Handlers -----------
def process_message(user_input: str, history: List, uploaded_file_info: str) -> Tuple[List, Dict]:
"""Process user message and return updated chat history."""
user_input = (user_input or "").strip()
if not user_input:
return history, gr.update(value="")
reply = respond(user_input, history or [], uploaded_file_info)
new_history = (history or []) + [
{"role": "user", "content": user_input},
{"role": "assistant", "content": reply}
]
return new_history, gr.update(value="")
def process_quick(label: str, history: List, uploaded_file_info: str) -> Tuple[List, Dict]:
"""Process quick action button click."""
for btn_label, query in config.quick_actions:
if label == btn_label:
return process_message(query, history, uploaded_file_info)
return history, gr.update(value="")
def handle_file_upload(file):
"""Process uploaded file and build index."""
if file is None:
return "ℹ️ No file uploaded.", ""
try:
filename = Path(file.name).name
num_chunks, file_type = kb.build_from_uploaded_file(file.name, filename)
return (
f"βœ… **File processed successfully!**\n\n"
f"πŸ“„ **File:** {filename}\n"
f"πŸ“‹ **Type:** {file_type}\n"
f"πŸ”’ **Chunks:** {num_chunks}\n\n"
f"You can now ask questions about this document!"
), filename
except Exception as e:
return f"❌ **Error processing file:** {str(e)}\n\nPlease ensure the file is a valid PDF, DOCX, TXT, or MD file.", ""
def clear_uploaded_file():
"""Clear uploaded file and reload KB index."""
try:
if kb.load():
return "βœ… Switched back to knowledge base.", "", None
else:
kb.index = None
kb.embeddings = None
kb.metadata = []
kb.uploaded_file_active = False
return "ℹ️ No knowledge base found. Please upload a file or build the KB index.", "", None
except Exception as e:
return f"⚠️ Error: {str(e)}", "", None
def rebuild_index_handler():
"""Rebuild the search index from KB directory."""
try:
kb.build(config.kb_directory)
return "βœ… Index rebuilt successfully! Ready to answer questions."
except Exception as e:
return f"❌ Error rebuilding index: {str(e)}"
# ----------- Gradio UI -----------
def create_interface():
"""Create Gradio interface with configuration."""
with gr.Blocks(
title=config.client_name,
theme=gr.themes.Soft(primary_hue=config.theme_color),
css="""
.contain { max-width: 1200px; margin: auto; }
.quick-btn { min-width: 180px !important; }
"""
) as demo:
uploaded_file_state = gr.State("")
# Header
header_text = f"# πŸ€– {config.client_name}\n### {config.client_description}"
if config.client_logo:
header_text += f"\n![Logo]({config.client_logo})"
gr.Markdown(header_text)
# File upload section
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### πŸ“€ Upload Document")
file_upload = gr.File(
label="Upload PDF, DOCX, TXT, or MD file",
file_types=[".pdf", ".docx", ".txt", ".md"],
type="filepath"
)
upload_status = gr.Markdown("ℹ️ Upload a file to ask questions about it.")
with gr.Row():
clear_btn = gr.Button("πŸ”„ Clear & Use KB", variant="secondary", size="sm")
# Main chat interface
with gr.Row():
with gr.Column(scale=1):
chat = gr.Chatbot(
height=500,
show_copy_button=True,
type="messages",
avatar_images=(None, "https://em-content.zobj.net/source/twitter/376/robot_1f916.png")
)
with gr.Row():
txt = gr.Textbox(
placeholder="πŸ’¬ Ask a question about the document or knowledge base...",
scale=9,
show_label=False,
container=False
)
send = gr.Button("Send", variant="primary", scale=1)
# Quick action buttons (if configured)
if config.quick_actions:
with gr.Accordion("⚑ Quick Actions", open=False):
with gr.Row():
quick_buttons = []
for label, _ in config.quick_actions:
btn = gr.Button(label, elem_classes="quick-btn", size="sm")
quick_buttons.append((btn, label))
# Admin section
with gr.Accordion("πŸ”§ Admin Panel", open=False):
gr.Markdown(
"""
**Rebuild Index:** Use this after adding or modifying files in the `{config.kb_directory}` directory.
The system will re-scan all markdown files and update the search index.
"""
)
with gr.Row():
rebuild_btn = gr.Button("πŸ”„ Rebuild KB Index", variant="secondary")
status_msg = gr.Markdown("")
# Event handlers
file_upload.change(
handle_file_upload,
inputs=[file_upload],
outputs=[upload_status, uploaded_file_state]
)
clear_btn.click(
clear_uploaded_file,
outputs=[upload_status, uploaded_file_state, file_upload]
)
send.click(
process_message,
inputs=[txt, chat, uploaded_file_state],
outputs=[chat, txt]
)
txt.submit(
process_message,
inputs=[txt, chat, uploaded_file_state],
outputs=[chat, txt]
)
if config.quick_actions:
for btn, label in quick_buttons:
btn.click(
process_quick,
inputs=[gr.State(label), chat, uploaded_file_state],
outputs=[chat, txt]
)
rebuild_btn.click(rebuild_index_handler, outputs=status_msg)
# Footer
gr.Markdown(
"""
---
πŸ’‘ **Tips:**
- Upload a document to ask questions specifically about that file
- Use "Clear & Use KB" to switch back to the knowledge base
- Be specific in your questions for better results
- Check the cited sources for full context
"""
)
return demo
# ----------- Main Entry Point -----------
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Configurable RAG Assistant')
parser.add_argument('--config', type=str, default='config.yaml',
help='Path to configuration YAML file (default: config.yaml)')
args = parser.parse_args()
# Load configuration
config = Config(args.config)
# Initialize KB with config
kb = KBIndex()
ensure_index()
# Create and launch interface
demo = create_interface()
demo.launch()