Spaces:
Sleeping
Sleeping
AdarshRajDS commited on
Commit ·
e23acaf
1
Parent(s): d73f6d0
Fix HF persistent storage paths
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- .DS_Store +0 -0
- .env +0 -0
- .gitignore +5 -0
- Dockerfile +17 -0
- README.md +0 -10
- app/.DS_Store +0 -0
- app/api/routes/grading.py +10 -0
- app/api/routes/rag.py +10 -0
- app/api/routes/upload.py +10 -0
- app/api/routes/visualize.py +10 -0
- app/main.py +32 -0
- app/schemas/grading.py +8 -0
- app/schemas/rag.py +11 -0
- app/schemas/upload.py +9 -0
- app/schemas/visualize.py +10 -0
- app/services/grading_service.py +30 -0
- app/services/ingestion_service.py +31 -0
- app/services/rag_service.py +11 -0
- app/services/visualization_service.py +10 -0
- multimodal_rag_thesis.egg-info/PKG-INFO +26 -0
- multimodal_rag_thesis.egg-info/SOURCES.txt +52 -0
- multimodal_rag_thesis.egg-info/dependency_links.txt +1 -0
- multimodal_rag_thesis.egg-info/requires.txt +20 -0
- multimodal_rag_thesis.egg-info/top_level.txt +5 -0
- pyproject.toml +38 -0
- requirements.txt +160 -0
- src/.DS_Store +0 -0
- src/__init__.py +0 -0
- src/assessment/annotation_grader.py +29 -0
- src/assessment/image_query_retriever.py +26 -0
- src/assessment/label_extractor.py +23 -0
- src/assessment/run_annotation_check.py +35 -0
- src/config/__init__.py +0 -0
- src/config/settings.py +31 -0
- src/embeddings/__init__.py +0 -0
- src/embeddings/embedding_factory.py +8 -0
- src/ingestion/__init__.py +0 -0
- src/ingestion/image_extractor.py +100 -0
- src/ingestion/loader.py +84 -0
- src/ingestion/run.py +38 -0
- src/ingestion/run_image_extraction.py +35 -0
- src/llm/__init__.py +0 -0
- src/llm/llm_factory.py +17 -0
- src/main.py +37 -0
- src/multimodal/__init__.py +7 -0
- src/multimodal/clip_embedding.py +17 -0
- src/multimodal/multimodal_indexer.py +69 -0
- src/multimodal/multimodal_rag_chain.py +36 -0
- src/multimodal/multimodal_retriever.py +29 -0
- src/multimodal/run_multimodal_indexing.py +11 -0
.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
.env
ADDED
|
File without changes
|
.gitignore
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
data/
|
| 2 |
+
outputs/
|
| 3 |
+
uploads/
|
| 4 |
+
__pycache__/
|
| 5 |
+
*.pyc
|
Dockerfile
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.10-slim
|
| 2 |
+
|
| 3 |
+
WORKDIR /code
|
| 4 |
+
|
| 5 |
+
RUN apt-get update && apt-get install -y \
|
| 6 |
+
build-essential \
|
| 7 |
+
poppler-utils \
|
| 8 |
+
libgl1 \
|
| 9 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 10 |
+
|
| 11 |
+
COPY requirements.txt .
|
| 12 |
+
|
| 13 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
| 14 |
+
|
| 15 |
+
COPY . .
|
| 16 |
+
|
| 17 |
+
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "7860"]
|
README.md
CHANGED
|
@@ -1,10 +0,0 @@
|
|
| 1 |
-
---
|
| 2 |
-
title: ThesisBackend
|
| 3 |
-
emoji: 🌖
|
| 4 |
-
colorFrom: green
|
| 5 |
-
colorTo: red
|
| 6 |
-
sdk: docker
|
| 7 |
-
pinned: false
|
| 8 |
-
---
|
| 9 |
-
|
| 10 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app/.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
app/api/routes/grading.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import APIRouter, UploadFile, File
|
| 2 |
+
from app.schemas.grading import GradingResponse
|
| 3 |
+
from app.services.grading_service import grade_annotation
|
| 4 |
+
|
| 5 |
+
router = APIRouter(prefix="/grade-annotation", tags=["Grading"])
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
@router.post("/", response_model=GradingResponse)
|
| 9 |
+
def grade(file: UploadFile = File(...)):
|
| 10 |
+
return grade_annotation(file)
|
app/api/routes/rag.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import APIRouter
|
| 2 |
+
from app.schemas.rag import AskRequest, AskResponse
|
| 3 |
+
from app.services.rag_service import ask_question
|
| 4 |
+
|
| 5 |
+
router = APIRouter(prefix="/rag", tags=["RAG"])
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
@router.post("/ask", response_model=AskResponse)
|
| 9 |
+
def ask(req: AskRequest):
|
| 10 |
+
return ask_question(req.question)
|
app/api/routes/upload.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import APIRouter, UploadFile, File
|
| 2 |
+
from app.schemas.upload import UploadResponse
|
| 3 |
+
from app.services.ingestion_service import upload_pdf
|
| 4 |
+
|
| 5 |
+
router = APIRouter(prefix="/upload-pdf", tags=["Ingestion"])
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
@router.post("/", response_model=UploadResponse)
|
| 9 |
+
def upload(file: UploadFile = File(...)):
|
| 10 |
+
return upload_pdf(file)
|
app/api/routes/visualize.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import APIRouter
|
| 2 |
+
from app.schemas.visualize import VisualizeRequest, VisualizeResponse
|
| 3 |
+
from app.services.visualization_service import visualize
|
| 4 |
+
|
| 5 |
+
router = APIRouter(prefix="/visualize", tags=["Visualization"])
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
@router.post("/", response_model=VisualizeResponse)
|
| 9 |
+
def run_visualize(req: VisualizeRequest):
|
| 10 |
+
return visualize(req.question)
|
app/main.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import FastAPI
|
| 2 |
+
from fastapi.staticfiles import StaticFiles
|
| 3 |
+
from app.api.routes import rag
|
| 4 |
+
from app.api.routes import rag, visualize
|
| 5 |
+
from app.api.routes import rag, visualize, grading
|
| 6 |
+
|
| 7 |
+
from app.api.routes import rag, visualize, grading, upload
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
app = FastAPI(title="Multimodal RAG API")
|
| 15 |
+
|
| 16 |
+
app.include_router(rag.router)
|
| 17 |
+
|
| 18 |
+
app.include_router(visualize.router)
|
| 19 |
+
|
| 20 |
+
app.include_router(grading.router)
|
| 21 |
+
|
| 22 |
+
app.include_router(upload.router)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
app.mount("/outputs", StaticFiles(directory="outputs"), name="outputs")
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
@app.get("/")
|
| 31 |
+
def root():
|
| 32 |
+
return {"status": "running"}
|
app/schemas/grading.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pydantic import BaseModel
|
| 2 |
+
from typing import List, Optional
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class GradingResponse(BaseModel):
|
| 6 |
+
score: Optional[float]
|
| 7 |
+
feedback: str
|
| 8 |
+
missing_structures: List[str]
|
app/schemas/rag.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pydantic import BaseModel
|
| 2 |
+
from typing import List, Optional
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class AskRequest(BaseModel):
|
| 6 |
+
question: str
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class AskResponse(BaseModel):
|
| 10 |
+
answer: str
|
| 11 |
+
images: Optional[List[str]] = None
|
app/schemas/upload.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pydantic import BaseModel
|
| 2 |
+
from typing import Dict
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class UploadResponse(BaseModel):
|
| 6 |
+
status: str
|
| 7 |
+
message: str
|
| 8 |
+
text_ingestion: Dict
|
| 9 |
+
image_extraction: Dict
|
app/schemas/visualize.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pydantic import BaseModel
|
| 2 |
+
from typing import Optional
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class VisualizeRequest(BaseModel):
|
| 6 |
+
question: str
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class VisualizeResponse(BaseModel):
|
| 10 |
+
annotated_image: Optional[str]
|
app/services/grading_service.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from src.assessment.run_annotation_check import run_annotation_grading
|
| 3 |
+
|
| 4 |
+
UPLOAD_DIR = "uploads"
|
| 5 |
+
os.makedirs(UPLOAD_DIR, exist_ok=True)
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def grade_annotation(file):
|
| 9 |
+
|
| 10 |
+
file_path = os.path.join(UPLOAD_DIR, file.filename)
|
| 11 |
+
|
| 12 |
+
with open(file_path, "wb") as f:
|
| 13 |
+
f.write(file.file.read())
|
| 14 |
+
|
| 15 |
+
raw_result = run_annotation_grading(file_path)
|
| 16 |
+
|
| 17 |
+
# 🔥 If your grader returns a STRING → convert to structured format
|
| 18 |
+
if isinstance(raw_result, str):
|
| 19 |
+
return {
|
| 20 |
+
"score": None,
|
| 21 |
+
"feedback": raw_result,
|
| 22 |
+
"missing_structures": []
|
| 23 |
+
}
|
| 24 |
+
|
| 25 |
+
# 🔥 If your grader already returns a dict → pass through safely
|
| 26 |
+
return {
|
| 27 |
+
"score": raw_result.get("score"),
|
| 28 |
+
"feedback": raw_result.get("feedback", ""),
|
| 29 |
+
"missing_structures": raw_result.get("missing_structures", [])
|
| 30 |
+
}
|
app/services/ingestion_service.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
from src.config.settings import settings
|
| 4 |
+
from src.ingestion.run import run_ingestion
|
| 5 |
+
from src.ingestion.run_image_extraction import run_image_extraction
|
| 6 |
+
from src.multimodal.run_multimodal_rag import reload_rag
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
UPLOAD_DIR = Path(settings.raw_data_dir)
|
| 10 |
+
UPLOAD_DIR.mkdir(parents=True, exist_ok=True)
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def upload_pdf(file):
|
| 14 |
+
|
| 15 |
+
file_path = UPLOAD_DIR / file.filename
|
| 16 |
+
|
| 17 |
+
with open(file_path, "wb") as f:
|
| 18 |
+
f.write(file.file.read())
|
| 19 |
+
|
| 20 |
+
text_result = run_ingestion()
|
| 21 |
+
image_result = run_image_extraction()
|
| 22 |
+
|
| 23 |
+
reload_rag()
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
return {
|
| 27 |
+
"status": "success",
|
| 28 |
+
"message": f"{file.filename} ingested and RAG reloaded",
|
| 29 |
+
"text_ingestion": text_result,
|
| 30 |
+
"image_extraction": image_result,
|
| 31 |
+
}
|
app/services/rag_service.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from src.multimodal.run_multimodal_rag import run_multimodal_rag
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def ask_question(question: str):
|
| 5 |
+
|
| 6 |
+
result = run_multimodal_rag(question)
|
| 7 |
+
|
| 8 |
+
return {
|
| 9 |
+
"answer": result["answer"],
|
| 10 |
+
"images": result.get("images", [])
|
| 11 |
+
}
|
app/services/visualization_service.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from src.visualization.run_visual_answer import run_visual_answer
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def visualize(question: str):
|
| 5 |
+
|
| 6 |
+
result = run_visual_answer(question)
|
| 7 |
+
|
| 8 |
+
return {
|
| 9 |
+
"annotated_image": result["annotated_image"]
|
| 10 |
+
}
|
multimodal_rag_thesis.egg-info/PKG-INFO
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Metadata-Version: 2.4
|
| 2 |
+
Name: multimodal-rag-thesis
|
| 3 |
+
Version: 0.1.0
|
| 4 |
+
Summary: Add your description here
|
| 5 |
+
Requires-Python: <3.12,>=3.10
|
| 6 |
+
Description-Content-Type: text/markdown
|
| 7 |
+
Requires-Dist: chromadb>=1.5.0
|
| 8 |
+
Requires-Dist: python-dotenv>=1.0.1
|
| 9 |
+
Requires-Dist: langchain>=0.2.0
|
| 10 |
+
Requires-Dist: langchain-chroma>=0.1.0
|
| 11 |
+
Requires-Dist: langchain-community>=0.2.0
|
| 12 |
+
Requires-Dist: langchain-huggingface>=0.0.3
|
| 13 |
+
Requires-Dist: pymupdf
|
| 14 |
+
Requires-Dist: pillow
|
| 15 |
+
Requires-Dist: matplotlib
|
| 16 |
+
Requires-Dist: streamlit
|
| 17 |
+
Requires-Dist: onnxruntime<1.17
|
| 18 |
+
Requires-Dist: opencv-python-headless<4.9
|
| 19 |
+
Requires-Dist: sentence-transformers==2.7.0
|
| 20 |
+
Requires-Dist: numpy<2
|
| 21 |
+
Requires-Dist: torch==2.2.2
|
| 22 |
+
Requires-Dist: langchain-groq>=1.1.2
|
| 23 |
+
Requires-Dist: fastapi>=0.129.0
|
| 24 |
+
Requires-Dist: uvicorn>=0.40.0
|
| 25 |
+
Requires-Dist: pydantic>=2.12.5
|
| 26 |
+
Requires-Dist: python-multipart>=0.0.22
|
multimodal_rag_thesis.egg-info/SOURCES.txt
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
README.md
|
| 2 |
+
pyproject.toml
|
| 3 |
+
app/main.py
|
| 4 |
+
app/api/routes/grading.py
|
| 5 |
+
app/api/routes/rag.py
|
| 6 |
+
app/api/routes/visualize.py
|
| 7 |
+
app/schemas/grading.py
|
| 8 |
+
app/schemas/rag.py
|
| 9 |
+
app/schemas/visualize.py
|
| 10 |
+
app/services/grading_service.py
|
| 11 |
+
app/services/rag_service.py
|
| 12 |
+
app/services/visualization_service.py
|
| 13 |
+
multimodal_rag_thesis.egg-info/PKG-INFO
|
| 14 |
+
multimodal_rag_thesis.egg-info/SOURCES.txt
|
| 15 |
+
multimodal_rag_thesis.egg-info/dependency_links.txt
|
| 16 |
+
multimodal_rag_thesis.egg-info/requires.txt
|
| 17 |
+
multimodal_rag_thesis.egg-info/top_level.txt
|
| 18 |
+
src/__init__.py
|
| 19 |
+
src/main.py
|
| 20 |
+
src/assessment/annotation_grader.py
|
| 21 |
+
src/assessment/image_query_retriever.py
|
| 22 |
+
src/assessment/label_extractor.py
|
| 23 |
+
src/assessment/run_annotation_check.py
|
| 24 |
+
src/config/__init__.py
|
| 25 |
+
src/config/settings.py
|
| 26 |
+
src/embeddings/__init__.py
|
| 27 |
+
src/embeddings/embedding_factory.py
|
| 28 |
+
src/ingestion/__init__.py
|
| 29 |
+
src/ingestion/image_extractor.py
|
| 30 |
+
src/ingestion/loader.py
|
| 31 |
+
src/ingestion/run.py
|
| 32 |
+
src/ingestion/run_image_extraction.py
|
| 33 |
+
src/llm/__init__.py
|
| 34 |
+
src/llm/llm_factory.py
|
| 35 |
+
src/multimodal/__init__.py
|
| 36 |
+
src/multimodal/clip_embedding.py
|
| 37 |
+
src/multimodal/multimodal_indexer.py
|
| 38 |
+
src/multimodal/multimodal_rag_chain.py
|
| 39 |
+
src/multimodal/multimodal_retriever.py
|
| 40 |
+
src/multimodal/run_multimodal_indexing.py
|
| 41 |
+
src/multimodal/run_multimodal_query.py
|
| 42 |
+
src/multimodal/run_multimodal_rag.py
|
| 43 |
+
src/retrieval/__init__.py
|
| 44 |
+
src/retrieval/query.py
|
| 45 |
+
src/retrieval/rag_query.py
|
| 46 |
+
src/retrieval/vector_store.py
|
| 47 |
+
src/utils/__init__.py
|
| 48 |
+
src/utils/logger.py
|
| 49 |
+
src/visualization/annotation_schema.py
|
| 50 |
+
src/visualization/image_annotator.py
|
| 51 |
+
src/visualization/llm_structure_extractor.py
|
| 52 |
+
src/visualization/run_visual_answer.py
|
multimodal_rag_thesis.egg-info/dependency_links.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
|
multimodal_rag_thesis.egg-info/requires.txt
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
chromadb>=1.5.0
|
| 2 |
+
python-dotenv>=1.0.1
|
| 3 |
+
langchain>=0.2.0
|
| 4 |
+
langchain-chroma>=0.1.0
|
| 5 |
+
langchain-community>=0.2.0
|
| 6 |
+
langchain-huggingface>=0.0.3
|
| 7 |
+
pymupdf
|
| 8 |
+
pillow
|
| 9 |
+
matplotlib
|
| 10 |
+
streamlit
|
| 11 |
+
onnxruntime<1.17
|
| 12 |
+
opencv-python-headless<4.9
|
| 13 |
+
sentence-transformers==2.7.0
|
| 14 |
+
numpy<2
|
| 15 |
+
torch==2.2.2
|
| 16 |
+
langchain-groq>=1.1.2
|
| 17 |
+
fastapi>=0.129.0
|
| 18 |
+
uvicorn>=0.40.0
|
| 19 |
+
pydantic>=2.12.5
|
| 20 |
+
python-multipart>=0.0.22
|
multimodal_rag_thesis.egg-info/top_level.txt
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
app
|
| 2 |
+
data
|
| 3 |
+
outputs
|
| 4 |
+
src
|
| 5 |
+
uploads
|
pyproject.toml
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[project]
|
| 2 |
+
name = "multimodal-rag-thesis"
|
| 3 |
+
version = "0.1.0"
|
| 4 |
+
description = "Add your description here"
|
| 5 |
+
readme = "README.md"
|
| 6 |
+
requires-python = ">=3.10,<3.12"
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
dependencies = [
|
| 10 |
+
"chromadb>=1.5.0",
|
| 11 |
+
"python-dotenv>=1.0.1",
|
| 12 |
+
"langchain>=0.2.0",
|
| 13 |
+
"langchain-chroma>=0.1.0",
|
| 14 |
+
"langchain-community>=0.2.0",
|
| 15 |
+
"langchain-huggingface>=0.0.3",
|
| 16 |
+
"pymupdf",
|
| 17 |
+
"pillow",
|
| 18 |
+
"matplotlib",
|
| 19 |
+
"streamlit",
|
| 20 |
+
"onnxruntime<1.17",
|
| 21 |
+
"opencv-python-headless<4.9",
|
| 22 |
+
"sentence-transformers==2.7.0",
|
| 23 |
+
"numpy<2",
|
| 24 |
+
"torch==2.2.2",
|
| 25 |
+
"langchain-groq>=1.1.2",
|
| 26 |
+
"fastapi>=0.129.0",
|
| 27 |
+
"uvicorn>=0.40.0",
|
| 28 |
+
"pydantic>=2.12.5",
|
| 29 |
+
"python-multipart>=0.0.22",
|
| 30 |
+
]
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
[build-system]
|
| 34 |
+
requires = ["setuptools"]
|
| 35 |
+
build-backend = "setuptools.build_meta"
|
| 36 |
+
|
| 37 |
+
[tool.setuptools.packages.find]
|
| 38 |
+
where = ["."]
|
requirements.txt
ADDED
|
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
aiohappyeyeballs==2.6.1
|
| 2 |
+
aiohttp==3.13.3
|
| 3 |
+
aiosignal==1.4.0
|
| 4 |
+
altair==6.0.0
|
| 5 |
+
annotated-doc==0.0.4
|
| 6 |
+
annotated-types==0.7.0
|
| 7 |
+
anyio==4.12.1
|
| 8 |
+
async-timeout==4.0.3
|
| 9 |
+
attrs==25.4.0
|
| 10 |
+
backoff==2.2.1
|
| 11 |
+
bcrypt==5.0.0
|
| 12 |
+
blinker==1.9.0
|
| 13 |
+
build==1.4.0
|
| 14 |
+
cachetools==6.2.6
|
| 15 |
+
certifi==2026.1.4
|
| 16 |
+
charset-normalizer==3.4.4
|
| 17 |
+
chromadb==1.5.0
|
| 18 |
+
click==8.3.1
|
| 19 |
+
coloredlogs==15.0.1
|
| 20 |
+
contourpy==1.3.2
|
| 21 |
+
cycler==0.12.1
|
| 22 |
+
dataclasses-json==0.6.7
|
| 23 |
+
distro==1.9.0
|
| 24 |
+
durationpy==0.10
|
| 25 |
+
exceptiongroup==1.3.1
|
| 26 |
+
fastapi==0.129.0
|
| 27 |
+
filelock==3.21.2
|
| 28 |
+
flatbuffers==25.12.19
|
| 29 |
+
fonttools==4.61.1
|
| 30 |
+
frozenlist==1.8.0
|
| 31 |
+
fsspec==2026.2.0
|
| 32 |
+
gitdb==4.0.12
|
| 33 |
+
gitpython==3.1.46
|
| 34 |
+
googleapis-common-protos==1.72.0
|
| 35 |
+
greenlet==3.3.1
|
| 36 |
+
groq==0.37.1
|
| 37 |
+
grpcio==1.78.0
|
| 38 |
+
h11==0.16.0
|
| 39 |
+
hf-xet==1.2.0
|
| 40 |
+
httpcore==1.0.9
|
| 41 |
+
httptools==0.7.1
|
| 42 |
+
httpx==0.28.1
|
| 43 |
+
httpx-sse==0.4.3
|
| 44 |
+
huggingface-hub==0.36.2
|
| 45 |
+
humanfriendly==10.0
|
| 46 |
+
idna==3.11
|
| 47 |
+
importlib-metadata==8.7.1
|
| 48 |
+
importlib-resources==6.5.2
|
| 49 |
+
jinja2==3.1.6
|
| 50 |
+
joblib==1.5.3
|
| 51 |
+
jsonpatch==1.33
|
| 52 |
+
jsonpointer==3.0.0
|
| 53 |
+
jsonschema==4.26.0
|
| 54 |
+
jsonschema-specifications==2025.9.1
|
| 55 |
+
kiwisolver==1.4.9
|
| 56 |
+
kubernetes==35.0.0
|
| 57 |
+
langchain==1.2.10
|
| 58 |
+
langchain-chroma==1.1.0
|
| 59 |
+
langchain-classic==1.0.1
|
| 60 |
+
langchain-community==0.4.1
|
| 61 |
+
langchain-core==1.2.12
|
| 62 |
+
langchain-groq==1.1.2
|
| 63 |
+
langchain-huggingface==1.2.0
|
| 64 |
+
langchain-text-splitters==1.1.0
|
| 65 |
+
langgraph==1.0.8
|
| 66 |
+
langgraph-checkpoint==4.0.0
|
| 67 |
+
langgraph-prebuilt==1.0.7
|
| 68 |
+
langgraph-sdk==0.3.5
|
| 69 |
+
langsmith==0.7.1
|
| 70 |
+
markdown-it-py==4.0.0
|
| 71 |
+
markupsafe==3.0.3
|
| 72 |
+
marshmallow==3.26.2
|
| 73 |
+
matplotlib==3.10.8
|
| 74 |
+
mdurl==0.1.2
|
| 75 |
+
mmh3==5.2.0
|
| 76 |
+
mpmath==1.3.0
|
| 77 |
+
multidict==6.7.1
|
| 78 |
+
-e file:///Users/human1/Documents/Thesis/POC2/phase1/multimodal-rag-thesis
|
| 79 |
+
mypy-extensions==1.1.0
|
| 80 |
+
narwhals==2.16.0
|
| 81 |
+
networkx==3.4.2
|
| 82 |
+
numpy==1.26.4
|
| 83 |
+
oauthlib==3.3.1
|
| 84 |
+
onnxruntime==1.16.3
|
| 85 |
+
opencv-python-headless==4.8.1.78
|
| 86 |
+
opentelemetry-api==1.39.1
|
| 87 |
+
opentelemetry-exporter-otlp-proto-common==1.39.1
|
| 88 |
+
opentelemetry-exporter-otlp-proto-grpc==1.39.1
|
| 89 |
+
opentelemetry-proto==1.39.1
|
| 90 |
+
opentelemetry-sdk==1.39.1
|
| 91 |
+
opentelemetry-semantic-conventions==0.60b1
|
| 92 |
+
orjson==3.11.7
|
| 93 |
+
ormsgpack==1.12.2
|
| 94 |
+
overrides==7.7.0
|
| 95 |
+
packaging==26.0
|
| 96 |
+
pandas==2.3.3
|
| 97 |
+
pillow==12.1.1
|
| 98 |
+
posthog==5.4.0
|
| 99 |
+
propcache==0.4.1
|
| 100 |
+
protobuf==6.33.5
|
| 101 |
+
pyarrow==23.0.0
|
| 102 |
+
pybase64==1.4.3
|
| 103 |
+
pydantic==2.12.5
|
| 104 |
+
pydantic-core==2.41.5
|
| 105 |
+
pydantic-settings==2.12.0
|
| 106 |
+
pydeck==0.9.1
|
| 107 |
+
pygments==2.19.2
|
| 108 |
+
pymupdf==1.27.1
|
| 109 |
+
pyparsing==3.3.2
|
| 110 |
+
pypika==0.51.1
|
| 111 |
+
pyproject-hooks==1.2.0
|
| 112 |
+
python-dateutil==2.9.0.post0
|
| 113 |
+
python-dotenv==1.2.1
|
| 114 |
+
python-multipart==0.0.22
|
| 115 |
+
pytz==2025.2
|
| 116 |
+
pyyaml==6.0.3
|
| 117 |
+
referencing==0.37.0
|
| 118 |
+
regex==2026.1.15
|
| 119 |
+
requests==2.32.5
|
| 120 |
+
requests-oauthlib==2.0.0
|
| 121 |
+
requests-toolbelt==1.0.0
|
| 122 |
+
rich==14.3.2
|
| 123 |
+
rpds-py==0.30.0
|
| 124 |
+
safetensors==0.7.0
|
| 125 |
+
scikit-learn==1.7.2
|
| 126 |
+
scipy==1.15.3
|
| 127 |
+
sentence-transformers==2.7.0
|
| 128 |
+
shellingham==1.5.4
|
| 129 |
+
six==1.17.0
|
| 130 |
+
smmap==5.0.2
|
| 131 |
+
sniffio==1.3.1
|
| 132 |
+
sqlalchemy==2.0.46
|
| 133 |
+
starlette==0.52.1
|
| 134 |
+
streamlit==1.54.0
|
| 135 |
+
sympy==1.14.0
|
| 136 |
+
tenacity==9.1.4
|
| 137 |
+
threadpoolctl==3.6.0
|
| 138 |
+
tokenizers==0.22.2
|
| 139 |
+
toml==0.10.2
|
| 140 |
+
tomli==2.4.0
|
| 141 |
+
torch==2.2.2
|
| 142 |
+
tornado==6.5.4
|
| 143 |
+
tqdm==4.67.3
|
| 144 |
+
transformers==4.57.6
|
| 145 |
+
typer==0.23.1
|
| 146 |
+
typing-extensions==4.15.0
|
| 147 |
+
typing-inspect==0.9.0
|
| 148 |
+
typing-inspection==0.4.2
|
| 149 |
+
tzdata==2025.3
|
| 150 |
+
urllib3==2.6.3
|
| 151 |
+
uuid-utils==0.14.0
|
| 152 |
+
uvicorn==0.40.0
|
| 153 |
+
uvloop==0.22.1
|
| 154 |
+
watchfiles==1.1.1
|
| 155 |
+
websocket-client==1.9.0
|
| 156 |
+
websockets==16.0
|
| 157 |
+
xxhash==3.6.0
|
| 158 |
+
yarl==1.22.0
|
| 159 |
+
zipp==3.23.0
|
| 160 |
+
zstandard==0.25.0
|
src/.DS_Store
ADDED
|
Binary file (8.2 kB). View file
|
|
|
src/__init__.py
ADDED
|
File without changes
|
src/assessment/annotation_grader.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from src.llm.llm_factory import get_llm
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class AnnotationGrader:
|
| 5 |
+
|
| 6 |
+
def __init__(self):
|
| 7 |
+
self.llm = get_llm()
|
| 8 |
+
|
| 9 |
+
def grade(self, user_labels, reference_text):
|
| 10 |
+
|
| 11 |
+
prompt = f"""
|
| 12 |
+
Compare the student labels with the reference anatomy.
|
| 13 |
+
|
| 14 |
+
Student labels:
|
| 15 |
+
{user_labels}
|
| 16 |
+
|
| 17 |
+
Reference:
|
| 18 |
+
{reference_text}
|
| 19 |
+
|
| 20 |
+
Return:
|
| 21 |
+
|
| 22 |
+
Correct:
|
| 23 |
+
Missing:
|
| 24 |
+
Incorrect:
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
response = self.llm.invoke(prompt)
|
| 28 |
+
|
| 29 |
+
return response.content
|
src/assessment/image_query_retriever.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from src.multimodal.clip_embedding import CLIPEmbedding
|
| 2 |
+
from langchain_chroma import Chroma
|
| 3 |
+
from src.config.settings import settings
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class ImageQueryRetriever:
|
| 7 |
+
|
| 8 |
+
def __init__(self):
|
| 9 |
+
|
| 10 |
+
self.embedding = CLIPEmbedding()
|
| 11 |
+
|
| 12 |
+
self.vectorstore = Chroma(
|
| 13 |
+
collection_name="multimodal_rag",
|
| 14 |
+
persist_directory=f"{settings.processed_data_dir}/multimodal_chroma"
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
def retrieve_similar(self, image_path, k=1):
|
| 18 |
+
|
| 19 |
+
emb = self.embedding.embed_image([image_path])[0]
|
| 20 |
+
|
| 21 |
+
results = self.vectorstore._collection.query(
|
| 22 |
+
query_embeddings=[emb.tolist()],
|
| 23 |
+
n_results=k
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
return results["metadatas"][0]
|
src/assessment/label_extractor.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from src.llm.llm_factory import get_llm
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class LabelExtractor:
|
| 5 |
+
|
| 6 |
+
def __init__(self):
|
| 7 |
+
self.llm = get_llm()
|
| 8 |
+
|
| 9 |
+
def extract(self, image_path):
|
| 10 |
+
|
| 11 |
+
prompt = f"""
|
| 12 |
+
The user uploaded an annotated anatomy image.
|
| 13 |
+
|
| 14 |
+
List the anatomical labels present in the image.
|
| 15 |
+
|
| 16 |
+
Return JSON:
|
| 17 |
+
|
| 18 |
+
{{ "labels": ["label1", "label2"] }}
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
response = self.llm.invoke(prompt)
|
| 22 |
+
|
| 23 |
+
return response.content
|
src/assessment/run_annotation_check.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from src.assessment.image_query_retriever import ImageQueryRetriever
|
| 2 |
+
from src.assessment.label_extractor import LabelExtractor
|
| 3 |
+
from src.assessment.annotation_grader import AnnotationGrader
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
# 🔥 Global instances for API performance
|
| 7 |
+
retriever = ImageQueryRetriever()
|
| 8 |
+
extractor = LabelExtractor()
|
| 9 |
+
grader = AnnotationGrader()
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def run_annotation_grading(image_path: str):
|
| 13 |
+
"""
|
| 14 |
+
FastAPI entry point for grading a student annotation.
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
reference = retriever.retrieve_similar(image_path)
|
| 18 |
+
|
| 19 |
+
user_labels = extractor.extract(image_path)
|
| 20 |
+
|
| 21 |
+
result = grader.grade(user_labels, reference)
|
| 22 |
+
|
| 23 |
+
return result
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def main():
|
| 27 |
+
image_path = input("Enter path to annotated image: ")
|
| 28 |
+
|
| 29 |
+
result = run_annotation_grading(image_path)
|
| 30 |
+
|
| 31 |
+
print("\nRESULT:\n", result)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
if __name__ == "__main__":
|
| 35 |
+
main()
|
src/config/__init__.py
ADDED
|
File without changes
|
src/config/settings.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass
|
| 2 |
+
from dotenv import load_dotenv
|
| 3 |
+
import os
|
| 4 |
+
|
| 5 |
+
load_dotenv()
|
| 6 |
+
|
| 7 |
+
# 🔥 This becomes /data on Hugging Face, and stays local when developing
|
| 8 |
+
BASE_DATA_DIR = os.getenv("HF_HOME", "data")
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
@dataclass
|
| 12 |
+
class Settings:
|
| 13 |
+
app_env: str = os.getenv("APP_ENV", "development")
|
| 14 |
+
log_level: str = os.getenv("LOG_LEVEL", "INFO")
|
| 15 |
+
|
| 16 |
+
# 📂 Data paths
|
| 17 |
+
base_data_dir: str = BASE_DATA_DIR
|
| 18 |
+
raw_data_dir: str = os.path.join(BASE_DATA_DIR, "raw")
|
| 19 |
+
processed_data_dir: str = os.path.join(BASE_DATA_DIR, "processed")
|
| 20 |
+
chroma_dir: str = os.path.join(BASE_DATA_DIR, "chroma")
|
| 21 |
+
|
| 22 |
+
# 🤖 Models
|
| 23 |
+
embedding_model: str = "sentence-transformers/all-MiniLM-L6-v2"
|
| 24 |
+
llm_model: str = "llama-3.1-8b-instant"
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
settings = Settings()
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
|
src/embeddings/__init__.py
ADDED
|
File without changes
|
src/embeddings/embedding_factory.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from langchain_huggingface import HuggingFaceEmbeddings
|
| 2 |
+
from src.config.settings import settings
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def get_text_embedding():
|
| 6 |
+
return HuggingFaceEmbeddings(
|
| 7 |
+
model_name=settings.embedding_model
|
| 8 |
+
)
|
src/ingestion/__init__.py
ADDED
|
File without changes
|
src/ingestion/image_extractor.py
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import fitz # PyMuPDF
|
| 2 |
+
import json
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
|
| 5 |
+
from src.config.settings import settings
|
| 6 |
+
from src.utils.logger import get_logger
|
| 7 |
+
|
| 8 |
+
logger = get_logger(__name__)
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class ImageExtractor:
|
| 12 |
+
|
| 13 |
+
def __init__(self):
|
| 14 |
+
self.output_dir = Path(settings.processed_data_dir) / "images"
|
| 15 |
+
self.output_dir.mkdir(parents=True, exist_ok=True)
|
| 16 |
+
|
| 17 |
+
self.metadata = []
|
| 18 |
+
|
| 19 |
+
# 🚫 pages we never want (publisher / front matter / credits)
|
| 20 |
+
self.page_noise_keywords = [
|
| 21 |
+
"learning resources",
|
| 22 |
+
"about our team",
|
| 23 |
+
"senior contributors",
|
| 24 |
+
"powerpoint slides",
|
| 25 |
+
"pronunciation guide",
|
| 26 |
+
"acknowledgments",
|
| 27 |
+
"reviewers",
|
| 28 |
+
"openstax",
|
| 29 |
+
]
|
| 30 |
+
|
| 31 |
+
def extract_from_pdf(self, pdf_path: Path):
|
| 32 |
+
|
| 33 |
+
logger.info(f"Extracting images from {pdf_path.name}")
|
| 34 |
+
|
| 35 |
+
doc = fitz.open(pdf_path)
|
| 36 |
+
|
| 37 |
+
for page_index in range(len(doc)):
|
| 38 |
+
|
| 39 |
+
page = doc[page_index]
|
| 40 |
+
page_text = page.get_text("text")
|
| 41 |
+
text_lower = page_text.lower()
|
| 42 |
+
|
| 43 |
+
# 🚫 Skip non-content pages
|
| 44 |
+
if any(keyword in text_lower for keyword in self.page_noise_keywords):
|
| 45 |
+
continue
|
| 46 |
+
|
| 47 |
+
image_list = page.get_images(full=True)
|
| 48 |
+
|
| 49 |
+
if not image_list:
|
| 50 |
+
continue
|
| 51 |
+
|
| 52 |
+
for img_index, img in enumerate(image_list):
|
| 53 |
+
|
| 54 |
+
xref = img[0]
|
| 55 |
+
base_image = doc.extract_image(xref)
|
| 56 |
+
image_bytes = base_image["image"]
|
| 57 |
+
|
| 58 |
+
pix = fitz.Pixmap(doc, xref)
|
| 59 |
+
|
| 60 |
+
# ✅ Skip tiny images (logos, bullets, icons)
|
| 61 |
+
if pix.width < 200 or pix.height < 200:
|
| 62 |
+
continue
|
| 63 |
+
|
| 64 |
+
# ✅ Skip very low file size
|
| 65 |
+
if len(image_bytes) < 20_000:
|
| 66 |
+
continue
|
| 67 |
+
|
| 68 |
+
image_name = (
|
| 69 |
+
f"{pdf_path.stem}_page_{page_index+1}_img_{img_index}.png"
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
image_path = self.output_dir / image_name
|
| 73 |
+
|
| 74 |
+
with open(image_path, "wb") as f:
|
| 75 |
+
f.write(image_bytes)
|
| 76 |
+
|
| 77 |
+
self.metadata.append(
|
| 78 |
+
{
|
| 79 |
+
"image_path": str(image_path),
|
| 80 |
+
"page": page_index + 1,
|
| 81 |
+
"source": pdf_path.name,
|
| 82 |
+
"image_index": img_index,
|
| 83 |
+
"width": pix.width,
|
| 84 |
+
"height": pix.height,
|
| 85 |
+
"nearby_text": page_text[:1000],
|
| 86 |
+
}
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
pix = None # free memory
|
| 90 |
+
|
| 91 |
+
doc.close()
|
| 92 |
+
|
| 93 |
+
def save_metadata(self):
|
| 94 |
+
|
| 95 |
+
metadata_path = Path(settings.processed_data_dir) / "image_metadata.json"
|
| 96 |
+
|
| 97 |
+
with open(metadata_path, "w") as f:
|
| 98 |
+
json.dump(self.metadata, f, indent=2)
|
| 99 |
+
|
| 100 |
+
logger.info(f"Saved image metadata → {metadata_path}")
|
src/ingestion/loader.py
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pathlib import Path
|
| 2 |
+
from langchain_community.document_loaders import PyMuPDFLoader
|
| 3 |
+
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
| 4 |
+
|
| 5 |
+
from src.config.settings import settings
|
| 6 |
+
from src.utils.logger import get_logger
|
| 7 |
+
|
| 8 |
+
logger = get_logger(__name__)
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class DocumentLoader:
|
| 12 |
+
|
| 13 |
+
def load_pdfs(self):
|
| 14 |
+
data_path = Path(settings.raw_data_dir)
|
| 15 |
+
pdf_files = list(data_path.glob("*.pdf"))
|
| 16 |
+
|
| 17 |
+
if not pdf_files:
|
| 18 |
+
logger.warning("No PDFs found in data/raw")
|
| 19 |
+
return []
|
| 20 |
+
|
| 21 |
+
documents = []
|
| 22 |
+
|
| 23 |
+
for pdf in pdf_files:
|
| 24 |
+
logger.info(f"Loading PDF: {pdf.name}")
|
| 25 |
+
loader = PyMuPDFLoader(str(pdf))
|
| 26 |
+
|
| 27 |
+
pages = loader.load()
|
| 28 |
+
|
| 29 |
+
clean_pages = []
|
| 30 |
+
|
| 31 |
+
for page in pages:
|
| 32 |
+
text = page.page_content.strip().lower()
|
| 33 |
+
|
| 34 |
+
# 🚫 remove index pages
|
| 35 |
+
if "index" in text[:200]:
|
| 36 |
+
continue
|
| 37 |
+
|
| 38 |
+
# 🚫 remove table of contents
|
| 39 |
+
if "chapter" in text and "...." in text:
|
| 40 |
+
continue
|
| 41 |
+
|
| 42 |
+
# 🚫 remove glossary-style alphabetical lists
|
| 43 |
+
if text.count(",") > 20 and len(text) < 1500:
|
| 44 |
+
continue
|
| 45 |
+
|
| 46 |
+
clean_pages.append(page)
|
| 47 |
+
|
| 48 |
+
logger.info(f"Kept {len(clean_pages)} useful pages.")
|
| 49 |
+
|
| 50 |
+
documents.extend(clean_pages)
|
| 51 |
+
|
| 52 |
+
logger.info(f"Total kept pages: {len(documents)}")
|
| 53 |
+
|
| 54 |
+
return documents
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def split_documents(self, documents):
|
| 59 |
+
|
| 60 |
+
text_splitter = RecursiveCharacterTextSplitter(
|
| 61 |
+
chunk_size=800,
|
| 62 |
+
chunk_overlap=150
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
chunks = text_splitter.split_documents(documents)
|
| 66 |
+
|
| 67 |
+
filtered_chunks = []
|
| 68 |
+
|
| 69 |
+
for chunk in chunks:
|
| 70 |
+
text = chunk.page_content.strip()
|
| 71 |
+
|
| 72 |
+
# Remove very short chunks
|
| 73 |
+
if len(text) < 200:
|
| 74 |
+
continue
|
| 75 |
+
|
| 76 |
+
# Remove index/table-of-contents style chunks
|
| 77 |
+
if text.count(".....") > 2:
|
| 78 |
+
continue
|
| 79 |
+
|
| 80 |
+
filtered_chunks.append(chunk)
|
| 81 |
+
|
| 82 |
+
logger.info(f"Split into {len(filtered_chunks)} clean chunks.")
|
| 83 |
+
|
| 84 |
+
return filtered_chunks
|
src/ingestion/run.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from src.ingestion.loader import DocumentLoader
|
| 2 |
+
from src.embeddings.embedding_factory import get_text_embedding
|
| 3 |
+
from src.retrieval.vector_store import VectorStoreFactory
|
| 4 |
+
from src.utils.logger import get_logger
|
| 5 |
+
|
| 6 |
+
logger = get_logger(__name__)
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def run_ingestion():
|
| 10 |
+
logger.info("Starting ingestion pipeline...")
|
| 11 |
+
|
| 12 |
+
loader = DocumentLoader()
|
| 13 |
+
|
| 14 |
+
documents = loader.load_pdfs()
|
| 15 |
+
|
| 16 |
+
if not documents:
|
| 17 |
+
logger.warning("No documents to ingest.")
|
| 18 |
+
return {"status": "warning", "message": "No documents found"}
|
| 19 |
+
|
| 20 |
+
chunks = loader.split_documents(documents)
|
| 21 |
+
|
| 22 |
+
embedding = get_text_embedding()
|
| 23 |
+
|
| 24 |
+
vectordb = VectorStoreFactory.create(embedding)
|
| 25 |
+
|
| 26 |
+
vectordb.add_documents(chunks)
|
| 27 |
+
|
| 28 |
+
logger.info("Ingestion complete.")
|
| 29 |
+
|
| 30 |
+
return {"status": "success", "message": "Text ingestion complete"}
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def main():
|
| 34 |
+
run_ingestion()
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
if __name__ == "__main__":
|
| 38 |
+
main()
|
src/ingestion/run_image_extraction.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pathlib import Path
|
| 2 |
+
from src.config.settings import settings
|
| 3 |
+
from src.ingestion.image_extractor import ImageExtractor
|
| 4 |
+
from src.utils.logger import get_logger
|
| 5 |
+
|
| 6 |
+
logger = get_logger(__name__)
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def run_image_extraction():
|
| 10 |
+
|
| 11 |
+
raw_path = Path(settings.raw_data_dir)
|
| 12 |
+
pdf_files = list(raw_path.glob("*.pdf"))
|
| 13 |
+
|
| 14 |
+
if not pdf_files:
|
| 15 |
+
logger.warning("No PDFs found.")
|
| 16 |
+
return {"status": "warning", "message": "No PDFs for image extraction"}
|
| 17 |
+
|
| 18 |
+
extractor = ImageExtractor()
|
| 19 |
+
|
| 20 |
+
for pdf in pdf_files:
|
| 21 |
+
extractor.extract_from_pdf(pdf)
|
| 22 |
+
|
| 23 |
+
extractor.save_metadata()
|
| 24 |
+
|
| 25 |
+
logger.info("Image extraction complete.")
|
| 26 |
+
|
| 27 |
+
return {"status": "success", "message": "Image extraction complete"}
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def main():
|
| 31 |
+
run_image_extraction()
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
if __name__ == "__main__":
|
| 35 |
+
main()
|
src/llm/__init__.py
ADDED
|
File without changes
|
src/llm/llm_factory.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from langchain_groq import ChatGroq
|
| 3 |
+
from src.config.settings import settings
|
| 4 |
+
from src.utils.logger import get_logger
|
| 5 |
+
|
| 6 |
+
logger = get_logger(__name__)
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def get_llm():
|
| 10 |
+
|
| 11 |
+
logger.info(f"Loading Groq model: {settings.llm_model}")
|
| 12 |
+
|
| 13 |
+
return ChatGroq(
|
| 14 |
+
model=settings.llm_model,
|
| 15 |
+
api_key=os.getenv("GROQ_API_KEY"),
|
| 16 |
+
temperature=0
|
| 17 |
+
)
|
src/main.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from src.config.settings import settings
|
| 2 |
+
from src.utils.logger import get_logger
|
| 3 |
+
from src.ingestion.loader import DocumentLoader
|
| 4 |
+
from src.embeddings.embedding_factory import get_text_embedding
|
| 5 |
+
from src.retrieval.vector_store import VectorStoreFactory
|
| 6 |
+
from src.llm.llm_factory import get_llm
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
import os
|
| 10 |
+
|
| 11 |
+
os.makedirs(settings.raw_data_dir, exist_ok=True)
|
| 12 |
+
os.makedirs(settings.processed_data_dir, exist_ok=True)
|
| 13 |
+
os.makedirs(settings.chroma_dir, exist_ok=True)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
logger = get_logger(__name__)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def main():
|
| 20 |
+
logger.info("Multimodal RAG system initialized.")
|
| 21 |
+
logger.info(f"Running in environment: {settings.app_env}")
|
| 22 |
+
|
| 23 |
+
loader = DocumentLoader()
|
| 24 |
+
loader.load()
|
| 25 |
+
|
| 26 |
+
embedding = get_text_embedding()
|
| 27 |
+
logger.info("Embedding model loaded.")
|
| 28 |
+
|
| 29 |
+
vectordb = VectorStoreFactory.create(embedding)
|
| 30 |
+
|
| 31 |
+
llm = get_llm()
|
| 32 |
+
|
| 33 |
+
logger.info("System setup complete.")
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
if __name__ == "__main__":
|
| 37 |
+
main()
|
src/multimodal/__init__.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .clip_embedding import CLIPEmbedding
|
| 2 |
+
from .multimodal_indexer import MultimodalIndexer
|
| 3 |
+
|
| 4 |
+
__all__ = [
|
| 5 |
+
"CLIPEmbedding",
|
| 6 |
+
"MultimodalIndexer",
|
| 7 |
+
]
|
src/multimodal/clip_embedding.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from sentence_transformers import SentenceTransformer
|
| 2 |
+
from PIL import Image
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class CLIPEmbedding:
|
| 6 |
+
|
| 7 |
+
def __init__(self):
|
| 8 |
+
self.model = SentenceTransformer("clip-ViT-B-32")
|
| 9 |
+
|
| 10 |
+
def embed_text(self, texts):
|
| 11 |
+
return self.model.encode(texts, convert_to_numpy=True)
|
| 12 |
+
|
| 13 |
+
def embed_image(self, image_paths):
|
| 14 |
+
|
| 15 |
+
images = [Image.open(p).convert("RGB") for p in image_paths]
|
| 16 |
+
|
| 17 |
+
return self.model.encode(images, convert_to_numpy=True)
|
src/multimodal/multimodal_indexer.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
from langchain_chroma import Chroma
|
| 4 |
+
|
| 5 |
+
from src.config.settings import settings
|
| 6 |
+
from src.multimodal.clip_embedding import CLIPEmbedding
|
| 7 |
+
from src.utils.logger import get_logger
|
| 8 |
+
|
| 9 |
+
logger = get_logger(__name__)
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class MultimodalIndexer:
|
| 13 |
+
|
| 14 |
+
def __init__(self):
|
| 15 |
+
|
| 16 |
+
self.embedding = CLIPEmbedding()
|
| 17 |
+
|
| 18 |
+
self.vectorstore = Chroma(
|
| 19 |
+
collection_name="multimodal_rag",
|
| 20 |
+
persist_directory=f"{settings.processed_data_dir}/multimodal_chroma"
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
def index_images(self):
|
| 24 |
+
|
| 25 |
+
metadata_path = Path(settings.processed_data_dir) / "image_metadata.json"
|
| 26 |
+
|
| 27 |
+
if not metadata_path.exists():
|
| 28 |
+
logger.warning("image_metadata.json not found.")
|
| 29 |
+
return
|
| 30 |
+
|
| 31 |
+
with open(metadata_path) as f:
|
| 32 |
+
metadata = json.load(f)
|
| 33 |
+
|
| 34 |
+
if not metadata:
|
| 35 |
+
logger.warning("No image metadata found.")
|
| 36 |
+
return
|
| 37 |
+
|
| 38 |
+
image_paths = [item["image_path"] for item in metadata]
|
| 39 |
+
|
| 40 |
+
logger.info(f"Embedding {len(image_paths)} images with CLIP...")
|
| 41 |
+
|
| 42 |
+
image_embeddings = self.embedding.embed_image(image_paths)
|
| 43 |
+
|
| 44 |
+
ids = []
|
| 45 |
+
documents = []
|
| 46 |
+
metadatas = []
|
| 47 |
+
|
| 48 |
+
for i, (emb, item) in enumerate(zip(image_embeddings, metadata)):
|
| 49 |
+
|
| 50 |
+
ids.append(f"image_{i}")
|
| 51 |
+
|
| 52 |
+
documents.append(item["nearby_text"])
|
| 53 |
+
|
| 54 |
+
metadatas.append({
|
| 55 |
+
"type": "image",
|
| 56 |
+
"image_path": item["image_path"],
|
| 57 |
+
"page": item["page"],
|
| 58 |
+
"source": item["source"],
|
| 59 |
+
})
|
| 60 |
+
|
| 61 |
+
self.vectorstore._collection.add(
|
| 62 |
+
embeddings=image_embeddings.tolist(),
|
| 63 |
+
documents=documents,
|
| 64 |
+
metadatas=metadatas,
|
| 65 |
+
ids=ids
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
logger.info("Image embeddings stored in Chroma.")
|
src/multimodal/multimodal_rag_chain.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from src.multimodal.multimodal_retriever import MultimodalRetriever
|
| 2 |
+
from src.llm.llm_factory import get_llm
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class MultimodalRAG:
|
| 6 |
+
|
| 7 |
+
def __init__(self):
|
| 8 |
+
|
| 9 |
+
self.retriever = MultimodalRetriever()
|
| 10 |
+
self.llm = get_llm()
|
| 11 |
+
|
| 12 |
+
def ask(self, query):
|
| 13 |
+
|
| 14 |
+
docs, metas = self.retriever.retrieve(query, k=5)
|
| 15 |
+
|
| 16 |
+
context = "\n\n".join(docs)
|
| 17 |
+
|
| 18 |
+
prompt = f"""
|
| 19 |
+
You are a medical anatomy assistant.
|
| 20 |
+
|
| 21 |
+
Use the context to answer the question.
|
| 22 |
+
|
| 23 |
+
Context:
|
| 24 |
+
{context}
|
| 25 |
+
|
| 26 |
+
Question:
|
| 27 |
+
{query}
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
response = self.llm.invoke(prompt)
|
| 31 |
+
|
| 32 |
+
image_paths = [
|
| 33 |
+
m["image_path"] for m in metas if m["type"] == "image"
|
| 34 |
+
]
|
| 35 |
+
|
| 36 |
+
return response.content, image_paths
|
src/multimodal/multimodal_retriever.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from langchain_chroma import Chroma
|
| 2 |
+
from src.config.settings import settings
|
| 3 |
+
from src.multimodal.clip_embedding import CLIPEmbedding
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class MultimodalRetriever:
|
| 7 |
+
|
| 8 |
+
def __init__(self):
|
| 9 |
+
|
| 10 |
+
self.embedding = CLIPEmbedding()
|
| 11 |
+
|
| 12 |
+
self.vectorstore = Chroma(
|
| 13 |
+
collection_name="multimodal_rag",
|
| 14 |
+
persist_directory=f"{settings.processed_data_dir}/multimodal_chroma"
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
def retrieve(self, query, k=5):
|
| 18 |
+
|
| 19 |
+
query_embedding = self.embedding.embed_text([query])[0]
|
| 20 |
+
|
| 21 |
+
results = self.vectorstore._collection.query(
|
| 22 |
+
query_embeddings=[query_embedding.tolist()],
|
| 23 |
+
n_results=k
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
documents = results["documents"][0]
|
| 27 |
+
metadatas = results["metadatas"][0]
|
| 28 |
+
|
| 29 |
+
return documents, metadatas
|
src/multimodal/run_multimodal_indexing.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from src.multimodal.multimodal_indexer import MultimodalIndexer
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def main():
|
| 5 |
+
|
| 6 |
+
indexer = MultimodalIndexer()
|
| 7 |
+
indexer.index_images()
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
if __name__ == "__main__":
|
| 11 |
+
main()
|