Spaces:
Sleeping
Sleeping
Commit ·
37f0716
1
Parent(s): 3937ce7
changing the path urls
Browse files
app.py
CHANGED
|
@@ -29,7 +29,6 @@ from qdrant_client import QdrantClient
|
|
| 29 |
# 1️⃣ Load OCR + Embedding Models + Groq Client
|
| 30 |
# -------------------------------
|
| 31 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 32 |
-
# Using a try-except block in case the model initialization fails (e.g., memory constraints)
|
| 33 |
try:
|
| 34 |
print(f"Loading OCR model to {device}...")
|
| 35 |
ocr_model = ocr_predictor(pretrained=True).to(device)
|
|
@@ -60,7 +59,7 @@ QDRANT_URL = os.environ.get("QDRANT_URL", "https://bdf142ef-7e2a-433b-87a0-301ff
|
|
| 60 |
QDRANT_API_KEY = os.environ.get("QDRANT_API_KEY")
|
| 61 |
COLLECTION_NAME = "multimodal_rag_store"
|
| 62 |
|
| 63 |
-
# --- Helper Functions ---
|
| 64 |
|
| 65 |
# -------------------------------
|
| 66 |
# 2️⃣ Helper: Check if image has substantial text
|
|
@@ -74,11 +73,10 @@ def has_substantial_text(text, min_words=10):
|
|
| 74 |
|
| 75 |
|
| 76 |
# -------------------------------
|
| 77 |
-
# 3️⃣
|
| 78 |
# -------------------------------
|
| 79 |
def analyze_image_with_vision(img_path=None, img_bytes=None, pil_image=None, max_retries=3):
|
| 80 |
if not groq_client:
|
| 81 |
-
print("❌ Groq client not initialized. Cannot use vision model.")
|
| 82 |
return ""
|
| 83 |
|
| 84 |
for attempt in range(max_retries):
|
|
@@ -96,7 +94,6 @@ def analyze_image_with_vision(img_path=None, img_bytes=None, pil_image=None, max
|
|
| 96 |
img_format = img_path.lower().split('.')[-1]
|
| 97 |
elif img_bytes:
|
| 98 |
img_data = img_bytes
|
| 99 |
-
# We assume PNG for raw bytes if format is unknown/irrelevant for a PIL-based source
|
| 100 |
else:
|
| 101 |
return ""
|
| 102 |
|
|
@@ -107,7 +104,7 @@ def analyze_image_with_vision(img_path=None, img_bytes=None, pil_image=None, max
|
|
| 107 |
vision_prompt = """Analyze this image carefully and provide a detailed description:
|
| 108 |
1. IDENTIFY THE TYPE: Is this a chart, graph, table, diagram, photograph, or text document?
|
| 109 |
2. IF IT'S A CHART/GRAPH/TABLE:
|
| 110 |
-
- Specify the exact type
|
| 111 |
- List ALL categories/labels shown
|
| 112 |
- Describe the data values and trends
|
| 113 |
- Mention axis labels, title, legend if present
|
|
@@ -153,7 +150,6 @@ Provide a comprehensive description suitable for semantic search. Be specific an
|
|
| 153 |
print(f"❌ Vision model '{VISION_MODEL}' not available! Skipping vision analysis.")
|
| 154 |
return ""
|
| 155 |
else:
|
| 156 |
-
print(f"❌ Vision API Error on attempt {attempt+1}: {e}")
|
| 157 |
if attempt < max_retries - 1:
|
| 158 |
time.sleep(2)
|
| 159 |
continue
|
|
@@ -186,7 +182,6 @@ def extract_text_from_image(img_path):
|
|
| 186 |
else:
|
| 187 |
print(f"🖼️ {os.path.basename(img_path)}: Using Vision Model (graph/chart/picture)")
|
| 188 |
vision_summary = analyze_image_with_vision(img_path=img_path)
|
| 189 |
-
# Fallback to sparse OCR text if vision summary fails
|
| 190 |
return vision_summary if vision_summary else ocr_text
|
| 191 |
except Exception as e:
|
| 192 |
print(f"❌ Error processing {img_path}: {e}")
|
|
@@ -208,7 +203,7 @@ def extract_text_from_txt(file_path):
|
|
| 208 |
|
| 209 |
|
| 210 |
# -------------------------------
|
| 211 |
-
# 6️⃣
|
| 212 |
# -------------------------------
|
| 213 |
def extract_content_from_pdf(pdf_path):
|
| 214 |
try:
|
|
@@ -221,11 +216,9 @@ def extract_content_from_pdf(pdf_path):
|
|
| 221 |
text = page.get_text()
|
| 222 |
if text.strip():
|
| 223 |
page_content.append(f"[Page {page_num} - Text Content]\n{text}")
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
# 2. Vision analysis of the entire page image (for charts/layouts)
|
| 227 |
if groq_client:
|
| 228 |
-
print(f"🔄 {os.path.basename(pdf_path)} (Page {page_num}): Rendering page for vision analysis...")
|
| 229 |
try:
|
| 230 |
mat = fitz.Matrix(2, 2)
|
| 231 |
pix = page.get_pixmap(matrix=mat)
|
|
@@ -236,11 +229,10 @@ def extract_content_from_pdf(pdf_path):
|
|
| 236 |
if vision_analysis and len(vision_analysis.strip()) > 30:
|
| 237 |
vision_section = f"[Page {page_num} - Visual Analysis]\n{vision_analysis}"
|
| 238 |
page_content.append(vision_section)
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
print(f"❌ Error rendering/analyzing page {page_num}: {e}")
|
| 242 |
|
| 243 |
-
# 3. OCR on embedded images (
|
| 244 |
if ocr_model:
|
| 245 |
image_list = page.get_images(full=True)
|
| 246 |
for img_index, img_info in enumerate(image_list, 1):
|
|
@@ -248,10 +240,8 @@ def extract_content_from_pdf(pdf_path):
|
|
| 248 |
xref = img_info[0]
|
| 249 |
base_image = doc.extract_image(xref)
|
| 250 |
image_bytes = base_image["image"]
|
| 251 |
-
|
| 252 |
image = Image.open(BytesIO(image_bytes)).convert("RGB")
|
| 253 |
image_np = np.array(image)
|
| 254 |
-
|
| 255 |
result = ocr_model([image_np])
|
| 256 |
ocr_text = []
|
| 257 |
for ocr_page in result.pages:
|
|
@@ -264,14 +254,12 @@ def extract_content_from_pdf(pdf_path):
|
|
| 264 |
if has_substantial_text(extracted_text, min_words=10):
|
| 265 |
page_content.append(f"[Page {page_num} - Embedded Image {img_index} OCR]\n{extracted_text}")
|
| 266 |
else:
|
| 267 |
-
# Use vision for image within PDF if OCR is sparse
|
| 268 |
vision_summary = analyze_image_with_vision(img_bytes=image_bytes)
|
| 269 |
if vision_summary:
|
| 270 |
page_content.append(
|
| 271 |
f"[Page {page_num} - Embedded Image {img_index} Analysis]\n{vision_summary}")
|
| 272 |
|
| 273 |
-
except Exception
|
| 274 |
-
print(f"❌ Error processing embedded image {img_index}: {e}")
|
| 275 |
continue
|
| 276 |
|
| 277 |
if page_content:
|
|
@@ -288,7 +276,7 @@ def extract_content_from_pdf(pdf_path):
|
|
| 288 |
|
| 289 |
|
| 290 |
# -------------------------------
|
| 291 |
-
# 7️⃣ Process All Document Types
|
| 292 |
# -------------------------------
|
| 293 |
def create_documents_from_folder(folder_path):
|
| 294 |
docs = []
|
|
@@ -296,9 +284,6 @@ def create_documents_from_folder(folder_path):
|
|
| 296 |
for filename in files:
|
| 297 |
full_path = os.path.join(root, filename)
|
| 298 |
file_ext = filename.lower().split('.')[-1]
|
| 299 |
-
print(f"\n{'=' * 60}")
|
| 300 |
-
print(f"Processing: {filename}")
|
| 301 |
-
print(f"{'=' * 60}")
|
| 302 |
text = ""
|
| 303 |
|
| 304 |
if file_ext in ["jpg", "jpeg", "png"]:
|
|
@@ -308,7 +293,6 @@ def create_documents_from_folder(folder_path):
|
|
| 308 |
elif file_ext == "pdf":
|
| 309 |
text = extract_content_from_pdf(full_path)
|
| 310 |
else:
|
| 311 |
-
print(f"⏭️ Skipping unsupported file: {filename}")
|
| 312 |
continue
|
| 313 |
|
| 314 |
if text.strip():
|
|
@@ -323,30 +307,22 @@ def create_documents_from_folder(folder_path):
|
|
| 323 |
}
|
| 324 |
)
|
| 325 |
docs.append(doc)
|
| 326 |
-
|
| 327 |
-
else:
|
| 328 |
-
print(f"⚠️ Skipping {filename} - no content extracted")
|
| 329 |
return docs
|
| 330 |
|
|
|
|
| 331 |
|
| 332 |
# -------------------------------
|
| 333 |
-
# 8️⃣
|
| 334 |
# -------------------------------
|
| 335 |
def build_or_update_qdrant_store(folder_path):
|
| 336 |
if not QDRANT_API_KEY:
|
| 337 |
-
print("❌ QDRANT_API_KEY is missing. Skipping database build.")
|
| 338 |
return None
|
| 339 |
|
| 340 |
-
print("\n" + "=" * 60)
|
| 341 |
-
print("🔄 STARTING DOCUMENT PROCESSING FOR QDRANT")
|
| 342 |
-
print("=" * 60)
|
| 343 |
docs = create_documents_from_folder(folder_path)
|
| 344 |
if not docs:
|
| 345 |
-
print("\n⚠️ No valid documents found!")
|
| 346 |
return None
|
| 347 |
|
| 348 |
-
print(f"\n✅ Successfully processed {len(docs)} documents")
|
| 349 |
-
print(f"☁️ Uploading documents to Qdrant Cloud collection: {COLLECTION_NAME}...")
|
| 350 |
try:
|
| 351 |
vector_store = Qdrant.from_documents(
|
| 352 |
docs,
|
|
@@ -356,20 +332,18 @@ def build_or_update_qdrant_store(folder_path):
|
|
| 356 |
collection_name=COLLECTION_NAME,
|
| 357 |
force_recreate=True
|
| 358 |
)
|
| 359 |
-
print(f"✅ Successfully created/updated Qdrant collection: {COLLECTION_NAME}")
|
| 360 |
return vector_store
|
| 361 |
except Exception as e:
|
| 362 |
print(f"❌ Error connecting or uploading to Qdrant: {e}")
|
| 363 |
-
print("Please check your QDRANT_URL and QDRANT_API_KEY")
|
| 364 |
return None
|
| 365 |
|
| 366 |
|
| 367 |
# -------------------------------
|
| 368 |
-
# 9️⃣
|
| 369 |
# -------------------------------
|
| 370 |
def query_qdrant_store(query_text, k=3):
|
| 371 |
if not QDRANT_API_KEY:
|
| 372 |
-
print("❌ QDRANT_API_KEY is missing. Cannot query database.")
|
| 373 |
return []
|
| 374 |
|
| 375 |
try:
|
|
@@ -383,7 +357,6 @@ def query_qdrant_store(query_text, k=3):
|
|
| 383 |
collection_name=COLLECTION_NAME,
|
| 384 |
embeddings=embedding_model
|
| 385 |
)
|
| 386 |
-
print(f"✅ Connected to Qdrant collection: {COLLECTION_NAME}")
|
| 387 |
except Exception as e:
|
| 388 |
print(f"❌ Error connecting to Qdrant: {e}")
|
| 389 |
return []
|
|
@@ -395,33 +368,25 @@ def query_qdrant_store(query_text, k=3):
|
|
| 395 |
is_visual_query = any(keyword in query_text.lower() for keyword in visual_query_keywords)
|
| 396 |
|
| 397 |
if is_visual_query:
|
| 398 |
-
print(f"🔍 Detected visual content query - applying smart re-ranking...")
|
| 399 |
reranked_results = []
|
| 400 |
for doc, score in results:
|
| 401 |
boost = 0.0
|
| 402 |
-
|
| 403 |
-
if "Visual Analysis]" in doc.page_content or "bar chart" in doc.page_content.lower() or "line graph" in doc.page_content.lower():
|
| 404 |
visual_content = doc.page_content.lower()
|
| 405 |
|
| 406 |
-
# Apply high boost if query type matches content type
|
| 407 |
if 'bar chart' in query_text.lower() and 'bar chart' in visual_content:
|
| 408 |
boost += 1.0
|
| 409 |
elif 'pie chart' in query_text.lower() and 'pie chart' in visual_content:
|
| 410 |
boost += 1.0
|
| 411 |
-
elif (
|
| 412 |
-
boost +=
|
| 413 |
-
# Apply moderate boost if a general visual query matches any visual content
|
| 414 |
-
elif any(kw in query_text.lower() for kw in ['chart', 'graph', 'visualization']):
|
| 415 |
-
if any(kw in visual_content for kw in ['chart', 'graph', 'plot', 'diagram', 'table']):
|
| 416 |
-
boost += 0.5
|
| 417 |
else:
|
| 418 |
boost += 0.2
|
| 419 |
|
| 420 |
-
# Note: Qdrant similarity search returns distance (lower is better), so we *subtract* the boost to make the score lower (better).
|
| 421 |
adjusted_score = score - boost
|
| 422 |
reranked_results.append((doc, adjusted_score, score))
|
| 423 |
|
| 424 |
-
reranked_results.sort(key=lambda x: x[1])
|
| 425 |
results = [(doc, adj_score) for doc, adj_score, _ in reranked_results[:k]]
|
| 426 |
else:
|
| 427 |
results = results[:k]
|
|
@@ -438,12 +403,11 @@ def query_qdrant_store(query_text, k=3):
|
|
| 438 |
|
| 439 |
|
| 440 |
# -------------------------------
|
| 441 |
-
#
|
| 442 |
# -------------------------------
|
| 443 |
def answer_question_with_llm(query_text, retrieved_docs, max_tokens=1000):
|
| 444 |
if not groq_client:
|
| 445 |
return "❌ Groq client not initialized. Cannot generate answer."
|
| 446 |
-
|
| 447 |
if not retrieved_docs:
|
| 448 |
return "❌ No relevant documents found to answer your question."
|
| 449 |
|
|
@@ -454,12 +418,7 @@ def answer_question_with_llm(query_text, retrieved_docs, max_tokens=1000):
|
|
| 454 |
metadata = doc['metadata']
|
| 455 |
timestamp = metadata.get('upload_timestamp')
|
| 456 |
|
| 457 |
-
readable_time = "N/A"
|
| 458 |
-
if timestamp:
|
| 459 |
-
try:
|
| 460 |
-
readable_time = time.ctime(float(timestamp))
|
| 461 |
-
except (ValueError, TypeError):
|
| 462 |
-
readable_time = str(timestamp)
|
| 463 |
|
| 464 |
metadata_str = (
|
| 465 |
f"Source: {source}\n"
|
|
@@ -481,8 +440,6 @@ def answer_question_with_llm(query_text, retrieved_docs, max_tokens=1000):
|
|
| 481 |
|
| 482 |
system_prompt = """You are a concise AI assistant. Answer the user's question *only* using the provided documents.
|
| 483 |
- Be brief and to the point.
|
| 484 |
-
- The documents include `[METADATA]` and `[CONTENT]`.
|
| 485 |
-
- Use the metadata to answer questions about file details (like upload time, source, or file type).
|
| 486 |
- If the answer is not in the documents or metadata, simply state 'That information is not available in the documents.'"""
|
| 487 |
|
| 488 |
user_prompt = f"""DOCUMENTS:
|
|
@@ -493,7 +450,6 @@ QUESTION: {query_text}
|
|
| 493 |
ANSWER: (Provide a concise answer based *only* on the documents)"""
|
| 494 |
|
| 495 |
try:
|
| 496 |
-
print(f"\n🤖 Generating answer with {LLM_MODEL}...")
|
| 497 |
response = groq_client.chat.completions.create(
|
| 498 |
model=LLM_MODEL,
|
| 499 |
messages=[
|
|
@@ -507,36 +463,22 @@ ANSWER: (Provide a concise answer based *only* on the documents)"""
|
|
| 507 |
answer = response.choices[0].message.content
|
| 508 |
return answer
|
| 509 |
except Exception as e:
|
| 510 |
-
print(f"❌ Error calling LLM: {e}")
|
| 511 |
return f"❌ Error generating answer: {str(e)}"
|
| 512 |
|
| 513 |
# -------------------------------
|
| 514 |
-
#
|
| 515 |
# -------------------------------
|
| 516 |
def get_rag_response(query_text: str, k: int = 3) -> Dict[str, Any]:
|
| 517 |
-
"""
|
| 518 |
-
Core RAG pipeline: retrieves, generates, and formats response.
|
| 519 |
-
"""
|
| 520 |
-
print("\n" + "=" * 80)
|
| 521 |
print(f"❓ QUERY: {query_text}")
|
| 522 |
-
print("=" * 80)
|
| 523 |
-
|
| 524 |
-
print("\n📚 Retrieving relevant documents from Qdrant...")
|
| 525 |
retrieved_docs = query_qdrant_store(query_text, k=k)
|
| 526 |
|
| 527 |
if not retrieved_docs:
|
| 528 |
-
print("❌ No relevant documents found.")
|
| 529 |
return {
|
| 530 |
"answer": "❌ No relevant documents found to answer your question. Please upload files first.",
|
| 531 |
"sources": []
|
| 532 |
}
|
| 533 |
|
| 534 |
-
print(f"\n📄 Retrieved {len(retrieved_docs)} relevant documents:")
|
| 535 |
-
for i, doc in enumerate(retrieved_docs, 1):
|
| 536 |
-
# We need the original score for display, which is `score` in the dict.
|
| 537 |
-
# The stored score might be the adjusted one if re-ranking occurred.
|
| 538 |
-
print(f" {i}. {doc['source']} (Score: {doc['score']:.4f})")
|
| 539 |
-
|
| 540 |
answer = answer_question_with_llm(query_text, retrieved_docs)
|
| 541 |
|
| 542 |
sources_list = [
|
|
@@ -548,19 +490,13 @@ def get_rag_response(query_text: str, k: int = 3) -> Dict[str, Any]:
|
|
| 548 |
"sources": sources_list
|
| 549 |
}
|
| 550 |
|
| 551 |
-
print("\n" + "=" * 80)
|
| 552 |
-
print(f"💡 ANSWER: {answer}")
|
| 553 |
-
print("=" * 80)
|
| 554 |
-
|
| 555 |
return response_data
|
| 556 |
|
| 557 |
# -------------------------------
|
| 558 |
-
#
|
| 559 |
# -------------------------------
|
| 560 |
def process_single_file(file_path: str, filename: str) -> Document:
|
| 561 |
-
"""
|
| 562 |
-
Processes a single file from a file path and returns a LangChain Document.
|
| 563 |
-
"""
|
| 564 |
file_ext = filename.lower().split('.')[-1]
|
| 565 |
text = ""
|
| 566 |
|
|
@@ -570,10 +506,7 @@ def process_single_file(file_path: str, filename: str) -> Document:
|
|
| 570 |
text = extract_text_from_txt(file_path)
|
| 571 |
elif file_ext == "pdf":
|
| 572 |
text = extract_content_from_pdf(file_path)
|
| 573 |
-
|
| 574 |
-
print(f"⏭️ Skipping unsupported file: {filename}")
|
| 575 |
-
return None
|
| 576 |
-
|
| 577 |
if text.strip():
|
| 578 |
doc = Document(
|
| 579 |
page_content=text,
|
|
@@ -584,26 +517,15 @@ def process_single_file(file_path: str, filename: str) -> Document:
|
|
| 584 |
"upload_timestamp": time.time()
|
| 585 |
}
|
| 586 |
)
|
| 587 |
-
print(f"✅ Processed {filename} ({len(text)} chars)")
|
| 588 |
return doc
|
| 589 |
-
|
| 590 |
-
print(f"⚠️ Skipping {filename} - no content extracted")
|
| 591 |
-
return None
|
| 592 |
|
| 593 |
def add_documents_to_qdrant(docs: List[Document]):
|
| 594 |
-
"""
|
| 595 |
-
|
| 596 |
-
"""
|
| 597 |
-
if not QDRANT_API_KEY:
|
| 598 |
-
print("❌ QDRANT_API_KEY is missing. Cannot add documents.")
|
| 599 |
-
return
|
| 600 |
-
|
| 601 |
-
if not docs:
|
| 602 |
-
print("No documents to add.")
|
| 603 |
return
|
| 604 |
|
| 605 |
try:
|
| 606 |
-
print(f"\n☁️ Connecting to Qdrant to add {len(docs)} new documents...")
|
| 607 |
client = QdrantClient(url=QDRANT_URL, api_key=QDRANT_API_KEY)
|
| 608 |
vector_store = Qdrant(
|
| 609 |
client=client,
|
|
@@ -616,10 +538,10 @@ def add_documents_to_qdrant(docs: List[Document]):
|
|
| 616 |
print(f"❌ Error adding documents to Qdrant: {e}")
|
| 617 |
raise HTTPException(status_code=500, detail=f"Error updating vector store: {e}")
|
| 618 |
|
|
|
|
| 619 |
# -------------------------------
|
| 620 |
-
# 🚀 14. Gradio UI Setup
|
| 621 |
# -------------------------------
|
| 622 |
-
|
| 623 |
def create_gradio_ui():
|
| 624 |
"""
|
| 625 |
Creates the Gradio Blocks UI.
|
|
@@ -656,8 +578,7 @@ def create_gradio_ui():
|
|
| 656 |
failed_count = 0
|
| 657 |
|
| 658 |
for file_obj in file_list:
|
| 659 |
-
|
| 660 |
-
full_path = file_obj.name
|
| 661 |
filename = os.path.basename(full_path)
|
| 662 |
|
| 663 |
try:
|
|
@@ -679,10 +600,12 @@ def create_gradio_ui():
|
|
| 679 |
|
| 680 |
return f"✅ Processing complete. Added {processed_count} files. Failed: {failed_count}."
|
| 681 |
|
|
|
|
| 682 |
with gr.Blocks(theme="soft") as demo:
|
| 683 |
gr.Markdown("# 🧠 Multimodal RAG System (Powered by Qdrant Cloud)")
|
| 684 |
|
| 685 |
with gr.Tabs():
|
|
|
|
| 686 |
with gr.TabItem("Chat with Documents"):
|
| 687 |
gr.ChatInterface(
|
| 688 |
fn=gradio_chat_response_func,
|
|
@@ -696,23 +619,28 @@ def create_gradio_ui():
|
|
| 696 |
],
|
| 697 |
)
|
| 698 |
|
| 699 |
-
|
| 700 |
-
|
| 701 |
-
gr.
|
| 702 |
-
|
|
|
|
|
|
|
|
|
|
| 703 |
file_count="multiple",
|
| 704 |
-
file_types=["
|
| 705 |
-
|
| 706 |
-
|
| 707 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 708 |
fn=gradio_upload_func,
|
| 709 |
-
inputs=
|
| 710 |
-
outputs=
|
| 711 |
)
|
| 712 |
-
|
| 713 |
-
|
| 714 |
-
|
| 715 |
-
return demo
|
| 716 |
|
| 717 |
|
| 718 |
# -------------------------------
|
|
@@ -738,60 +666,34 @@ app = FastAPI(title="🧠 Multimodal RAG API")
|
|
| 738 |
|
| 739 |
@app.on_event("startup")
|
| 740 |
def on_startup():
|
| 741 |
-
"""
|
| 742 |
-
Checks keys and builds the initial database on server startup.
|
| 743 |
-
"""
|
| 744 |
print("🚀 FastAPI app starting up...")
|
| 745 |
|
| 746 |
-
# Check Groq API key
|
| 747 |
if not os.environ.get("GROQ_API_KEY"):
|
| 748 |
print("⚠️ WARNING: GROQ_API_KEY not set!")
|
| 749 |
-
else:
|
| 750 |
-
print("✅ Groq API Key found")
|
| 751 |
-
|
| 752 |
-
# Check Qdrant API Key
|
| 753 |
if not QDRANT_API_KEY:
|
| 754 |
print("⚠️ WARNING: QDRANT_API_KEY not set! Database functions will fail.")
|
| 755 |
-
else:
|
| 756 |
-
print("✅ Qdrant API Key found")
|
| 757 |
-
|
| 758 |
-
print(f"✅ Vision Model: {VISION_MODEL}")
|
| 759 |
-
print(f"✅ LLM Model: {LLM_MODEL}\n")
|
| 760 |
|
| 761 |
-
# Rebuild the cloud database on startup
|
| 762 |
folder = "data"
|
| 763 |
if os.path.exists(folder):
|
| 764 |
-
print("\n" + "=" * 60)
|
| 765 |
-
print(f"🔄 Found 'data' folder, rebuilding Qdrant collection...")
|
| 766 |
-
print("=" * 60)
|
| 767 |
build_or_update_qdrant_store(folder)
|
| 768 |
else:
|
| 769 |
-
print("
|
| 770 |
-
print(f"ℹ️ No 'data' folder found. Skipping initial build.")
|
| 771 |
-
print(" Database will be populated via the /upload endpoint.")
|
| 772 |
-
print("=" * 60)
|
| 773 |
|
| 774 |
-
# ---
|
| 775 |
-
# 🚀 API Endpoints
|
| 776 |
-
# ---
|
| 777 |
|
| 778 |
@app.post("/query/", response_model=QueryResponse)
|
| 779 |
async def handle_query(request: QueryRequest):
|
| 780 |
-
"""
|
| 781 |
-
Executes a RAG query against the vector database.
|
| 782 |
-
"""
|
| 783 |
try:
|
| 784 |
response_data = get_rag_response(request.query, request.k)
|
| 785 |
return response_data
|
| 786 |
except Exception as e:
|
| 787 |
-
print(f"❌ Error during query: {e}")
|
| 788 |
raise HTTPException(status_code=500, detail=str(e))
|
| 789 |
|
| 790 |
@app.post("/upload/", response_model=UploadResponse)
|
| 791 |
async def handle_upload(files: List[UploadFile] = File(...)):
|
| 792 |
-
"""
|
| 793 |
-
Uploads one or more files, processes them, and adds them to the vector DB.
|
| 794 |
-
"""
|
| 795 |
if not QDRANT_API_KEY:
|
| 796 |
raise HTTPException(status_code=500, detail="QDRANT_API_KEY is not set. Upload failed.")
|
| 797 |
|
|
@@ -802,16 +704,12 @@ async def handle_upload(files: List[UploadFile] = File(...)):
|
|
| 802 |
for file in files:
|
| 803 |
tmp_path = None
|
| 804 |
try:
|
| 805 |
-
# 1. Save file to a temporary location
|
| 806 |
with tempfile.NamedTemporaryFile(delete=False, suffix=f"_{file.filename}") as tmp:
|
| 807 |
shutil.copyfileobj(file.file, tmp)
|
| 808 |
tmp_path = tmp.name
|
| 809 |
|
| 810 |
-
# 2. Process the single file
|
| 811 |
-
print(f"\nProcessing uploaded file: {file.filename}")
|
| 812 |
doc = process_single_file(tmp_path, file.filename)
|
| 813 |
|
| 814 |
-
# 3. Collect processed document or log failure
|
| 815 |
if doc:
|
| 816 |
docs_to_add.append(doc)
|
| 817 |
processed_files.append(file.filename)
|
|
@@ -819,25 +717,19 @@ async def handle_upload(files: List[UploadFile] = File(...)):
|
|
| 819 |
failed_files.append(file.filename)
|
| 820 |
|
| 821 |
except Exception as e:
|
| 822 |
-
print(f"❌ Error processing file {file.filename}: {e}")
|
| 823 |
failed_files.append(file.filename)
|
| 824 |
|
| 825 |
finally:
|
| 826 |
-
# 4. Cleanup temp file
|
| 827 |
if tmp_path and os.path.exists(tmp_path):
|
| 828 |
os.unlink(tmp_path)
|
| 829 |
-
# 5. Close the UploadFile stream
|
| 830 |
file.file.close()
|
| 831 |
|
| 832 |
-
# 6. Add all successfully processed documents to Qdrant in a batch
|
| 833 |
if docs_to_add:
|
| 834 |
try:
|
| 835 |
add_documents_to_qdrant(docs_to_add)
|
| 836 |
except HTTPException:
|
| 837 |
-
# If batch upload failed, revert the success list to failures
|
| 838 |
failed_files.extend(processed_files)
|
| 839 |
processed_files = []
|
| 840 |
-
print(f"❌ Batch upload to Qdrant failed for {len(docs_to_add)} documents.")
|
| 841 |
|
| 842 |
return {
|
| 843 |
"message": f"Processing complete. Added {len(processed_files)} file(s) to the database.",
|
|
@@ -849,7 +741,6 @@ async def handle_upload(files: List[UploadFile] = File(...)):
|
|
| 849 |
# 🚀 15. Create and Mount the Apps
|
| 850 |
# -------------------------------
|
| 851 |
|
| 852 |
-
# Define the root redirect and API info endpoints BEFORE mounting Gradio
|
| 853 |
@app.get("/")
|
| 854 |
def redirect_to_ui():
|
| 855 |
"""Redirect root to the Gradio UI"""
|
|
@@ -862,14 +753,8 @@ def api_info():
|
|
| 862 |
"message": "Welcome to the Multimodal RAG API",
|
| 863 |
"endpoints": {
|
| 864 |
"ui": "/ui - Gradio interface",
|
| 865 |
-
"docs": "/docs - API documentation (Swagger UI)",
|
| 866 |
-
"redoc": "/redoc - Alternative API documentation",
|
| 867 |
"query": "POST /query/ - Execute RAG queries",
|
| 868 |
"upload": "POST /upload/ - Upload and process files"
|
| 869 |
-
},
|
| 870 |
-
"models": {
|
| 871 |
-
"vision_model": VISION_MODEL,
|
| 872 |
-
"llm_model": LLM_MODEL
|
| 873 |
}
|
| 874 |
}
|
| 875 |
|
|
@@ -877,7 +762,4 @@ def api_info():
|
|
| 877 |
gradio_ui = create_gradio_ui()
|
| 878 |
|
| 879 |
# Mount the Gradio UI at /ui path
|
| 880 |
-
app = gr.mount_gradio_app(app, gradio_ui, path="/ui")
|
| 881 |
-
|
| 882 |
-
# The application is now ready to be run by uvicorn:
|
| 883 |
-
# uvicorn app:app --host 0.0.0.0 --port 7860
|
|
|
|
| 29 |
# 1️⃣ Load OCR + Embedding Models + Groq Client
|
| 30 |
# -------------------------------
|
| 31 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
| 32 |
try:
|
| 33 |
print(f"Loading OCR model to {device}...")
|
| 34 |
ocr_model = ocr_predictor(pretrained=True).to(device)
|
|
|
|
| 59 |
QDRANT_API_KEY = os.environ.get("QDRANT_API_KEY")
|
| 60 |
COLLECTION_NAME = "multimodal_rag_store"
|
| 61 |
|
| 62 |
+
# --- Helper Functions (2 to 7) ---
|
| 63 |
|
| 64 |
# -------------------------------
|
| 65 |
# 2️⃣ Helper: Check if image has substantial text
|
|
|
|
| 73 |
|
| 74 |
|
| 75 |
# -------------------------------
|
| 76 |
+
# 3️⃣ Vision Analysis using Groq Llama 4 Scout
|
| 77 |
# -------------------------------
|
| 78 |
def analyze_image_with_vision(img_path=None, img_bytes=None, pil_image=None, max_retries=3):
|
| 79 |
if not groq_client:
|
|
|
|
| 80 |
return ""
|
| 81 |
|
| 82 |
for attempt in range(max_retries):
|
|
|
|
| 94 |
img_format = img_path.lower().split('.')[-1]
|
| 95 |
elif img_bytes:
|
| 96 |
img_data = img_bytes
|
|
|
|
| 97 |
else:
|
| 98 |
return ""
|
| 99 |
|
|
|
|
| 104 |
vision_prompt = """Analyze this image carefully and provide a detailed description:
|
| 105 |
1. IDENTIFY THE TYPE: Is this a chart, graph, table, diagram, photograph, or text document?
|
| 106 |
2. IF IT'S A CHART/GRAPH/TABLE:
|
| 107 |
+
- Specify the exact type
|
| 108 |
- List ALL categories/labels shown
|
| 109 |
- Describe the data values and trends
|
| 110 |
- Mention axis labels, title, legend if present
|
|
|
|
| 150 |
print(f"❌ Vision model '{VISION_MODEL}' not available! Skipping vision analysis.")
|
| 151 |
return ""
|
| 152 |
else:
|
|
|
|
| 153 |
if attempt < max_retries - 1:
|
| 154 |
time.sleep(2)
|
| 155 |
continue
|
|
|
|
| 182 |
else:
|
| 183 |
print(f"🖼️ {os.path.basename(img_path)}: Using Vision Model (graph/chart/picture)")
|
| 184 |
vision_summary = analyze_image_with_vision(img_path=img_path)
|
|
|
|
| 185 |
return vision_summary if vision_summary else ocr_text
|
| 186 |
except Exception as e:
|
| 187 |
print(f"❌ Error processing {img_path}: {e}")
|
|
|
|
| 203 |
|
| 204 |
|
| 205 |
# -------------------------------
|
| 206 |
+
# 6️⃣ Extract Content from PDFs with Vision Analysis
|
| 207 |
# -------------------------------
|
| 208 |
def extract_content_from_pdf(pdf_path):
|
| 209 |
try:
|
|
|
|
| 216 |
text = page.get_text()
|
| 217 |
if text.strip():
|
| 218 |
page_content.append(f"[Page {page_num} - Text Content]\n{text}")
|
| 219 |
+
|
| 220 |
+
# 2. Vision analysis of the entire page image
|
|
|
|
| 221 |
if groq_client:
|
|
|
|
| 222 |
try:
|
| 223 |
mat = fitz.Matrix(2, 2)
|
| 224 |
pix = page.get_pixmap(matrix=mat)
|
|
|
|
| 229 |
if vision_analysis and len(vision_analysis.strip()) > 30:
|
| 230 |
vision_section = f"[Page {page_num} - Visual Analysis]\n{vision_analysis}"
|
| 231 |
page_content.append(vision_section)
|
| 232 |
+
except Exception:
|
| 233 |
+
pass # Ignore rendering/analysis errors
|
|
|
|
| 234 |
|
| 235 |
+
# 3. OCR on embedded images (if OCR model is loaded)
|
| 236 |
if ocr_model:
|
| 237 |
image_list = page.get_images(full=True)
|
| 238 |
for img_index, img_info in enumerate(image_list, 1):
|
|
|
|
| 240 |
xref = img_info[0]
|
| 241 |
base_image = doc.extract_image(xref)
|
| 242 |
image_bytes = base_image["image"]
|
|
|
|
| 243 |
image = Image.open(BytesIO(image_bytes)).convert("RGB")
|
| 244 |
image_np = np.array(image)
|
|
|
|
| 245 |
result = ocr_model([image_np])
|
| 246 |
ocr_text = []
|
| 247 |
for ocr_page in result.pages:
|
|
|
|
| 254 |
if has_substantial_text(extracted_text, min_words=10):
|
| 255 |
page_content.append(f"[Page {page_num} - Embedded Image {img_index} OCR]\n{extracted_text}")
|
| 256 |
else:
|
|
|
|
| 257 |
vision_summary = analyze_image_with_vision(img_bytes=image_bytes)
|
| 258 |
if vision_summary:
|
| 259 |
page_content.append(
|
| 260 |
f"[Page {page_num} - Embedded Image {img_index} Analysis]\n{vision_summary}")
|
| 261 |
|
| 262 |
+
except Exception:
|
|
|
|
| 263 |
continue
|
| 264 |
|
| 265 |
if page_content:
|
|
|
|
| 276 |
|
| 277 |
|
| 278 |
# -------------------------------
|
| 279 |
+
# 7️⃣ Process All Document Types for folder build
|
| 280 |
# -------------------------------
|
| 281 |
def create_documents_from_folder(folder_path):
|
| 282 |
docs = []
|
|
|
|
| 284 |
for filename in files:
|
| 285 |
full_path = os.path.join(root, filename)
|
| 286 |
file_ext = filename.lower().split('.')[-1]
|
|
|
|
|
|
|
|
|
|
| 287 |
text = ""
|
| 288 |
|
| 289 |
if file_ext in ["jpg", "jpeg", "png"]:
|
|
|
|
| 293 |
elif file_ext == "pdf":
|
| 294 |
text = extract_content_from_pdf(full_path)
|
| 295 |
else:
|
|
|
|
| 296 |
continue
|
| 297 |
|
| 298 |
if text.strip():
|
|
|
|
| 307 |
}
|
| 308 |
)
|
| 309 |
docs.append(doc)
|
| 310 |
+
|
|
|
|
|
|
|
| 311 |
return docs
|
| 312 |
|
| 313 |
+
# --- Core RAG/DB Functions (8 to 12) ---
|
| 314 |
|
| 315 |
# -------------------------------
|
| 316 |
+
# 8️⃣ Build or Update QDRANT Store
|
| 317 |
# -------------------------------
|
| 318 |
def build_or_update_qdrant_store(folder_path):
|
| 319 |
if not QDRANT_API_KEY:
|
|
|
|
| 320 |
return None
|
| 321 |
|
|
|
|
|
|
|
|
|
|
| 322 |
docs = create_documents_from_folder(folder_path)
|
| 323 |
if not docs:
|
|
|
|
| 324 |
return None
|
| 325 |
|
|
|
|
|
|
|
| 326 |
try:
|
| 327 |
vector_store = Qdrant.from_documents(
|
| 328 |
docs,
|
|
|
|
| 332 |
collection_name=COLLECTION_NAME,
|
| 333 |
force_recreate=True
|
| 334 |
)
|
| 335 |
+
print(f"✅ Successfully created/updated Qdrant collection: {COLLECTION_NAME} with {len(docs)} documents.")
|
| 336 |
return vector_store
|
| 337 |
except Exception as e:
|
| 338 |
print(f"❌ Error connecting or uploading to Qdrant: {e}")
|
|
|
|
| 339 |
return None
|
| 340 |
|
| 341 |
|
| 342 |
# -------------------------------
|
| 343 |
+
# 9️⃣ Query QDRANT Function with Chart-Aware Re-ranking
|
| 344 |
# -------------------------------
|
| 345 |
def query_qdrant_store(query_text, k=3):
|
| 346 |
if not QDRANT_API_KEY:
|
|
|
|
| 347 |
return []
|
| 348 |
|
| 349 |
try:
|
|
|
|
| 357 |
collection_name=COLLECTION_NAME,
|
| 358 |
embeddings=embedding_model
|
| 359 |
)
|
|
|
|
| 360 |
except Exception as e:
|
| 361 |
print(f"❌ Error connecting to Qdrant: {e}")
|
| 362 |
return []
|
|
|
|
| 368 |
is_visual_query = any(keyword in query_text.lower() for keyword in visual_query_keywords)
|
| 369 |
|
| 370 |
if is_visual_query:
|
|
|
|
| 371 |
reranked_results = []
|
| 372 |
for doc, score in results:
|
| 373 |
boost = 0.0
|
| 374 |
+
if "Visual Analysis]" in doc.page_content or "bar chart" in doc.page_content.lower():
|
|
|
|
| 375 |
visual_content = doc.page_content.lower()
|
| 376 |
|
|
|
|
| 377 |
if 'bar chart' in query_text.lower() and 'bar chart' in visual_content:
|
| 378 |
boost += 1.0
|
| 379 |
elif 'pie chart' in query_text.lower() and 'pie chart' in visual_content:
|
| 380 |
boost += 1.0
|
| 381 |
+
elif any(kw in query_text.lower() for kw in ['chart', 'graph']) and any(kw in visual_content for kw in ['chart', 'graph', 'plot', 'diagram', 'table']):
|
| 382 |
+
boost += 0.5
|
|
|
|
|
|
|
|
|
|
|
|
|
| 383 |
else:
|
| 384 |
boost += 0.2
|
| 385 |
|
|
|
|
| 386 |
adjusted_score = score - boost
|
| 387 |
reranked_results.append((doc, adjusted_score, score))
|
| 388 |
|
| 389 |
+
reranked_results.sort(key=lambda x: x[1])
|
| 390 |
results = [(doc, adj_score) for doc, adj_score, _ in reranked_results[:k]]
|
| 391 |
else:
|
| 392 |
results = results[:k]
|
|
|
|
| 403 |
|
| 404 |
|
| 405 |
# -------------------------------
|
| 406 |
+
# 10️⃣ Answer Question using Llama 3.3 70B
|
| 407 |
# -------------------------------
|
| 408 |
def answer_question_with_llm(query_text, retrieved_docs, max_tokens=1000):
|
| 409 |
if not groq_client:
|
| 410 |
return "❌ Groq client not initialized. Cannot generate answer."
|
|
|
|
| 411 |
if not retrieved_docs:
|
| 412 |
return "❌ No relevant documents found to answer your question."
|
| 413 |
|
|
|
|
| 418 |
metadata = doc['metadata']
|
| 419 |
timestamp = metadata.get('upload_timestamp')
|
| 420 |
|
| 421 |
+
readable_time = time.ctime(float(timestamp)) if timestamp else "N/A"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 422 |
|
| 423 |
metadata_str = (
|
| 424 |
f"Source: {source}\n"
|
|
|
|
| 440 |
|
| 441 |
system_prompt = """You are a concise AI assistant. Answer the user's question *only* using the provided documents.
|
| 442 |
- Be brief and to the point.
|
|
|
|
|
|
|
| 443 |
- If the answer is not in the documents or metadata, simply state 'That information is not available in the documents.'"""
|
| 444 |
|
| 445 |
user_prompt = f"""DOCUMENTS:
|
|
|
|
| 450 |
ANSWER: (Provide a concise answer based *only* on the documents)"""
|
| 451 |
|
| 452 |
try:
|
|
|
|
| 453 |
response = groq_client.chat.completions.create(
|
| 454 |
model=LLM_MODEL,
|
| 455 |
messages=[
|
|
|
|
| 463 |
answer = response.choices[0].message.content
|
| 464 |
return answer
|
| 465 |
except Exception as e:
|
|
|
|
| 466 |
return f"❌ Error generating answer: {str(e)}"
|
| 467 |
|
| 468 |
# -------------------------------
|
| 469 |
+
# 11️⃣ Core RAG Response Function
|
| 470 |
# -------------------------------
|
| 471 |
def get_rag_response(query_text: str, k: int = 3) -> Dict[str, Any]:
|
| 472 |
+
"""Core RAG pipeline: retrieves, generates, and formats response."""
|
|
|
|
|
|
|
|
|
|
| 473 |
print(f"❓ QUERY: {query_text}")
|
|
|
|
|
|
|
|
|
|
| 474 |
retrieved_docs = query_qdrant_store(query_text, k=k)
|
| 475 |
|
| 476 |
if not retrieved_docs:
|
|
|
|
| 477 |
return {
|
| 478 |
"answer": "❌ No relevant documents found to answer your question. Please upload files first.",
|
| 479 |
"sources": []
|
| 480 |
}
|
| 481 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 482 |
answer = answer_question_with_llm(query_text, retrieved_docs)
|
| 483 |
|
| 484 |
sources_list = [
|
|
|
|
| 490 |
"sources": sources_list
|
| 491 |
}
|
| 492 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 493 |
return response_data
|
| 494 |
|
| 495 |
# -------------------------------
|
| 496 |
+
# 12️⃣ Core File Processing & Qdrant Addition
|
| 497 |
# -------------------------------
|
| 498 |
def process_single_file(file_path: str, filename: str) -> Document:
|
| 499 |
+
"""Processes a single file and returns a LangChain Document."""
|
|
|
|
|
|
|
| 500 |
file_ext = filename.lower().split('.')[-1]
|
| 501 |
text = ""
|
| 502 |
|
|
|
|
| 506 |
text = extract_text_from_txt(file_path)
|
| 507 |
elif file_ext == "pdf":
|
| 508 |
text = extract_content_from_pdf(file_path)
|
| 509 |
+
|
|
|
|
|
|
|
|
|
|
| 510 |
if text.strip():
|
| 511 |
doc = Document(
|
| 512 |
page_content=text,
|
|
|
|
| 517 |
"upload_timestamp": time.time()
|
| 518 |
}
|
| 519 |
)
|
|
|
|
| 520 |
return doc
|
| 521 |
+
return None
|
|
|
|
|
|
|
| 522 |
|
| 523 |
def add_documents_to_qdrant(docs: List[Document]):
|
| 524 |
+
"""Adds a list of processed documents to the Qdrant cloud."""
|
| 525 |
+
if not QDRANT_API_KEY or not docs:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 526 |
return
|
| 527 |
|
| 528 |
try:
|
|
|
|
| 529 |
client = QdrantClient(url=QDRANT_URL, api_key=QDRANT_API_KEY)
|
| 530 |
vector_store = Qdrant(
|
| 531 |
client=client,
|
|
|
|
| 538 |
print(f"❌ Error adding documents to Qdrant: {e}")
|
| 539 |
raise HTTPException(status_code=500, detail=f"Error updating vector store: {e}")
|
| 540 |
|
| 541 |
+
|
| 542 |
# -------------------------------
|
| 543 |
+
# 🚀 14. Gradio UI Setup
|
| 544 |
# -------------------------------
|
|
|
|
| 545 |
def create_gradio_ui():
|
| 546 |
"""
|
| 547 |
Creates the Gradio Blocks UI.
|
|
|
|
| 578 |
failed_count = 0
|
| 579 |
|
| 580 |
for file_obj in file_list:
|
| 581 |
+
full_path = file_obj.name
|
|
|
|
| 582 |
filename = os.path.basename(full_path)
|
| 583 |
|
| 584 |
try:
|
|
|
|
| 600 |
|
| 601 |
return f"✅ Processing complete. Added {processed_count} files. Failed: {failed_count}."
|
| 602 |
|
| 603 |
+
# Create the Gradio UI using Blocks
|
| 604 |
with gr.Blocks(theme="soft") as demo:
|
| 605 |
gr.Markdown("# 🧠 Multimodal RAG System (Powered by Qdrant Cloud)")
|
| 606 |
|
| 607 |
with gr.Tabs():
|
| 608 |
+
# --- CHAT TAB ---
|
| 609 |
with gr.TabItem("Chat with Documents"):
|
| 610 |
gr.ChatInterface(
|
| 611 |
fn=gradio_chat_response_func,
|
|
|
|
| 619 |
],
|
| 620 |
)
|
| 621 |
|
| 622 |
+
# --- UPLOAD TAB ---
|
| 623 |
+
with gr.TabItem("Upload New Documents"):
|
| 624 |
+
gr.Markdown("Upload new PDF, image, or text files to add them to the knowledge base.")
|
| 625 |
+
|
| 626 |
+
# Define components
|
| 627 |
+
file_uploader = gr.File(
|
| 628 |
+
label="Upload Documents",
|
| 629 |
file_count="multiple",
|
| 630 |
+
file_types=["image", ".pdf", ".txt", ".md"],
|
| 631 |
+
interactive=True
|
| 632 |
+
)
|
| 633 |
+
upload_button = gr.Button("Process and Add Documents", variant="primary")
|
| 634 |
+
status_output = gr.Markdown("Status: Ready to upload new documents.")
|
| 635 |
+
|
| 636 |
+
# Connect the upload button to the processing function
|
| 637 |
+
upload_button.click(
|
| 638 |
fn=gradio_upload_func,
|
| 639 |
+
inputs=[file_uploader],
|
| 640 |
+
outputs=[status_output]
|
| 641 |
)
|
| 642 |
+
|
| 643 |
+
return demo
|
|
|
|
|
|
|
| 644 |
|
| 645 |
|
| 646 |
# -------------------------------
|
|
|
|
| 666 |
|
| 667 |
@app.on_event("startup")
|
| 668 |
def on_startup():
|
| 669 |
+
"""Checks keys and builds the initial database on server startup."""
|
|
|
|
|
|
|
| 670 |
print("🚀 FastAPI app starting up...")
|
| 671 |
|
|
|
|
| 672 |
if not os.environ.get("GROQ_API_KEY"):
|
| 673 |
print("⚠️ WARNING: GROQ_API_KEY not set!")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 674 |
if not QDRANT_API_KEY:
|
| 675 |
print("⚠️ WARNING: QDRANT_API_KEY not set! Database functions will fail.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 676 |
|
|
|
|
| 677 |
folder = "data"
|
| 678 |
if os.path.exists(folder):
|
|
|
|
|
|
|
|
|
|
| 679 |
build_or_update_qdrant_store(folder)
|
| 680 |
else:
|
| 681 |
+
print("ℹ️ No 'data' folder found. Skipping initial build.")
|
|
|
|
|
|
|
|
|
|
| 682 |
|
| 683 |
+
# --- API Endpoints ---
|
|
|
|
|
|
|
| 684 |
|
| 685 |
@app.post("/query/", response_model=QueryResponse)
|
| 686 |
async def handle_query(request: QueryRequest):
|
| 687 |
+
"""Executes a RAG query against the vector database."""
|
|
|
|
|
|
|
| 688 |
try:
|
| 689 |
response_data = get_rag_response(request.query, request.k)
|
| 690 |
return response_data
|
| 691 |
except Exception as e:
|
|
|
|
| 692 |
raise HTTPException(status_code=500, detail=str(e))
|
| 693 |
|
| 694 |
@app.post("/upload/", response_model=UploadResponse)
|
| 695 |
async def handle_upload(files: List[UploadFile] = File(...)):
|
| 696 |
+
"""Uploads one or more files, processes them, and adds them to the vector DB."""
|
|
|
|
|
|
|
| 697 |
if not QDRANT_API_KEY:
|
| 698 |
raise HTTPException(status_code=500, detail="QDRANT_API_KEY is not set. Upload failed.")
|
| 699 |
|
|
|
|
| 704 |
for file in files:
|
| 705 |
tmp_path = None
|
| 706 |
try:
|
|
|
|
| 707 |
with tempfile.NamedTemporaryFile(delete=False, suffix=f"_{file.filename}") as tmp:
|
| 708 |
shutil.copyfileobj(file.file, tmp)
|
| 709 |
tmp_path = tmp.name
|
| 710 |
|
|
|
|
|
|
|
| 711 |
doc = process_single_file(tmp_path, file.filename)
|
| 712 |
|
|
|
|
| 713 |
if doc:
|
| 714 |
docs_to_add.append(doc)
|
| 715 |
processed_files.append(file.filename)
|
|
|
|
| 717 |
failed_files.append(file.filename)
|
| 718 |
|
| 719 |
except Exception as e:
|
|
|
|
| 720 |
failed_files.append(file.filename)
|
| 721 |
|
| 722 |
finally:
|
|
|
|
| 723 |
if tmp_path and os.path.exists(tmp_path):
|
| 724 |
os.unlink(tmp_path)
|
|
|
|
| 725 |
file.file.close()
|
| 726 |
|
|
|
|
| 727 |
if docs_to_add:
|
| 728 |
try:
|
| 729 |
add_documents_to_qdrant(docs_to_add)
|
| 730 |
except HTTPException:
|
|
|
|
| 731 |
failed_files.extend(processed_files)
|
| 732 |
processed_files = []
|
|
|
|
| 733 |
|
| 734 |
return {
|
| 735 |
"message": f"Processing complete. Added {len(processed_files)} file(s) to the database.",
|
|
|
|
| 741 |
# 🚀 15. Create and Mount the Apps
|
| 742 |
# -------------------------------
|
| 743 |
|
|
|
|
| 744 |
@app.get("/")
|
| 745 |
def redirect_to_ui():
|
| 746 |
"""Redirect root to the Gradio UI"""
|
|
|
|
| 753 |
"message": "Welcome to the Multimodal RAG API",
|
| 754 |
"endpoints": {
|
| 755 |
"ui": "/ui - Gradio interface",
|
|
|
|
|
|
|
| 756 |
"query": "POST /query/ - Execute RAG queries",
|
| 757 |
"upload": "POST /upload/ - Upload and process files"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 758 |
}
|
| 759 |
}
|
| 760 |
|
|
|
|
| 762 |
gradio_ui = create_gradio_ui()
|
| 763 |
|
| 764 |
# Mount the Gradio UI at /ui path
|
| 765 |
+
app = gr.mount_gradio_app(app, gradio_ui, path="/ui")
|
|
|
|
|
|
|
|
|