Spaces:
Sleeping
Sleeping
| from doctr.models import ocr_predictor | |
| from PIL import Image | |
| import numpy as np | |
| import torch | |
| from langchain_huggingface import HuggingFaceEmbeddings | |
| from langchain_core.documents import Document | |
| import os | |
| from groq import Groq | |
| import base64 | |
| from io import BytesIO | |
| import fitz # PyMuPDF | |
| import time | |
| import shutil | |
| # Flask imports | |
| from flask import Flask, request, jsonify | |
| from werkzeug.utils import secure_filename | |
| import tempfile | |
| # Qdrant imports | |
| from langchain_qdrant import Qdrant | |
| from qdrant_client import QdrantClient | |
| # ------------------------------- | |
| # Configuration | |
| # ------------------------------- | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| ocr_model = ocr_predictor(pretrained=True).to(device) | |
| embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2") | |
| groq_client = Groq(api_key=os.environ.get("GROQ_API_KEY")) | |
| VISION_MODEL = "meta-llama/llama-4-scout-17b-16e-instruct" | |
| LLM_MODEL = "llama-3.3-70b-versatile" | |
| QDRANT_URL = "https://bdf142ef-7e2a-433b-87a0-301ff303e3af.us-east4-0.gcp.cloud.qdrant.io:6333" | |
| QDRANT_API_KEY = os.environ.get("QDRANT_API_KEY") | |
| COLLECTION_NAME = "multimodal_rag_store" | |
| # ------------------------------- | |
| # Helper Functions | |
| # ------------------------------- | |
| def has_substantial_text(text, min_words=10): | |
| words = text.split() | |
| return len(words) >= min_words | |
| def analyze_image_with_vision(img_path=None, img_bytes=None, pil_image=None, max_retries=3): | |
| for attempt in range(max_retries): | |
| try: | |
| if pil_image: | |
| buffered = BytesIO() | |
| pil_image.save(buffered, format="PNG") | |
| img_data = buffered.getvalue() | |
| img_format = "png" | |
| elif img_path: | |
| with open(img_path, "rb") as img_file: | |
| img_data = img_file.read() | |
| img_format = img_path.lower().split('.')[-1] | |
| elif img_bytes: | |
| img_data = img_bytes | |
| img_format = "png" | |
| else: | |
| return "" | |
| base64_image = base64.b64encode(img_data).decode('utf-8') | |
| if img_format == 'jpg': | |
| img_format = 'jpeg' | |
| vision_prompt = """Analyze this image carefully and provide a detailed description: | |
| 1. IDENTIFY THE TYPE: Is this a chart, graph, table, diagram, photograph, or text document? | |
| 2. IF IT'S A CHART/GRAPH/TABLE: | |
| - Specify the exact type (bar chart, pie chart, line graph, scatter plot, table, etc.) | |
| - List ALL categories/labels shown | |
| - Describe the data values and trends | |
| - Mention axis labels, title, legend if present | |
| - Highlight key insights or patterns | |
| 3. IF IT'S A PHOTOGRAPH/DIAGRAM: | |
| - Describe what you see in detail | |
| - Identify key objects, people, or concepts | |
| - Note any text visible in the image | |
| 4. IF IT'S A TEXT DOCUMENT: | |
| - Summarize the main content and structure | |
| Provide a comprehensive description suitable for semantic search. Be specific and detailed.""" | |
| chat_completion = groq_client.chat.completions.create( | |
| messages=[ | |
| { | |
| "role": "user", | |
| "content": [ | |
| {"type": "text", "text": vision_prompt}, | |
| { | |
| "type": "image_url", | |
| "image_url": { | |
| "url": f"data:image/{img_format};base64,{base64_image}" | |
| } | |
| } | |
| ] | |
| } | |
| ], | |
| model=VISION_MODEL, | |
| temperature=0.2, | |
| max_tokens=1500, | |
| ) | |
| summary = chat_completion.choices[0].message.content | |
| if summary and len(summary.strip()) > 30: | |
| return summary | |
| else: | |
| if attempt < max_retries - 1: | |
| time.sleep(1) | |
| continue | |
| return "" | |
| except Exception as e: | |
| if attempt < max_retries - 1: | |
| time.sleep(2) | |
| continue | |
| return "" | |
| return "" | |
| def extract_text_from_image(img_path): | |
| try: | |
| image = Image.open(img_path).convert("RGB") | |
| image_np = np.array(image) | |
| result = ocr_model([image_np]) | |
| text = [] | |
| for page in result.pages: | |
| for block in page.blocks: | |
| for line in block.lines: | |
| line_text = " ".join([word.value for word in line.words]) | |
| text.append(line_text) | |
| ocr_text = "\n".join(text) | |
| if has_substantial_text(ocr_text, min_words=10): | |
| print(f"📄 {os.path.basename(img_path)}: Using OCR") | |
| return ocr_text | |
| else: | |
| print(f"🖼️ {os.path.basename(img_path)}: Using Vision Model") | |
| vision_summary = analyze_image_with_vision(img_path=img_path) | |
| return vision_summary if vision_summary else ocr_text | |
| except Exception as e: | |
| print(f"❌ Error processing {img_path}: {e}") | |
| return "" | |
| def extract_text_from_txt(file_path): | |
| try: | |
| with open(file_path, 'r', encoding='utf-8') as f: | |
| text = f.read() | |
| print(f"📝 {os.path.basename(file_path)}: Extracted text") | |
| return text | |
| except Exception as e: | |
| print(f"❌ Error reading text file {file_path}: {e}") | |
| return "" | |
| def extract_content_from_pdf(pdf_path): | |
| try: | |
| doc = fitz.open(pdf_path) | |
| all_content = [] | |
| for page_num, page in enumerate(doc, 1): | |
| page_content = [] | |
| text = page.get_text() | |
| if text.strip(): | |
| page_content.append(f"[Page {page_num} - Text Content]\n{text}") | |
| try: | |
| mat = fitz.Matrix(2, 2) | |
| pix = page.get_pixmap(matrix=mat) | |
| img_data = pix.tobytes("png") | |
| page_image = Image.open(BytesIO(img_data)).convert("RGB") | |
| vision_analysis = analyze_image_with_vision(pil_image=page_image) | |
| if vision_analysis and len(vision_analysis.strip()) > 30: | |
| page_content.append(f"[Page {page_num} - Visual Analysis]\n{vision_analysis}") | |
| except Exception as e: | |
| print(f"❌ Error rendering page {page_num}: {e}") | |
| image_list = page.get_images(full=True) | |
| for img_index, img_info in enumerate(image_list, 1): | |
| try: | |
| xref = img_info[0] | |
| base_image = doc.extract_image(xref) | |
| image_bytes = base_image["image"] | |
| image = Image.open(BytesIO(image_bytes)).convert("RGB") | |
| image_np = np.array(image) | |
| result = ocr_model([image_np]) | |
| ocr_text = [] | |
| for ocr_page in result.pages: | |
| for block in ocr_page.blocks: | |
| for line in block.lines: | |
| line_text = " ".join([word.value for word in line.words]) | |
| ocr_text.append(line_text) | |
| extracted_text = "\n".join(ocr_text) | |
| if has_substantial_text(extracted_text, min_words=10): | |
| page_content.append(f"[Page {page_num} - Embedded Image {img_index} OCR]\n{extracted_text}") | |
| else: | |
| vision_summary = analyze_image_with_vision(img_bytes=image_bytes) | |
| if vision_summary: | |
| page_content.append(f"[Page {page_num} - Embedded Image {img_index} Analysis]\n{vision_summary}") | |
| except Exception as e: | |
| print(f"❌ Error processing embedded image {img_index}: {e}") | |
| continue | |
| if page_content: | |
| combined_page = "\n\n---SECTION BREAK---\n\n".join(page_content) | |
| all_content.append(combined_page) | |
| doc.close() | |
| final_content = "\n\n---PAGE BREAK---\n\n".join(all_content) | |
| return final_content | |
| except Exception as e: | |
| print(f"❌ Error processing PDF {pdf_path}: {e}") | |
| return "" | |
| def create_documents_from_folder(folder_path): | |
| docs = [] | |
| for root, dirs, files in os.walk(folder_path): | |
| for filename in files: | |
| full_path = os.path.join(root, filename) | |
| file_ext = filename.lower().split('.')[-1] | |
| text = "" | |
| if file_ext in ["jpg", "jpeg", "png"]: | |
| text = extract_text_from_image(full_path) | |
| elif file_ext in ["txt", "md"]: | |
| text = extract_text_from_txt(full_path) | |
| elif file_ext == "pdf": | |
| text = extract_content_from_pdf(full_path) | |
| else: | |
| continue | |
| if text.strip(): | |
| relative_path = os.path.relpath(full_path, folder_path) | |
| doc = Document( | |
| page_content=text, | |
| metadata={ | |
| "source": relative_path, | |
| "filename": filename, | |
| "file_type": file_ext, | |
| "upload_timestamp": os.path.getmtime(full_path) | |
| } | |
| ) | |
| docs.append(doc) | |
| print(f"✅ Added {filename}") | |
| return docs | |
| def build_or_update_qdrant_store(folder_path): | |
| print("\n🔄 Building Qdrant collection...") | |
| docs = create_documents_from_folder(folder_path) | |
| if not docs: | |
| print("⚠️ No valid documents found!") | |
| return None | |
| try: | |
| vector_store = Qdrant.from_documents( | |
| docs, | |
| embedding_model, | |
| url=QDRANT_URL, | |
| api_key=QDRANT_API_KEY, | |
| collection_name=COLLECTION_NAME, | |
| force_recreate=True | |
| ) | |
| print(f"✅ Created collection with {len(docs)} documents") | |
| return vector_store | |
| except Exception as e: | |
| print(f"❌ Error with Qdrant: {e}") | |
| return None | |
| def query_qdrant_store(query_text, k=3): | |
| try: | |
| client = QdrantClient(url=QDRANT_URL, api_key=QDRANT_API_KEY, timeout=20) | |
| vector_store = Qdrant( | |
| client=client, | |
| collection_name=COLLECTION_NAME, | |
| embeddings=embedding_model | |
| ) | |
| except Exception as e: | |
| print(f"❌ Error connecting to Qdrant: {e}") | |
| return [] | |
| initial_k = k * 3 | |
| results = vector_store.similarity_search_with_score(query_text, k=initial_k) | |
| visual_query_keywords = ['chart', 'graph', 'bar', 'pie', 'plot', 'diagram', 'table', 'visual', 'visualization'] | |
| is_visual_query = any(keyword in query_text.lower() for keyword in visual_query_keywords) | |
| if is_visual_query: | |
| reranked_results = [] | |
| for doc, score in results: | |
| boost = 0.0 | |
| if "Visual Analysis]" in doc.page_content: | |
| boost += 0.5 | |
| adjusted_score = score - boost | |
| reranked_results.append((doc, adjusted_score)) | |
| reranked_results.sort(key=lambda x: x[1]) | |
| results = reranked_results[:k] | |
| else: | |
| results = results[:k] | |
| retrieved_docs = [] | |
| for doc, score in results: | |
| retrieved_docs.append({ | |
| "source": doc.metadata['source'], | |
| "content": doc.page_content, | |
| "score": float(score), | |
| "metadata": doc.metadata | |
| }) | |
| return retrieved_docs | |
| def answer_question_with_llm(query_text, retrieved_docs, max_tokens=1000): | |
| if not retrieved_docs: | |
| return "❌ No relevant documents found." | |
| context_parts = [] | |
| for i, doc in enumerate(retrieved_docs, 1): | |
| source = doc['source'] | |
| content = doc['content'] | |
| max_content_length = 2500 | |
| if len(content) > max_content_length: | |
| content = content[:max_content_length] + "...[truncated]" | |
| context_parts.append(f"--- Document {i} ---\nSource: {source}\n\n{content}\n") | |
| context = "\n".join(context_parts) | |
| system_prompt = """You are a concise AI assistant. Answer the user's question *only* using the provided documents. | |
| - Be brief and to the point. | |
| - If the answer is not in the documents, state 'That information is not available in the documents.'""" | |
| user_prompt = f"""DOCUMENTS: | |
| {context} | |
| QUESTION: {query_text} | |
| ANSWER:""" | |
| try: | |
| response = groq_client.chat.completions.create( | |
| model=LLM_MODEL, | |
| messages=[ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "user", "content": user_prompt} | |
| ], | |
| temperature=0.2, | |
| max_tokens=max_tokens, | |
| ) | |
| return response.choices[0].message.content | |
| except Exception as e: | |
| return f"❌ Error: {str(e)}" | |
| def get_rag_response(query_text, k=3): | |
| print(f"\n❓ Query: {query_text}") | |
| retrieved_docs = query_qdrant_store(query_text, k=k) | |
| if not retrieved_docs: | |
| return { | |
| "answer": "❌ No relevant documents found.", | |
| "sources": [] | |
| } | |
| answer = answer_question_with_llm(query_text, retrieved_docs) | |
| sources_list = [{"source": doc['source'], "score": doc['score']} for doc in retrieved_docs] | |
| return { | |
| "answer": answer, | |
| "sources": sources_list | |
| } | |
| def process_single_file(file_path, filename): | |
| file_ext = filename.lower().split('.')[-1] | |
| text = "" | |
| if file_ext in ["jpg", "jpeg", "png"]: | |
| text = extract_text_from_image(file_path) | |
| elif file_ext in ["txt", "md"]: | |
| text = extract_text_from_txt(file_path) | |
| elif file_ext == "pdf": | |
| text = extract_content_from_pdf(file_path) | |
| else: | |
| return None | |
| if text.strip(): | |
| doc = Document( | |
| page_content=text, | |
| metadata={ | |
| "source": filename, | |
| "filename": filename, | |
| "file_type": file_ext, | |
| "upload_timestamp": time.time() | |
| } | |
| ) | |
| print(f"✅ Processed {filename}") | |
| return doc | |
| return None | |
| def add_documents_to_qdrant(docs): | |
| if not docs: | |
| return | |
| try: | |
| client = QdrantClient(url=QDRANT_URL, api_key=QDRANT_API_KEY) | |
| vector_store = Qdrant( | |
| client=client, | |
| collection_name=COLLECTION_NAME, | |
| embeddings=embedding_model | |
| ) | |
| vector_store.add_documents(docs) | |
| print(f"✅ Added {len(docs)} documents to Qdrant") | |
| except Exception as e: | |
| print(f"❌ Error adding to Qdrant: {e}") | |
| raise | |
| # ------------------------------- | |
| # Flask App - API ONLY | |
| # ------------------------------- | |
| app = Flask(__name__) | |
| def home(): | |
| return jsonify({ | |
| "message": "🧠 Multimodal RAG API", | |
| "endpoints": { | |
| "/query": "POST - Query documents", | |
| "/upload": "POST - Upload files", | |
| "/health": "GET - Health check" | |
| } | |
| }) | |
| def handle_query(): | |
| data = request.get_json() | |
| if not data or 'query' not in data: | |
| return jsonify({"error": "No query provided"}), 400 | |
| query = data.get('query', '') | |
| k = data.get('k', 3) | |
| try: | |
| response_data = get_rag_response(query, k) | |
| return jsonify(response_data) | |
| except Exception as e: | |
| return jsonify({"error": str(e)}), 500 | |
| def handle_upload(): | |
| if 'files' not in request.files: | |
| return jsonify({"error": "No files provided"}), 400 | |
| files = request.files.getlist('files') | |
| processed_files = [] | |
| failed_files = [] | |
| docs_to_add = [] | |
| for file in files: | |
| if file.filename == '': | |
| continue | |
| try: | |
| filename = secure_filename(file.filename) | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=filename) as tmp: | |
| file.save(tmp.name) | |
| tmp_path = tmp.name | |
| doc = process_single_file(tmp_path, filename) | |
| if doc: | |
| docs_to_add.append(doc) | |
| processed_files.append(filename) | |
| else: | |
| failed_files.append(filename) | |
| os.unlink(tmp_path) | |
| except Exception as e: | |
| print(f"❌ Error: {e}") | |
| failed_files.append(file.filename) | |
| if docs_to_add: | |
| try: | |
| add_documents_to_qdrant(docs_to_add) | |
| except Exception as e: | |
| return jsonify({"error": f"Failed to add to database: {str(e)}"}), 500 | |
| return jsonify({ | |
| "message": f"Processed {len(processed_files)} files", | |
| "processed_files": processed_files, | |
| "failed_files": failed_files | |
| }) | |
| def health_check(): | |
| return jsonify({"status": "ok", "message": "API is running"}) | |
| # ------------------------------- | |
| # Initialize and Run | |
| # ------------------------------- | |
| if __name__ == "__main__": | |
| print("🚀 Starting Multimodal RAG API...") | |
| # Build initial database if data folder exists | |
| folder = "data" | |
| if os.path.exists(folder): | |
| print(f"\n📂 Found '{folder}' folder, building database...") | |
| build_or_update_qdrant_store(folder) | |
| print("\n✅ Flask API starting on http://0.0.0.0:7860") | |
| print(" Endpoints:") | |
| print(" - GET / (Home/Docs)") | |
| print(" - POST /query") | |
| print(" - POST /upload") | |
| print(" - GET /health\n") | |
| app.run(host='0.0.0.0', port=7860, debug=False) |