Spaces:
Restarting
Restarting
Commit ·
2b523d0
0
Parent(s):
docker implementation with hugging face
Browse files- .dockerignore +32 -0
- .gitignore +20 -0
- Dockerfile +77 -0
- README.md +125 -0
- api/__init__.py +0 -0
- api/config.py +69 -0
- api/main.py +779 -0
- config.py +69 -0
- hybrid/__init__.py +7 -0
- hybrid/assistant.py +179 -0
- hybrid/web_search.py +93 -0
- main.py +125 -0
- mcq/__init__.py +7 -0
- mcq/generator.py +252 -0
- mcq/validator.py +99 -0
- models/__init__.py +0 -0
- models/embeddings.py +68 -0
- models/llm.py +109 -0
- rag/__init__.py +0 -0
- rag/generator.py +82 -0
- rag/retriever.py +70 -0
- requirements.txt +37 -0
- speech/__init__.py +13 -0
- speech/audio_handler.py +156 -0
- speech/formatter.py +197 -0
- speech/transcriber.py +103 -0
- tests/test_mcq.py +32 -0
- tests/test_rag.py +69 -0
- vectordb/__init__.py +0 -0
- vectordb/document_processor.py +172 -0
- vectordb/json_store.py +230 -0
.dockerignore
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Exclude local model cache from Docker build context.
|
| 2 |
+
# Models are downloaded DURING the build (in Dockerfile RUN step).
|
| 3 |
+
# If this folder were included, it would add 3+ GB to the build upload
|
| 4 |
+
# and potentially overwrite the freshly downloaded models.
|
| 5 |
+
models_cache/
|
| 6 |
+
|
| 7 |
+
# Python virtual environments (never needed in container)
|
| 8 |
+
.venv/
|
| 9 |
+
venv/
|
| 10 |
+
env/
|
| 11 |
+
|
| 12 |
+
# Runtime data (ephemeral — not part of the image)
|
| 13 |
+
data/
|
| 14 |
+
|
| 15 |
+
# Python bytecode
|
| 16 |
+
__pycache__/
|
| 17 |
+
*.pyc
|
| 18 |
+
*.pyo
|
| 19 |
+
*.pyd
|
| 20 |
+
|
| 21 |
+
# Environment variables — NEVER include in Docker image
|
| 22 |
+
.env
|
| 23 |
+
|
| 24 |
+
# Development files
|
| 25 |
+
*.md
|
| 26 |
+
tests/
|
| 27 |
+
.gitignore
|
| 28 |
+
.dockerignore
|
| 29 |
+
|
| 30 |
+
# OS files
|
| 31 |
+
.DS_Store
|
| 32 |
+
Thumbs.db
|
.gitignore
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__pycache__/
|
| 2 |
+
*.pyc
|
| 3 |
+
*.pyo
|
| 4 |
+
|
| 5 |
+
# Model cache — too large for Git (2.5+ GB), baked into Docker image instead
|
| 6 |
+
models_cache/
|
| 7 |
+
|
| 8 |
+
# Data files — runtime only, not part of source code
|
| 9 |
+
data/
|
| 10 |
+
chunks_only.json
|
| 11 |
+
embeddings_store.json
|
| 12 |
+
documents/
|
| 13 |
+
|
| 14 |
+
# Python virtual environments
|
| 15 |
+
.venv/
|
| 16 |
+
venv/
|
| 17 |
+
env/
|
| 18 |
+
|
| 19 |
+
# Environment variables — NEVER commit (contains API keys/tokens)
|
| 20 |
+
.env
|
Dockerfile
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ============================================================
|
| 2 |
+
# Cortexa AI — HuggingFace Docker Space
|
| 3 |
+
# HF CPU Basic free tier: 2 vCPU, 16 GB RAM
|
| 4 |
+
# Port 7860 is required by HuggingFace Spaces platform
|
| 5 |
+
# ============================================================
|
| 6 |
+
|
| 7 |
+
FROM python:3.11-slim
|
| 8 |
+
|
| 9 |
+
# --- System dependencies ---
|
| 10 |
+
# ffmpeg is REQUIRED for openai-whisper (audio processing)
|
| 11 |
+
# git is needed by some transformers internals
|
| 12 |
+
RUN apt-get update && apt-get install -y \
|
| 13 |
+
ffmpeg \
|
| 14 |
+
git \
|
| 15 |
+
build-essential \
|
| 16 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 17 |
+
|
| 18 |
+
# --- Non-root user (required by HuggingFace Spaces) ---
|
| 19 |
+
RUN useradd -m -u 1000 user
|
| 20 |
+
USER user
|
| 21 |
+
|
| 22 |
+
ENV HOME=/home/user \
|
| 23 |
+
PATH=/home/user/.local/bin:$PATH
|
| 24 |
+
|
| 25 |
+
WORKDIR /home/user/app
|
| 26 |
+
|
| 27 |
+
# --- Install Python dependencies ---
|
| 28 |
+
# Copy requirements first so this layer is cached separately from app code.
|
| 29 |
+
# If only your code changes (not requirements.txt), this entire layer is reused
|
| 30 |
+
# on the next push — no 15-min reinstall.
|
| 31 |
+
COPY --chown=user requirements.txt .
|
| 32 |
+
|
| 33 |
+
RUN pip install --no-cache-dir --user -r requirements.txt
|
| 34 |
+
|
| 35 |
+
# --- Pre-download all models into the Docker image ---
|
| 36 |
+
# This is the KEY trick for HuggingFace Spaces:
|
| 37 |
+
# - Models are downloaded ONCE during 'docker build' on HF's build servers
|
| 38 |
+
# - The resulting Docker layer is cached by HuggingFace
|
| 39 |
+
# - Every future container start uses the cached image — no re-download
|
| 40 |
+
# - Container startup time: ~30 seconds instead of 10+ minutes
|
| 41 |
+
#
|
| 42 |
+
# Build time for this step: ~10-20 minutes (one-time, on first push only)
|
| 43 |
+
# Models downloaded:
|
| 44 |
+
# - paraphrase-MiniLM-L3-v2 (~120 MB)
|
| 45 |
+
# - TinyLlama-1.1B-Chat-v1.0 (~2.2 GB on disk, ~4.4 GB in RAM fp32)
|
| 46 |
+
# - Whisper base (~140 MB)
|
| 47 |
+
RUN python -c "\
|
| 48 |
+
from sentence_transformers import SentenceTransformer; \
|
| 49 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer; \
|
| 50 |
+
import whisper, torch; \
|
| 51 |
+
print('--- Downloading sentence-transformers (120 MB) ---'); \
|
| 52 |
+
SentenceTransformer('sentence-transformers/paraphrase-MiniLM-L3-v2', cache_folder='/home/user/app/models_cache'); \
|
| 53 |
+
print('--- Downloading TinyLlama tokenizer ---'); \
|
| 54 |
+
AutoTokenizer.from_pretrained('TinyLlama/TinyLlama-1.1B-Chat-v1.0', cache_dir='/home/user/app/models_cache', trust_remote_code=True); \
|
| 55 |
+
print('--- Downloading TinyLlama model weights (2.2 GB, please wait) ---'); \
|
| 56 |
+
AutoModelForCausalLM.from_pretrained('TinyLlama/TinyLlama-1.1B-Chat-v1.0', cache_dir='/home/user/app/models_cache', torch_dtype=torch.float32, trust_remote_code=True); \
|
| 57 |
+
print('--- Downloading Whisper base (140 MB) ---'); \
|
| 58 |
+
whisper.load_model('base', download_root='/home/user/app/models_cache/whisper'); \
|
| 59 |
+
print('=== All models downloaded successfully ==='); \
|
| 60 |
+
"
|
| 61 |
+
|
| 62 |
+
# --- Copy application code ---
|
| 63 |
+
# This is after model download so that code-only changes don't invalidate
|
| 64 |
+
# the model download cache layer above.
|
| 65 |
+
COPY --chown=user . .
|
| 66 |
+
|
| 67 |
+
# --- Environment ---
|
| 68 |
+
ENV PYTHONPATH=/home/user/app
|
| 69 |
+
# HF_HOME tells HuggingFace library to use the pre-baked models_cache
|
| 70 |
+
ENV HF_HOME=/home/user/app/models_cache
|
| 71 |
+
ENV PORT=7860
|
| 72 |
+
|
| 73 |
+
# HuggingFace Spaces requires port 7860
|
| 74 |
+
EXPOSE 7860
|
| 75 |
+
|
| 76 |
+
# Start the FastAPI server
|
| 77 |
+
CMD ["python", "-m", "uvicorn", "api.main:app", "--host", "0.0.0.0", "--port", "7860"]
|
README.md
ADDED
|
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: Cortexa AI
|
| 3 |
+
emoji: 🧠
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: purple
|
| 6 |
+
sdk: docker
|
| 7 |
+
app_port: 7860
|
| 8 |
+
pinned: false
|
| 9 |
+
---
|
| 10 |
+
|
| 11 |
+
# Cortexa RAG System
|
| 12 |
+
|
| 13 |
+
Retrieval-Augmented Generation (RAG) system for educational content Q&A.
|
| 14 |
+
|
| 15 |
+
## Features
|
| 16 |
+
|
| 17 |
+
- 📄 Document processing (PDF, TXT, DOCX)
|
| 18 |
+
- 🔍 Semantic search with embeddings
|
| 19 |
+
- 💬 Citation-backed answers
|
| 20 |
+
- 🚀 No external AI APIs required
|
| 21 |
+
- 🔒 Runs locally
|
| 22 |
+
|
| 23 |
+
## Setup
|
| 24 |
+
|
| 25 |
+
### 1. Install Dependencies
|
| 26 |
+
```
|
| 27 |
+
cd ai
|
| 28 |
+
pip install -r requirements.txt
|
| 29 |
+
```
|
| 30 |
+
|
| 31 |
+
### 2. Add Documents
|
| 32 |
+
|
| 33 |
+
Place your PDF/TXT/DOCX files in `data/documents/`
|
| 34 |
+
|
| 35 |
+
### 3. Run System
|
| 36 |
+
|
| 37 |
+
```
|
| 38 |
+
python main.py
|
| 39 |
+
```
|
| 40 |
+
|
| 41 |
+
### 4. Run API Server
|
| 42 |
+
|
| 43 |
+
```
|
| 44 |
+
python api/main.py
|
| 45 |
+
or
|
| 46 |
+
python -m api.main
|
| 47 |
+
```
|
| 48 |
+
|
| 49 |
+
Then visit: `http://localhost:8000/docs`
|
| 50 |
+
|
| 51 |
+
## Usage
|
| 52 |
+
|
| 53 |
+
### CLI Mode
|
| 54 |
+
|
| 55 |
+
```
|
| 56 |
+
python main.py
|
| 57 |
+
```
|
| 58 |
+
|
| 59 |
+
### API Mode
|
| 60 |
+
|
| 61 |
+
Start server
|
| 62 |
+
```
|
| 63 |
+
python api/main.py
|
| 64 |
+
```
|
| 65 |
+
Upload document
|
| 66 |
+
```
|
| 67 |
+
curl -X POST "http://localhost:8000/upload"
|
| 68 |
+
-F "file=@document.pdf"
|
| 69 |
+
-F "institution_id=mit"
|
| 70 |
+
```
|
| 71 |
+
|
| 72 |
+
Query
|
| 73 |
+
```
|
| 74 |
+
curl -X POST "http://localhost:8000/query"
|
| 75 |
+
-H "Content-Type: application/json"
|
| 76 |
+
-d '{"query": "What is machine learning?"}'
|
| 77 |
+
```
|
| 78 |
+
|
| 79 |
+
## Project Structure
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
```
|
| 83 |
+
ai/
|
| 84 |
+
├── models/ # Embedding & LLM models
|
| 85 |
+
├── vectordb/ # Vector store & document processing
|
| 86 |
+
├── rag/ # Retrieval & generation
|
| 87 |
+
├── api/ # FastAPI server
|
| 88 |
+
├── data/ # Documents & processed data
|
| 89 |
+
└── tests/ # Unit tests
|
| 90 |
+
```
|
| 91 |
+
|
| 92 |
+
## Models Used
|
| 93 |
+
|
| 94 |
+
- **Embeddings**: sentence-transformers/paraphrase-MiniLM-L3-v2
|
| 95 |
+
- **LLM**: TinyLlama/TinyLlama-1.1B-Chat-v1.0
|
| 96 |
+
- **Vector DB**: JsonStore
|
| 97 |
+
|
| 98 |
+
## System Requirements
|
| 99 |
+
|
| 100 |
+
- **CPU**: Works on CPU (slower)
|
| 101 |
+
- **GPU**: Recommended for faster inference
|
| 102 |
+
- **RAM**: 8GB minimum, 16GB recommended
|
| 103 |
+
- **Storage**: ~5GB for models
|
| 104 |
+
|
| 105 |
+
### Setup & Running Instructions
|
| 106 |
+
|
| 107 |
+
#### Step 1: Install
|
| 108 |
+
|
| 109 |
+
```
|
| 110 |
+
cd ai
|
| 111 |
+
pip install -r requirements.txt
|
| 112 |
+
```
|
| 113 |
+
#### Step 2: Add Sample Documents
|
| 114 |
+
Place some PDF/TXT files in ai/data/documents/
|
| 115 |
+
|
| 116 |
+
#### Step 3: Run
|
| 117 |
+
```
|
| 118 |
+
python main.py
|
| 119 |
+
```
|
| 120 |
+
#### Step 4: Test API
|
| 121 |
+
```
|
| 122 |
+
python api/main.py
|
| 123 |
+
```
|
| 124 |
+
|
| 125 |
+
## This is a complete, production-ready RAG system that runs entirely locally without any external AI APIs! 🚀
|
api/__init__.py
ADDED
|
File without changes
|
api/config.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Configuration file for RAG system
|
| 3 |
+
"""
|
| 4 |
+
import torch
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
|
| 7 |
+
# Base paths
|
| 8 |
+
BASE_DIR = Path(__file__).parent
|
| 9 |
+
DATA_DIR = BASE_DIR / "data"
|
| 10 |
+
DOCUMENTS_DIR = DATA_DIR / "documents"
|
| 11 |
+
PROCESSED_DIR = DATA_DIR / "processed"
|
| 12 |
+
MODELS_DIR = BASE_DIR / "models_cache"
|
| 13 |
+
|
| 14 |
+
# NEW: Audio storage
|
| 15 |
+
AUDIO_DIR = DATA_DIR / "audio"
|
| 16 |
+
TRANSCRIPTS_DIR = DATA_DIR / "transcripts"
|
| 17 |
+
|
| 18 |
+
# Create directories if they don't exist
|
| 19 |
+
for dir_path in [DATA_DIR, DOCUMENTS_DIR, PROCESSED_DIR, MODELS_DIR, AUDIO_DIR, TRANSCRIPTS_DIR]:
|
| 20 |
+
dir_path.mkdir(parents=True, exist_ok=True)
|
| 21 |
+
|
| 22 |
+
# JSON storage file
|
| 23 |
+
EMBEDDINGS_JSON = PROCESSED_DIR / "embeddings_store.json"
|
| 24 |
+
|
| 25 |
+
# Model configurations
|
| 26 |
+
EMBEDDING_MODEL = "sentence-transformers/paraphrase-MiniLM-L3-v2" # 120 MB
|
| 27 |
+
LLM_MODEL = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" # 1.1 GB
|
| 28 |
+
WHISPER_MODEL = "base" # Options: tiny, base, small, medium, large
|
| 29 |
+
|
| 30 |
+
# Alternative faster models (uncomment to use):
|
| 31 |
+
# LLM_MODEL = "distilgpt2" # 350 MB - RECOMMENDED: 3-5x faster!
|
| 32 |
+
# LLM_MODEL = "gpt2" # 500 MB - 2x faster than TinyLlama
|
| 33 |
+
# NEW: Whisper model configuration
|
| 34 |
+
# Model sizes:
|
| 35 |
+
# - tiny: ~75MB, fastest
|
| 36 |
+
# - base: ~140MB, good balance (RECOMMENDED)
|
| 37 |
+
# - small: ~470MB, better accuracy
|
| 38 |
+
# - medium: ~1.5GB, high accuracy
|
| 39 |
+
# - large: ~3GB, best accuracy
|
| 40 |
+
|
| 41 |
+
# Chunking settings
|
| 42 |
+
CHUNK_SIZE = 512
|
| 43 |
+
CHUNK_OVERLAP = 50
|
| 44 |
+
MAX_CHUNKS_PER_DOC = 1000
|
| 45 |
+
|
| 46 |
+
# Retrieval settings
|
| 47 |
+
TOP_K = 3 # Reduced from 5 for faster retrieval
|
| 48 |
+
SIMILARITY_THRESHOLD = 0.3
|
| 49 |
+
|
| 50 |
+
# Generation settings
|
| 51 |
+
MAX_NEW_TOKENS = 256 # Reduced from 512 for faster generation
|
| 52 |
+
TEMPERATURE = 0.7
|
| 53 |
+
TOP_P = 0.9
|
| 54 |
+
|
| 55 |
+
# MCQ Generation settings (optimized for speed)
|
| 56 |
+
MCQ_MAX_TOKENS_PER_QUESTION = 150 # ~150 tokens per MCQ
|
| 57 |
+
MCQ_MAX_CONTEXT_LENGTH = 1000 # Shorter context = faster generation
|
| 58 |
+
|
| 59 |
+
# Audio/Transcription settings
|
| 60 |
+
MAX_AUDIO_SIZE_MB = 100 # Maximum audio file size
|
| 61 |
+
SUPPORTED_AUDIO_FORMATS = ['.wav', '.mp3', '.m4a', '.ogg', '.flac']
|
| 62 |
+
WHISPER_LANGUAGE = "en" # English only as per requirement
|
| 63 |
+
|
| 64 |
+
# Device settings
|
| 65 |
+
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 66 |
+
|
| 67 |
+
# Performance settings
|
| 68 |
+
USE_FAST_TOKENIZER = True
|
| 69 |
+
LOW_CPU_MEM_USAGE = True
|
api/main.py
ADDED
|
@@ -0,0 +1,779 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
FastAPI server for RAG system with Voice-to-Text
|
| 3 |
+
"""
|
| 4 |
+
from fastapi import FastAPI, UploadFile, File, HTTPException, Form
|
| 5 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 6 |
+
from fastapi.responses import FileResponse
|
| 7 |
+
from pydantic import BaseModel
|
| 8 |
+
from typing import List, Optional, Dict
|
| 9 |
+
import shutil
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
|
| 12 |
+
from config import DOCUMENTS_DIR, AUDIO_DIR, TRANSCRIPTS_DIR
|
| 13 |
+
from vectordb.document_processor import DocumentProcessor
|
| 14 |
+
from vectordb.json_store import get_json_store
|
| 15 |
+
from rag.retriever import get_retriever
|
| 16 |
+
from rag.generator import get_generator
|
| 17 |
+
from mcq.generator import get_mcq_generator
|
| 18 |
+
from mcq.validator import MCQValidator
|
| 19 |
+
from hybrid.assistant import get_hybrid_assistant
|
| 20 |
+
|
| 21 |
+
# NEW: Import speech modules
|
| 22 |
+
from speech.transcriber import get_transcriber
|
| 23 |
+
from speech.formatter import TextFormatter
|
| 24 |
+
from speech.audio_handler import AudioHandler
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
app = FastAPI(title="Cortexa RAG API", version="2.0.0")
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
app.add_middleware(
|
| 31 |
+
CORSMiddleware,
|
| 32 |
+
allow_origins=["*"],
|
| 33 |
+
allow_credentials=True,
|
| 34 |
+
allow_methods=["*"],
|
| 35 |
+
allow_headers=["*"],
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
@app.on_event("startup")
|
| 40 |
+
async def startup_event():
|
| 41 |
+
"""Pre-load models on startup"""
|
| 42 |
+
print("="*60)
|
| 43 |
+
print("🚀 Starting Cortexa AI Server...")
|
| 44 |
+
print("="*60)
|
| 45 |
+
print("📦 Loading AI models (this may take 30-60 seconds)...")
|
| 46 |
+
print("✅ Models loaded successfully!")
|
| 47 |
+
print("🌐 Server ready at http://localhost:8000")
|
| 48 |
+
print("📚 API docs at http://localhost:8000/docs")
|
| 49 |
+
print("="*60)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
# ============================================================================
|
| 53 |
+
# PYDANTIC MODELS
|
| 54 |
+
# ============================================================================
|
| 55 |
+
|
| 56 |
+
class QueryRequest(BaseModel):
|
| 57 |
+
query: str
|
| 58 |
+
top_k: Optional[int] = 5
|
| 59 |
+
institution_id: Optional[str] = None
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
class QueryResponse(BaseModel):
|
| 63 |
+
query: str
|
| 64 |
+
answer: str
|
| 65 |
+
sources: List[dict]
|
| 66 |
+
context: str
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
class DocumentUploadResponse(BaseModel):
|
| 70 |
+
filename: str
|
| 71 |
+
chunks_added: int
|
| 72 |
+
status: str
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
class MCQGenerateRequest(BaseModel):
|
| 76 |
+
source_type: str # "text", "document", "topic"
|
| 77 |
+
source: str # text content, document name, or topic
|
| 78 |
+
num_questions: int = 5
|
| 79 |
+
difficulty: str = "medium"
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
class MCQScoreRequest(BaseModel):
|
| 83 |
+
mcqs: List[dict]
|
| 84 |
+
user_answers: Dict[int, str]
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
class HybridQueryRequest(BaseModel):
|
| 88 |
+
query: str
|
| 89 |
+
use_web_fallback: bool = True
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
# NEW: Speech-to-Text Models
|
| 93 |
+
class TranscribeRequest(BaseModel):
|
| 94 |
+
audio_filename: str
|
| 95 |
+
include_timestamps: bool = True
|
| 96 |
+
format_text: bool = True
|
| 97 |
+
export_format: str = "both" # "markdown", "docx", "both"
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
class TranscribeResponse(BaseModel):
|
| 101 |
+
status: str
|
| 102 |
+
text: str
|
| 103 |
+
duration: float
|
| 104 |
+
formatted_text: Optional[str] = None
|
| 105 |
+
download_links: Dict[str, str] = {}
|
| 106 |
+
segments: Optional[List[Dict]] = None
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
# ============================================================================
|
| 110 |
+
# GLOBAL LAZY LOADING INSTANCES
|
| 111 |
+
# ============================================================================
|
| 112 |
+
|
| 113 |
+
# Existing instances
|
| 114 |
+
_doc_processor = None
|
| 115 |
+
_vector_store = None
|
| 116 |
+
_retriever = None
|
| 117 |
+
_generator = None
|
| 118 |
+
_mcq_generator = None
|
| 119 |
+
_mcq_validator = None
|
| 120 |
+
_hybrid_assistant = None
|
| 121 |
+
|
| 122 |
+
# NEW: Speech module instances
|
| 123 |
+
_transcriber = None
|
| 124 |
+
_audio_handler = None
|
| 125 |
+
_text_formatter = None
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def get_doc_processor():
|
| 129 |
+
global _doc_processor
|
| 130 |
+
if _doc_processor is None:
|
| 131 |
+
_doc_processor = DocumentProcessor()
|
| 132 |
+
return _doc_processor
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
def get_vector_store():
|
| 136 |
+
global _vector_store
|
| 137 |
+
if _vector_store is None:
|
| 138 |
+
_vector_store = get_json_store()
|
| 139 |
+
return _vector_store
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def get_retriever_instance():
|
| 143 |
+
global _retriever
|
| 144 |
+
if _retriever is None:
|
| 145 |
+
_retriever = get_retriever()
|
| 146 |
+
return _retriever
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
def get_generator_instance():
|
| 150 |
+
global _generator
|
| 151 |
+
if _generator is None:
|
| 152 |
+
_generator = get_generator()
|
| 153 |
+
return _generator
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def get_mcq_generator_instance():
|
| 157 |
+
global _mcq_generator
|
| 158 |
+
if _mcq_generator is None:
|
| 159 |
+
_mcq_generator = get_mcq_generator()
|
| 160 |
+
return _mcq_generator
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
def get_mcq_validator_instance():
|
| 164 |
+
global _mcq_validator
|
| 165 |
+
if _mcq_validator is None:
|
| 166 |
+
_mcq_validator = MCQValidator()
|
| 167 |
+
return _mcq_validator
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
def get_hybrid_assistant_instance():
|
| 171 |
+
global _hybrid_assistant
|
| 172 |
+
if _hybrid_assistant is None:
|
| 173 |
+
_hybrid_assistant = get_hybrid_assistant()
|
| 174 |
+
return _hybrid_assistant
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
# NEW: Speech module getters
|
| 178 |
+
def get_transcriber_instance():
|
| 179 |
+
global _transcriber
|
| 180 |
+
if _transcriber is None:
|
| 181 |
+
_transcriber = get_transcriber()
|
| 182 |
+
return _transcriber
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
def get_audio_handler():
|
| 186 |
+
global _audio_handler
|
| 187 |
+
if _audio_handler is None:
|
| 188 |
+
_audio_handler = AudioHandler()
|
| 189 |
+
return _audio_handler
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
def get_text_formatter():
|
| 193 |
+
global _text_formatter
|
| 194 |
+
if _text_formatter is None:
|
| 195 |
+
_text_formatter = TextFormatter()
|
| 196 |
+
return _text_formatter
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
# ============================================================================
|
| 200 |
+
# BASIC ENDPOINTS
|
| 201 |
+
# ============================================================================
|
| 202 |
+
|
| 203 |
+
@app.get("/")
|
| 204 |
+
def root():
|
| 205 |
+
return {
|
| 206 |
+
"message": "Cortexa RAG API with Voice-to-Text",
|
| 207 |
+
"status": "running",
|
| 208 |
+
"version": "2.0.0",
|
| 209 |
+
"features": [
|
| 210 |
+
"Document RAG",
|
| 211 |
+
"MCQ Generation",
|
| 212 |
+
"Hybrid Assistant",
|
| 213 |
+
"Voice-to-Text Transcription"
|
| 214 |
+
]
|
| 215 |
+
}
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
@app.get("/health")
|
| 219 |
+
def health_check():
|
| 220 |
+
try:
|
| 221 |
+
vector_store = get_vector_store()
|
| 222 |
+
stats = vector_store.get_stats()
|
| 223 |
+
return {"status": "healthy", "store": stats}
|
| 224 |
+
except Exception as e:
|
| 225 |
+
return {"status": "unhealthy", "error": str(e)}
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
# ============================================================================
|
| 229 |
+
# DOCUMENT UPLOAD & QUERY ENDPOINTS
|
| 230 |
+
# ============================================================================
|
| 231 |
+
|
| 232 |
+
@app.post("/upload", response_model=DocumentUploadResponse)
|
| 233 |
+
async def upload_document(
|
| 234 |
+
file: UploadFile = File(...),
|
| 235 |
+
institution_id: Optional[str] = None,
|
| 236 |
+
course_id: Optional[str] = None
|
| 237 |
+
):
|
| 238 |
+
"""Upload and process document for RAG system"""
|
| 239 |
+
try:
|
| 240 |
+
doc_processor = get_doc_processor()
|
| 241 |
+
vector_store = get_vector_store()
|
| 242 |
+
|
| 243 |
+
file_path = DOCUMENTS_DIR / file.filename
|
| 244 |
+
with open(file_path, "wb") as buffer:
|
| 245 |
+
shutil.copyfileobj(file.file, buffer)
|
| 246 |
+
|
| 247 |
+
metadata = {
|
| 248 |
+
'institution_id': institution_id,
|
| 249 |
+
'course_id': course_id
|
| 250 |
+
}
|
| 251 |
+
|
| 252 |
+
chunks = doc_processor.process_document(str(file_path), metadata)
|
| 253 |
+
|
| 254 |
+
texts = [chunk.text for chunk in chunks]
|
| 255 |
+
metadatas = [chunk.metadata for chunk in chunks]
|
| 256 |
+
ids = [f"{file.filename}_{i}" for i in range(len(chunks))]
|
| 257 |
+
|
| 258 |
+
vector_store.add_documents(texts, metadatas, ids)
|
| 259 |
+
|
| 260 |
+
return DocumentUploadResponse(
|
| 261 |
+
filename=file.filename,
|
| 262 |
+
chunks_added=len(chunks),
|
| 263 |
+
status="success"
|
| 264 |
+
)
|
| 265 |
+
except Exception as e:
|
| 266 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
@app.post("/query", response_model=QueryResponse)
|
| 270 |
+
async def query_documents(request: QueryRequest):
|
| 271 |
+
"""Query RAG system with semantic search"""
|
| 272 |
+
try:
|
| 273 |
+
retriever = get_retriever_instance()
|
| 274 |
+
generator = get_generator_instance()
|
| 275 |
+
|
| 276 |
+
filter_metadata = None
|
| 277 |
+
if request.institution_id:
|
| 278 |
+
filter_metadata = {'institution_id': request.institution_id}
|
| 279 |
+
|
| 280 |
+
retrieved_docs = retriever.retrieve(
|
| 281 |
+
query=request.query,
|
| 282 |
+
top_k=request.top_k,
|
| 283 |
+
filter_metadata=filter_metadata
|
| 284 |
+
)
|
| 285 |
+
|
| 286 |
+
context = retriever.format_context(retrieved_docs)
|
| 287 |
+
answer = generator.generate_response(request.query, context)
|
| 288 |
+
|
| 289 |
+
sources = [
|
| 290 |
+
{
|
| 291 |
+
'source': doc['source'],
|
| 292 |
+
'chunk_index': doc['chunk_index'],
|
| 293 |
+
'similarity': doc['similarity'],
|
| 294 |
+
'text_preview': doc['text'][:200] + "..."
|
| 295 |
+
}
|
| 296 |
+
for doc in retrieved_docs
|
| 297 |
+
]
|
| 298 |
+
|
| 299 |
+
return QueryResponse(
|
| 300 |
+
query=request.query,
|
| 301 |
+
answer=answer,
|
| 302 |
+
sources=sources,
|
| 303 |
+
context=context
|
| 304 |
+
)
|
| 305 |
+
except Exception as e:
|
| 306 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 307 |
+
|
| 308 |
+
|
| 309 |
+
@app.delete("/documents/all")
|
| 310 |
+
def delete_all_documents():
|
| 311 |
+
"""Delete all documents from vector store"""
|
| 312 |
+
try:
|
| 313 |
+
vector_store = get_vector_store()
|
| 314 |
+
vector_store.delete_all()
|
| 315 |
+
return {"status": "success", "message": "All documents deleted"}
|
| 316 |
+
except Exception as e:
|
| 317 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 318 |
+
|
| 319 |
+
|
| 320 |
+
@app.get("/export/chunks")
|
| 321 |
+
def export_chunks():
|
| 322 |
+
"""Export chunks without embeddings"""
|
| 323 |
+
try:
|
| 324 |
+
vector_store = get_vector_store()
|
| 325 |
+
vector_store.export_chunks_only()
|
| 326 |
+
return {"status": "success", "message": "Chunks exported to chunks_only.json"}
|
| 327 |
+
except Exception as e:
|
| 328 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 329 |
+
|
| 330 |
+
|
| 331 |
+
# ============================================================================
|
| 332 |
+
# MCQ GENERATION ENDPOINTS
|
| 333 |
+
# ============================================================================
|
| 334 |
+
|
| 335 |
+
@app.post("/mcq/generate")
|
| 336 |
+
async def generate_mcqs(request: MCQGenerateRequest):
|
| 337 |
+
"""Generate MCQs from text, document, or topic"""
|
| 338 |
+
try:
|
| 339 |
+
mcq_generator = get_mcq_generator_instance()
|
| 340 |
+
mcq_validator = get_mcq_validator_instance()
|
| 341 |
+
|
| 342 |
+
if request.source_type == "text":
|
| 343 |
+
mcqs = mcq_generator.generate_from_text(
|
| 344 |
+
text=request.source,
|
| 345 |
+
num_questions=request.num_questions,
|
| 346 |
+
difficulty=request.difficulty
|
| 347 |
+
)
|
| 348 |
+
elif request.source_type == "document":
|
| 349 |
+
mcqs = mcq_generator.generate_from_document(
|
| 350 |
+
document_name=request.source,
|
| 351 |
+
num_questions=request.num_questions,
|
| 352 |
+
difficulty=request.difficulty
|
| 353 |
+
)
|
| 354 |
+
elif request.source_type == "topic":
|
| 355 |
+
mcqs = mcq_generator.generate_from_topic(
|
| 356 |
+
topic=request.source,
|
| 357 |
+
num_questions=request.num_questions,
|
| 358 |
+
difficulty=request.difficulty
|
| 359 |
+
)
|
| 360 |
+
else:
|
| 361 |
+
raise HTTPException(status_code=400, detail="Invalid source_type")
|
| 362 |
+
|
| 363 |
+
# Filter valid MCQs
|
| 364 |
+
valid_mcqs = [mcq for mcq in mcqs if mcq_validator.validate_mcq(mcq)]
|
| 365 |
+
|
| 366 |
+
return {
|
| 367 |
+
"status": "success",
|
| 368 |
+
"total_generated": len(mcqs),
|
| 369 |
+
"valid_mcqs": len(valid_mcqs),
|
| 370 |
+
"mcqs": valid_mcqs
|
| 371 |
+
}
|
| 372 |
+
|
| 373 |
+
except Exception as e:
|
| 374 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 375 |
+
|
| 376 |
+
|
| 377 |
+
@app.post("/mcq/score")
|
| 378 |
+
async def score_mcqs(request: MCQScoreRequest):
|
| 379 |
+
"""Score user answers"""
|
| 380 |
+
try:
|
| 381 |
+
mcq_validator = get_mcq_validator_instance()
|
| 382 |
+
result = mcq_validator.score_answers(
|
| 383 |
+
mcqs=request.mcqs,
|
| 384 |
+
user_answers=request.user_answers
|
| 385 |
+
)
|
| 386 |
+
return result
|
| 387 |
+
except Exception as e:
|
| 388 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 389 |
+
|
| 390 |
+
|
| 391 |
+
# ============================================================================
|
| 392 |
+
# HYBRID ASSISTANT ENDPOINT
|
| 393 |
+
# ============================================================================
|
| 394 |
+
|
| 395 |
+
@app.post("/assistant")
|
| 396 |
+
async def hybrid_query(request: HybridQueryRequest):
|
| 397 |
+
"""
|
| 398 |
+
Hybrid AI Assistant - Searches documents first, then web if needed
|
| 399 |
+
"""
|
| 400 |
+
try:
|
| 401 |
+
print(f"📥 Received query: {request.query[:50]}...")
|
| 402 |
+
print(f"🌐 Web fallback: {request.use_web_fallback}")
|
| 403 |
+
|
| 404 |
+
hybrid_assistant = get_hybrid_assistant_instance()
|
| 405 |
+
result = hybrid_assistant.answer(
|
| 406 |
+
query=request.query,
|
| 407 |
+
use_web=request.use_web_fallback
|
| 408 |
+
)
|
| 409 |
+
|
| 410 |
+
print(f"✅ Query successful! Method: {result.get('search_method', 'unknown')}")
|
| 411 |
+
return result
|
| 412 |
+
except Exception as e:
|
| 413 |
+
print(f"❌ Query failed: {str(e)}")
|
| 414 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 415 |
+
|
| 416 |
+
|
| 417 |
+
# ============================================================================
|
| 418 |
+
# VOICE-TO-TEXT ENDPOINTS (NEW)
|
| 419 |
+
# ============================================================================
|
| 420 |
+
|
| 421 |
+
@app.post("/speech/upload-audio")
|
| 422 |
+
async def upload_audio(
|
| 423 |
+
file: UploadFile = File(...),
|
| 424 |
+
teacher_id: Optional[str] = Form(None),
|
| 425 |
+
lecture_title: Optional[str] = Form(None)
|
| 426 |
+
):
|
| 427 |
+
"""
|
| 428 |
+
Upload audio file for transcription
|
| 429 |
+
|
| 430 |
+
Supported formats: .wav, .mp3, .m4a, .ogg, .flac
|
| 431 |
+
Max size: 100MB (configurable in config.py)
|
| 432 |
+
"""
|
| 433 |
+
try:
|
| 434 |
+
audio_handler = get_audio_handler()
|
| 435 |
+
|
| 436 |
+
# Save uploaded file
|
| 437 |
+
file_path = AUDIO_DIR / file.filename
|
| 438 |
+
with open(file_path, "wb") as buffer:
|
| 439 |
+
shutil.copyfileobj(file.file, buffer)
|
| 440 |
+
|
| 441 |
+
# Validate audio
|
| 442 |
+
audio_handler.validate_audio(str(file_path))
|
| 443 |
+
duration = audio_handler.get_audio_duration(str(file_path))
|
| 444 |
+
|
| 445 |
+
return {
|
| 446 |
+
"status": "success",
|
| 447 |
+
"filename": file.filename,
|
| 448 |
+
"path": str(file_path),
|
| 449 |
+
"duration_seconds": round(duration, 2),
|
| 450 |
+
"size_mb": round(file_path.stat().st_size / (1024 * 1024), 2),
|
| 451 |
+
"teacher_id": teacher_id,
|
| 452 |
+
"lecture_title": lecture_title,
|
| 453 |
+
"message": "Audio uploaded successfully. Use /speech/transcribe to convert to text."
|
| 454 |
+
}
|
| 455 |
+
|
| 456 |
+
except ValueError as ve:
|
| 457 |
+
raise HTTPException(status_code=400, detail=str(ve))
|
| 458 |
+
except Exception as e:
|
| 459 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 460 |
+
|
| 461 |
+
|
| 462 |
+
@app.post("/speech/transcribe", response_model=TranscribeResponse)
|
| 463 |
+
async def transcribe_audio(request: TranscribeRequest):
|
| 464 |
+
"""
|
| 465 |
+
Transcribe uploaded audio to text
|
| 466 |
+
|
| 467 |
+
Features:
|
| 468 |
+
- Converts speech to English text using Whisper
|
| 469 |
+
- Optional formatting with headings/structure using LLM
|
| 470 |
+
- Export to Markdown and/or DOCX format
|
| 471 |
+
- Returns timestamps for each segment
|
| 472 |
+
"""
|
| 473 |
+
try:
|
| 474 |
+
audio_path = AUDIO_DIR / request.audio_filename
|
| 475 |
+
|
| 476 |
+
if not audio_path.exists():
|
| 477 |
+
raise HTTPException(
|
| 478 |
+
status_code=404,
|
| 479 |
+
detail=f"Audio file not found: {request.audio_filename}"
|
| 480 |
+
)
|
| 481 |
+
|
| 482 |
+
# Step 1: Transcribe audio
|
| 483 |
+
print(f"🎙️ Starting transcription: {request.audio_filename}")
|
| 484 |
+
transcriber = get_transcriber_instance()
|
| 485 |
+
result = transcriber.transcribe_audio(
|
| 486 |
+
str(audio_path),
|
| 487 |
+
include_timestamps=request.include_timestamps
|
| 488 |
+
)
|
| 489 |
+
|
| 490 |
+
raw_text = result["text"]
|
| 491 |
+
segments = result.get("segments", [])
|
| 492 |
+
duration = result.get("duration", 0)
|
| 493 |
+
|
| 494 |
+
# Step 2: Format text if requested
|
| 495 |
+
formatted_text = None
|
| 496 |
+
download_links = {}
|
| 497 |
+
|
| 498 |
+
if request.format_text:
|
| 499 |
+
print("📝 Formatting text with structure...")
|
| 500 |
+
formatter = get_text_formatter()
|
| 501 |
+
formatted_text = formatter.format_as_structured_text(raw_text, segments)
|
| 502 |
+
|
| 503 |
+
# Export to requested formats
|
| 504 |
+
base_filename = Path(request.audio_filename).stem
|
| 505 |
+
|
| 506 |
+
if request.export_format in ["markdown", "both"]:
|
| 507 |
+
md_path = formatter.export_to_markdown(
|
| 508 |
+
formatted_text,
|
| 509 |
+
base_filename,
|
| 510 |
+
title=f"Lecture: {base_filename}"
|
| 511 |
+
)
|
| 512 |
+
download_links["markdown"] = f"/speech/download/{Path(md_path).name}"
|
| 513 |
+
|
| 514 |
+
if request.export_format in ["docx", "both"]:
|
| 515 |
+
docx_path = formatter.export_to_docx(
|
| 516 |
+
formatted_text,
|
| 517 |
+
base_filename,
|
| 518 |
+
title=f"Lecture: {base_filename}",
|
| 519 |
+
segments=segments
|
| 520 |
+
)
|
| 521 |
+
download_links["docx"] = f"/speech/download/{Path(docx_path).name}"
|
| 522 |
+
|
| 523 |
+
return TranscribeResponse(
|
| 524 |
+
status="success",
|
| 525 |
+
text=raw_text,
|
| 526 |
+
duration=round(duration, 2),
|
| 527 |
+
formatted_text=formatted_text,
|
| 528 |
+
download_links=download_links,
|
| 529 |
+
segments=segments if request.include_timestamps else None
|
| 530 |
+
)
|
| 531 |
+
|
| 532 |
+
except HTTPException:
|
| 533 |
+
raise
|
| 534 |
+
except Exception as e:
|
| 535 |
+
print(f"❌ Transcription error: {str(e)}")
|
| 536 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 537 |
+
|
| 538 |
+
|
| 539 |
+
@app.post("/speech/transcribe-and-upload")
|
| 540 |
+
async def transcribe_and_upload_to_rag(
|
| 541 |
+
audio_file: UploadFile = File(...),
|
| 542 |
+
institution_id: Optional[str] = Form(None),
|
| 543 |
+
course_id: Optional[str] = Form(None),
|
| 544 |
+
lecture_title: Optional[str] = Form("Untitled Lecture"),
|
| 545 |
+
teacher_id: Optional[str] = Form(None)
|
| 546 |
+
):
|
| 547 |
+
"""
|
| 548 |
+
Complete workflow for teachers: Upload audio → Transcribe → Format → Add to RAG
|
| 549 |
+
|
| 550 |
+
This is the main endpoint for lecture recording feature:
|
| 551 |
+
1. Uploads audio file
|
| 552 |
+
2. Transcribes to English text using Whisper
|
| 553 |
+
3. Formats with headings/structure using LLM
|
| 554 |
+
4. Exports to DOCX document
|
| 555 |
+
5. Adds transcript to RAG system for student queries
|
| 556 |
+
6. Returns formatted text for immediate display
|
| 557 |
+
"""
|
| 558 |
+
try:
|
| 559 |
+
# Step 1: Save audio
|
| 560 |
+
print(f"📤 Uploading audio: {audio_file.filename}")
|
| 561 |
+
audio_path = AUDIO_DIR / audio_file.filename
|
| 562 |
+
with open(audio_path, "wb") as buffer:
|
| 563 |
+
shutil.copyfileobj(audio_file.file, buffer)
|
| 564 |
+
|
| 565 |
+
# Step 2: Validate audio
|
| 566 |
+
audio_handler = get_audio_handler()
|
| 567 |
+
audio_handler.validate_audio(str(audio_path))
|
| 568 |
+
|
| 569 |
+
# Step 3: Transcribe
|
| 570 |
+
print(f"🎙️ Transcribing: {audio_file.filename}")
|
| 571 |
+
transcriber = get_transcriber_instance()
|
| 572 |
+
result = transcriber.transcribe_audio(str(audio_path))
|
| 573 |
+
raw_text = result["text"]
|
| 574 |
+
duration = result.get("duration", 0)
|
| 575 |
+
segments = result.get("segments", [])
|
| 576 |
+
|
| 577 |
+
print(f"✅ Transcription complete! Duration: {duration:.2f}s")
|
| 578 |
+
|
| 579 |
+
# Step 4: Format with structure
|
| 580 |
+
print("📝 Formatting transcript with headings...")
|
| 581 |
+
formatter = get_text_formatter()
|
| 582 |
+
formatted_text = formatter.format_as_structured_text(raw_text, segments)
|
| 583 |
+
|
| 584 |
+
# Step 5: Export to DOCX
|
| 585 |
+
base_filename = Path(audio_file.filename).stem
|
| 586 |
+
docx_path = formatter.export_to_docx(
|
| 587 |
+
formatted_text,
|
| 588 |
+
base_filename,
|
| 589 |
+
title=lecture_title,
|
| 590 |
+
segments=segments
|
| 591 |
+
)
|
| 592 |
+
|
| 593 |
+
# Step 6: Add transcript to RAG system
|
| 594 |
+
print("🔄 Adding transcript to RAG knowledge base...")
|
| 595 |
+
doc_processor = get_doc_processor()
|
| 596 |
+
vector_store = get_vector_store()
|
| 597 |
+
|
| 598 |
+
metadata = {
|
| 599 |
+
'institution_id': institution_id,
|
| 600 |
+
'course_id': course_id,
|
| 601 |
+
'lecture_title': lecture_title,
|
| 602 |
+
'teacher_id': teacher_id,
|
| 603 |
+
'content_type': 'lecture_transcript',
|
| 604 |
+
'audio_filename': audio_file.filename,
|
| 605 |
+
'duration': duration
|
| 606 |
+
}
|
| 607 |
+
|
| 608 |
+
chunks = doc_processor.process_document(docx_path, metadata)
|
| 609 |
+
texts = [chunk.text for chunk in chunks]
|
| 610 |
+
metadatas = [chunk.metadata for chunk in chunks]
|
| 611 |
+
ids = [f"{base_filename}_transcript_{i}" for i in range(len(chunks))]
|
| 612 |
+
|
| 613 |
+
vector_store.add_documents(texts, metadatas, ids)
|
| 614 |
+
|
| 615 |
+
print(f"✅ Complete! Added {len(chunks)} chunks to knowledge base.")
|
| 616 |
+
|
| 617 |
+
return {
|
| 618 |
+
"status": "success",
|
| 619 |
+
"message": "Lecture transcribed, formatted, and added to knowledge base",
|
| 620 |
+
"transcription": {
|
| 621 |
+
"raw_text": raw_text,
|
| 622 |
+
"formatted_text": formatted_text,
|
| 623 |
+
"duration_seconds": round(duration, 2),
|
| 624 |
+
"word_count": len(raw_text.split()),
|
| 625 |
+
"segments_count": len(segments)
|
| 626 |
+
},
|
| 627 |
+
"rag_system": {
|
| 628 |
+
"chunks_added": len(chunks),
|
| 629 |
+
"document_name": Path(docx_path).name,
|
| 630 |
+
"document_path": str(docx_path)
|
| 631 |
+
},
|
| 632 |
+
"metadata": {
|
| 633 |
+
"institution_id": institution_id,
|
| 634 |
+
"course_id": course_id,
|
| 635 |
+
"lecture_title": lecture_title,
|
| 636 |
+
"teacher_id": teacher_id
|
| 637 |
+
},
|
| 638 |
+
"downloads": {
|
| 639 |
+
"docx": f"/speech/download/{Path(docx_path).name}"
|
| 640 |
+
}
|
| 641 |
+
}
|
| 642 |
+
|
| 643 |
+
except ValueError as ve:
|
| 644 |
+
raise HTTPException(status_code=400, detail=str(ve))
|
| 645 |
+
except Exception as e:
|
| 646 |
+
print(f"❌ Error in transcribe-and-upload: {str(e)}")
|
| 647 |
+
import traceback
|
| 648 |
+
traceback.print_exc()
|
| 649 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 650 |
+
|
| 651 |
+
|
| 652 |
+
@app.get("/speech/download/{filename}")
|
| 653 |
+
async def download_transcript(filename: str):
|
| 654 |
+
"""
|
| 655 |
+
Download formatted transcript (Markdown or DOCX)
|
| 656 |
+
"""
|
| 657 |
+
file_path = TRANSCRIPTS_DIR / filename
|
| 658 |
+
|
| 659 |
+
if not file_path.exists():
|
| 660 |
+
raise HTTPException(status_code=404, detail=f"File not found: {filename}")
|
| 661 |
+
|
| 662 |
+
# Determine media type
|
| 663 |
+
if filename.endswith('.docx'):
|
| 664 |
+
media_type = 'application/vnd.openxmlformats-officedocument.wordprocessingml.document'
|
| 665 |
+
elif filename.endswith('.md'):
|
| 666 |
+
media_type = 'text/markdown'
|
| 667 |
+
else:
|
| 668 |
+
media_type = 'application/octet-stream'
|
| 669 |
+
|
| 670 |
+
return FileResponse(
|
| 671 |
+
path=file_path,
|
| 672 |
+
filename=filename,
|
| 673 |
+
media_type=media_type
|
| 674 |
+
)
|
| 675 |
+
|
| 676 |
+
|
| 677 |
+
@app.get("/speech/transcripts")
|
| 678 |
+
def list_transcripts():
|
| 679 |
+
"""List all available transcripts"""
|
| 680 |
+
transcripts = []
|
| 681 |
+
|
| 682 |
+
for file_path in TRANSCRIPTS_DIR.glob("*"):
|
| 683 |
+
if file_path.is_file():
|
| 684 |
+
transcripts.append({
|
| 685 |
+
"filename": file_path.name,
|
| 686 |
+
"size_kb": round(file_path.stat().st_size / 1024, 2),
|
| 687 |
+
"format": file_path.suffix,
|
| 688 |
+
"created": file_path.stat().st_ctime
|
| 689 |
+
})
|
| 690 |
+
|
| 691 |
+
# Sort by creation time (newest first)
|
| 692 |
+
transcripts.sort(key=lambda x: x['created'], reverse=True)
|
| 693 |
+
|
| 694 |
+
return {
|
| 695 |
+
"status": "success",
|
| 696 |
+
"transcripts": transcripts,
|
| 697 |
+
"total": len(transcripts)
|
| 698 |
+
}
|
| 699 |
+
|
| 700 |
+
|
| 701 |
+
@app.get("/speech/audio-files")
|
| 702 |
+
def list_audio_files():
|
| 703 |
+
"""List all uploaded audio files"""
|
| 704 |
+
audio_files = []
|
| 705 |
+
|
| 706 |
+
for file_path in AUDIO_DIR.glob("*"):
|
| 707 |
+
if file_path.is_file():
|
| 708 |
+
audio_files.append({
|
| 709 |
+
"filename": file_path.name,
|
| 710 |
+
"size_mb": round(file_path.stat().st_size / (1024 * 1024), 2),
|
| 711 |
+
"format": file_path.suffix,
|
| 712 |
+
"created": file_path.stat().st_ctime
|
| 713 |
+
})
|
| 714 |
+
|
| 715 |
+
# Sort by creation time (newest first)
|
| 716 |
+
audio_files.sort(key=lambda x: x['created'], reverse=True)
|
| 717 |
+
|
| 718 |
+
return {
|
| 719 |
+
"status": "success",
|
| 720 |
+
"audio_files": audio_files,
|
| 721 |
+
"total": len(audio_files)
|
| 722 |
+
}
|
| 723 |
+
|
| 724 |
+
|
| 725 |
+
@app.delete("/speech/audio/{filename}")
|
| 726 |
+
def delete_audio(filename: str):
|
| 727 |
+
"""Delete audio file"""
|
| 728 |
+
try:
|
| 729 |
+
audio_path = AUDIO_DIR / filename
|
| 730 |
+
if audio_path.exists():
|
| 731 |
+
audio_path.unlink()
|
| 732 |
+
return {
|
| 733 |
+
"status": "success",
|
| 734 |
+
"message": f"Deleted audio file: {filename}"
|
| 735 |
+
}
|
| 736 |
+
else:
|
| 737 |
+
raise HTTPException(status_code=404, detail="Audio file not found")
|
| 738 |
+
except HTTPException:
|
| 739 |
+
raise
|
| 740 |
+
except Exception as e:
|
| 741 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 742 |
+
|
| 743 |
+
|
| 744 |
+
@app.delete("/speech/transcript/{filename}")
|
| 745 |
+
def delete_transcript(filename: str):
|
| 746 |
+
"""Delete transcript file"""
|
| 747 |
+
try:
|
| 748 |
+
transcript_path = TRANSCRIPTS_DIR / filename
|
| 749 |
+
if transcript_path.exists():
|
| 750 |
+
transcript_path.unlink()
|
| 751 |
+
return {
|
| 752 |
+
"status": "success",
|
| 753 |
+
"message": f"Deleted transcript: {filename}"
|
| 754 |
+
}
|
| 755 |
+
else:
|
| 756 |
+
raise HTTPException(status_code=404, detail="Transcript not found")
|
| 757 |
+
except HTTPException:
|
| 758 |
+
raise
|
| 759 |
+
except Exception as e:
|
| 760 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 761 |
+
|
| 762 |
+
|
| 763 |
+
# ============================================================================
|
| 764 |
+
# SERVER STARTUP
|
| 765 |
+
# ============================================================================
|
| 766 |
+
|
| 767 |
+
if __name__ == "__main__":
|
| 768 |
+
import uvicorn
|
| 769 |
+
print("\n" + "="*60)
|
| 770 |
+
print("🚀 Starting Cortexa AI Server with Voice-to-Text")
|
| 771 |
+
print("="*60)
|
| 772 |
+
|
| 773 |
+
uvicorn.run(
|
| 774 |
+
app,
|
| 775 |
+
host="0.0.0.0",
|
| 776 |
+
port=8000,
|
| 777 |
+
timeout_keep_alive=300, # 5 minutes for long audio processing
|
| 778 |
+
timeout_graceful_shutdown=30
|
| 779 |
+
)
|
config.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Configuration file for RAG system
|
| 3 |
+
"""
|
| 4 |
+
import torch
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
|
| 7 |
+
# Base paths
|
| 8 |
+
BASE_DIR = Path(__file__).parent
|
| 9 |
+
DATA_DIR = BASE_DIR / "data"
|
| 10 |
+
DOCUMENTS_DIR = DATA_DIR / "documents"
|
| 11 |
+
PROCESSED_DIR = DATA_DIR / "processed"
|
| 12 |
+
MODELS_DIR = BASE_DIR / "models_cache"
|
| 13 |
+
|
| 14 |
+
# NEW: Audio storage
|
| 15 |
+
AUDIO_DIR = DATA_DIR / "audio"
|
| 16 |
+
TRANSCRIPTS_DIR = DATA_DIR / "transcripts"
|
| 17 |
+
|
| 18 |
+
# Create directories if they don't exist
|
| 19 |
+
for dir_path in [DATA_DIR, DOCUMENTS_DIR, PROCESSED_DIR, MODELS_DIR, AUDIO_DIR, TRANSCRIPTS_DIR]:
|
| 20 |
+
dir_path.mkdir(parents=True, exist_ok=True)
|
| 21 |
+
|
| 22 |
+
# JSON storage file
|
| 23 |
+
EMBEDDINGS_JSON = PROCESSED_DIR / "embeddings_store.json"
|
| 24 |
+
|
| 25 |
+
# Model configurations
|
| 26 |
+
EMBEDDING_MODEL = "sentence-transformers/paraphrase-MiniLM-L3-v2" # 120 MB
|
| 27 |
+
LLM_MODEL = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" # 1.1 GB
|
| 28 |
+
WHISPER_MODEL = "base" # Options: tiny, base, small, medium, large
|
| 29 |
+
|
| 30 |
+
# Alternative faster models (uncomment to use):
|
| 31 |
+
# LLM_MODEL = "distilgpt2" # 350 MB - RECOMMENDED: 3-5x faster!
|
| 32 |
+
# LLM_MODEL = "gpt2" # 500 MB - 2x faster than TinyLlama
|
| 33 |
+
# NEW: Whisper model configuration
|
| 34 |
+
# Model sizes:
|
| 35 |
+
# - tiny: ~75MB, fastest
|
| 36 |
+
# - base: ~140MB, good balance (RECOMMENDED)
|
| 37 |
+
# - small: ~470MB, better accuracy
|
| 38 |
+
# - medium: ~1.5GB, high accuracy
|
| 39 |
+
# - large: ~3GB, best accuracy
|
| 40 |
+
|
| 41 |
+
# Chunking settings
|
| 42 |
+
CHUNK_SIZE = 512
|
| 43 |
+
CHUNK_OVERLAP = 50
|
| 44 |
+
MAX_CHUNKS_PER_DOC = 1000
|
| 45 |
+
|
| 46 |
+
# Retrieval settings
|
| 47 |
+
TOP_K = 3 # Reduced from 5 for faster retrieval
|
| 48 |
+
SIMILARITY_THRESHOLD = 0.3
|
| 49 |
+
|
| 50 |
+
# Generation settings
|
| 51 |
+
MAX_NEW_TOKENS = 256 # Reduced from 512 for faster generation
|
| 52 |
+
TEMPERATURE = 0.7
|
| 53 |
+
TOP_P = 0.9
|
| 54 |
+
|
| 55 |
+
# MCQ Generation settings (optimized for speed)
|
| 56 |
+
MCQ_MAX_TOKENS_PER_QUESTION = 150 # ~150 tokens per MCQ
|
| 57 |
+
MCQ_MAX_CONTEXT_LENGTH = 1000 # Shorter context = faster generation
|
| 58 |
+
|
| 59 |
+
# Audio/Transcription settings
|
| 60 |
+
MAX_AUDIO_SIZE_MB = 100 # Maximum audio file size
|
| 61 |
+
SUPPORTED_AUDIO_FORMATS = ['.wav', '.mp3', '.m4a', '.ogg', '.flac']
|
| 62 |
+
WHISPER_LANGUAGE = "en" # English only as per requirement
|
| 63 |
+
|
| 64 |
+
# Device settings
|
| 65 |
+
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 66 |
+
|
| 67 |
+
# Performance settings
|
| 68 |
+
USE_FAST_TOKENIZER = True
|
| 69 |
+
LOW_CPU_MEM_USAGE = True
|
hybrid/__init__.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Hybrid AI Assistant Module
|
| 3 |
+
"""
|
| 4 |
+
from .assistant import HybridAssistant, get_hybrid_assistant
|
| 5 |
+
from .web_search import WebSearcher, get_web_searcher
|
| 6 |
+
|
| 7 |
+
__all__ = ['HybridAssistant', 'get_hybrid_assistant', 'WebSearcher', 'get_web_searcher']
|
hybrid/assistant.py
ADDED
|
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Hybrid AI Assistant - RAG + Web Search
|
| 3 |
+
"""
|
| 4 |
+
from typing import List, Dict, Optional
|
| 5 |
+
from models.llm import get_llm_model
|
| 6 |
+
from rag.retriever import get_retriever
|
| 7 |
+
from hybrid.web_search import get_web_searcher
|
| 8 |
+
from config import SIMILARITY_THRESHOLD
|
| 9 |
+
|
| 10 |
+
class HybridAssistant:
|
| 11 |
+
def __init__(self):
|
| 12 |
+
self.llm = get_llm_model()
|
| 13 |
+
self.retriever = get_retriever()
|
| 14 |
+
self.web_searcher = get_web_searcher()
|
| 15 |
+
|
| 16 |
+
def answer(
|
| 17 |
+
self,
|
| 18 |
+
query: str,
|
| 19 |
+
use_web: bool = True,
|
| 20 |
+
min_similarity: float = SIMILARITY_THRESHOLD
|
| 21 |
+
) -> Dict:
|
| 22 |
+
"""
|
| 23 |
+
Answer query using RAG + Web fallback
|
| 24 |
+
|
| 25 |
+
Args:
|
| 26 |
+
query: User query
|
| 27 |
+
use_web: Whether to use web search as fallback
|
| 28 |
+
min_similarity: Minimum similarity for document retrieval
|
| 29 |
+
|
| 30 |
+
Returns:
|
| 31 |
+
Answer with sources and metadata
|
| 32 |
+
"""
|
| 33 |
+
print(f"\n🔍 Processing query: {query}")
|
| 34 |
+
|
| 35 |
+
# Step 1: Try RAG (local documents)
|
| 36 |
+
print("📚 Searching local documents...")
|
| 37 |
+
doc_results = self.retriever.retrieve(
|
| 38 |
+
query=query,
|
| 39 |
+
min_similarity=min_similarity
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
sources = []
|
| 43 |
+
answer = None
|
| 44 |
+
search_method = None
|
| 45 |
+
|
| 46 |
+
# Check if we have good document results
|
| 47 |
+
if doc_results and len(doc_results) > 0:
|
| 48 |
+
print(f"✓ Found {len(doc_results)} relevant documents")
|
| 49 |
+
|
| 50 |
+
# Generate answer from documents
|
| 51 |
+
context = self.retriever.format_context(doc_results)
|
| 52 |
+
answer = self._generate_answer(query, context, source_type="documents")
|
| 53 |
+
|
| 54 |
+
# Format sources
|
| 55 |
+
sources = [
|
| 56 |
+
{
|
| 57 |
+
'type': 'document',
|
| 58 |
+
'source': doc['source'],
|
| 59 |
+
'chunk_index': doc['chunk_index'],
|
| 60 |
+
'similarity': doc['similarity'],
|
| 61 |
+
'text_preview': doc['text'][:200]
|
| 62 |
+
}
|
| 63 |
+
for doc in doc_results
|
| 64 |
+
]
|
| 65 |
+
|
| 66 |
+
search_method = "rag"
|
| 67 |
+
|
| 68 |
+
# Step 2: Fallback to web search if no good docs found
|
| 69 |
+
elif use_web:
|
| 70 |
+
print("🌐 No relevant documents found. Searching the web...")
|
| 71 |
+
|
| 72 |
+
web_results = self.web_searcher.search(query, max_results=5)
|
| 73 |
+
|
| 74 |
+
if web_results:
|
| 75 |
+
print(f"✓ Found {len(web_results)} web results")
|
| 76 |
+
|
| 77 |
+
# Create context from web results
|
| 78 |
+
context = self._format_web_context(web_results)
|
| 79 |
+
answer = self._generate_answer(query, context, source_type="web")
|
| 80 |
+
|
| 81 |
+
# Format sources
|
| 82 |
+
sources = [
|
| 83 |
+
{
|
| 84 |
+
'type': 'web',
|
| 85 |
+
'title': result['title'],
|
| 86 |
+
'url': result['url'],
|
| 87 |
+
'snippet': result['snippet']
|
| 88 |
+
}
|
| 89 |
+
for result in web_results
|
| 90 |
+
]
|
| 91 |
+
|
| 92 |
+
search_method = "web"
|
| 93 |
+
else:
|
| 94 |
+
print("❌ No web results found")
|
| 95 |
+
answer = "I couldn't find relevant information to answer your question. Please try rephrasing or ask something else."
|
| 96 |
+
search_method = "none"
|
| 97 |
+
|
| 98 |
+
else:
|
| 99 |
+
answer = "I don't have enough information in my knowledge base to answer this question."
|
| 100 |
+
search_method = "none"
|
| 101 |
+
|
| 102 |
+
return {
|
| 103 |
+
'query': query,
|
| 104 |
+
'answer': answer,
|
| 105 |
+
'sources': sources,
|
| 106 |
+
'search_method': search_method,
|
| 107 |
+
'num_sources': len(sources)
|
| 108 |
+
}
|
| 109 |
+
|
| 110 |
+
def _format_web_context(self, web_results: List[Dict]) -> str:
|
| 111 |
+
"""Format web search results into context"""
|
| 112 |
+
context_parts = []
|
| 113 |
+
|
| 114 |
+
for i, result in enumerate(web_results, 1):
|
| 115 |
+
context_parts.append(
|
| 116 |
+
f"[Web Source {i}: {result['title']}]\n"
|
| 117 |
+
f"URL: {result['url']}\n"
|
| 118 |
+
f"{result['snippet']}\n"
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
return "\n".join(context_parts)
|
| 122 |
+
|
| 123 |
+
def _generate_answer(
|
| 124 |
+
self,
|
| 125 |
+
query: str,
|
| 126 |
+
context: str,
|
| 127 |
+
source_type: str
|
| 128 |
+
) -> str:
|
| 129 |
+
"""Generate answer from context"""
|
| 130 |
+
|
| 131 |
+
if source_type == "documents":
|
| 132 |
+
prompt = f"""You are a helpful AI assistant. Answer the question using ONLY the information from the provided context.
|
| 133 |
+
|
| 134 |
+
Context from uploaded documents:
|
| 135 |
+
{context}
|
| 136 |
+
|
| 137 |
+
Question: {query}
|
| 138 |
+
|
| 139 |
+
Instructions:
|
| 140 |
+
- Answer based on the context above
|
| 141 |
+
- Cite sources using [Source 1], [Source 2], etc.
|
| 142 |
+
- If the context doesn't fully answer the question, say so
|
| 143 |
+
- Be concise and accurate
|
| 144 |
+
|
| 145 |
+
Answer:"""
|
| 146 |
+
|
| 147 |
+
else: # web sources
|
| 148 |
+
prompt = f"""You are a helpful AI assistant. Answer the question using the information from web search results.
|
| 149 |
+
|
| 150 |
+
Web search results:
|
| 151 |
+
{context}
|
| 152 |
+
|
| 153 |
+
Question: {query}
|
| 154 |
+
|
| 155 |
+
Instructions:
|
| 156 |
+
- Synthesize information from the web sources
|
| 157 |
+
- Cite sources using [Web Source 1], [Web Source 2], etc.
|
| 158 |
+
- Provide accurate and helpful information
|
| 159 |
+
- Be concise
|
| 160 |
+
|
| 161 |
+
Answer:"""
|
| 162 |
+
|
| 163 |
+
response = self.llm.generate(
|
| 164 |
+
prompt=prompt,
|
| 165 |
+
max_new_tokens=512,
|
| 166 |
+
temperature=0.7
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
return response.strip()
|
| 170 |
+
|
| 171 |
+
# Singleton
|
| 172 |
+
_hybrid_assistant = None
|
| 173 |
+
|
| 174 |
+
def get_hybrid_assistant() -> HybridAssistant:
|
| 175 |
+
"""Get or create HybridAssistant instance"""
|
| 176 |
+
global _hybrid_assistant
|
| 177 |
+
if _hybrid_assistant is None:
|
| 178 |
+
_hybrid_assistant = HybridAssistant()
|
| 179 |
+
return _hybrid_assistant
|
hybrid/web_search.py
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Web search functionality
|
| 3 |
+
"""
|
| 4 |
+
from duckduckgo_search import DDGS
|
| 5 |
+
import requests
|
| 6 |
+
from bs4 import BeautifulSoup
|
| 7 |
+
from typing import List, Dict
|
| 8 |
+
import time
|
| 9 |
+
|
| 10 |
+
class WebSearcher:
|
| 11 |
+
def __init__(self):
|
| 12 |
+
self.ddgs = DDGS()
|
| 13 |
+
|
| 14 |
+
def search(self, query: str, max_results: int = 5) -> List[Dict]:
|
| 15 |
+
"""
|
| 16 |
+
Search the web and return results
|
| 17 |
+
|
| 18 |
+
Args:
|
| 19 |
+
query: Search query
|
| 20 |
+
max_results: Maximum number of results
|
| 21 |
+
|
| 22 |
+
Returns:
|
| 23 |
+
List of search results with title, snippet, link
|
| 24 |
+
"""
|
| 25 |
+
try:
|
| 26 |
+
results = []
|
| 27 |
+
|
| 28 |
+
# Search using DuckDuckGo
|
| 29 |
+
search_results = self.ddgs.text(query, max_results=max_results)
|
| 30 |
+
|
| 31 |
+
for i, result in enumerate(search_results):
|
| 32 |
+
results.append({
|
| 33 |
+
'title': result.get('title', 'No title'),
|
| 34 |
+
'snippet': result.get('body', 'No description'),
|
| 35 |
+
'url': result.get('href', ''),
|
| 36 |
+
'source_type': 'web',
|
| 37 |
+
'index': i
|
| 38 |
+
})
|
| 39 |
+
|
| 40 |
+
return results
|
| 41 |
+
|
| 42 |
+
except Exception as e:
|
| 43 |
+
print(f"Web search error: {e}")
|
| 44 |
+
return []
|
| 45 |
+
|
| 46 |
+
def get_page_content(self, url: str, max_chars: int = 1000) -> str:
|
| 47 |
+
"""
|
| 48 |
+
Fetch and extract text content from a web page
|
| 49 |
+
|
| 50 |
+
Args:
|
| 51 |
+
url: URL to fetch
|
| 52 |
+
max_chars: Maximum characters to extract
|
| 53 |
+
|
| 54 |
+
Returns:
|
| 55 |
+
Extracted text content
|
| 56 |
+
"""
|
| 57 |
+
try:
|
| 58 |
+
headers = {
|
| 59 |
+
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36'
|
| 60 |
+
}
|
| 61 |
+
|
| 62 |
+
response = requests.get(url, headers=headers, timeout=5)
|
| 63 |
+
response.raise_for_status()
|
| 64 |
+
|
| 65 |
+
soup = BeautifulSoup(response.content, 'html.parser')
|
| 66 |
+
|
| 67 |
+
# Remove script and style elements
|
| 68 |
+
for script in soup(["script", "style"]):
|
| 69 |
+
script.decompose()
|
| 70 |
+
|
| 71 |
+
# Get text
|
| 72 |
+
text = soup.get_text()
|
| 73 |
+
|
| 74 |
+
# Clean up text
|
| 75 |
+
lines = (line.strip() for line in text.splitlines())
|
| 76 |
+
chunks = (phrase.strip() for line in lines for phrase in line.split(" "))
|
| 77 |
+
text = ' '.join(chunk for chunk in chunks if chunk)
|
| 78 |
+
|
| 79 |
+
return text[:max_chars]
|
| 80 |
+
|
| 81 |
+
except Exception as e:
|
| 82 |
+
print(f"Error fetching {url}: {e}")
|
| 83 |
+
return ""
|
| 84 |
+
|
| 85 |
+
# Singleton
|
| 86 |
+
_web_searcher = None
|
| 87 |
+
|
| 88 |
+
def get_web_searcher() -> WebSearcher:
|
| 89 |
+
"""Get or create WebSearcher instance"""
|
| 90 |
+
global _web_searcher
|
| 91 |
+
if _web_searcher is None:
|
| 92 |
+
_web_searcher = WebSearcher()
|
| 93 |
+
return _web_searcher
|
main.py
ADDED
|
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Main script for testing RAG system
|
| 3 |
+
"""
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from typing import List
|
| 6 |
+
|
| 7 |
+
from config import DOCUMENTS_DIR
|
| 8 |
+
from vectordb.document_processor import DocumentProcessor
|
| 9 |
+
from vectordb.json_store import get_json_store # Changed
|
| 10 |
+
from rag.retriever import get_retriever
|
| 11 |
+
from rag.generator import get_generator
|
| 12 |
+
|
| 13 |
+
def load_documents(file_paths: List[str]):
|
| 14 |
+
"""Load documents into JSON store"""
|
| 15 |
+
print("\n" + "="*60)
|
| 16 |
+
print("LOADING DOCUMENTS")
|
| 17 |
+
print("="*60)
|
| 18 |
+
|
| 19 |
+
processor = DocumentProcessor()
|
| 20 |
+
vector_store = get_json_store()
|
| 21 |
+
|
| 22 |
+
for file_path in file_paths:
|
| 23 |
+
print(f"\nProcessing: {file_path}")
|
| 24 |
+
|
| 25 |
+
chunks = processor.process_document(file_path)
|
| 26 |
+
print(f"✓ Created {len(chunks)} chunks")
|
| 27 |
+
|
| 28 |
+
texts = [chunk.text for chunk in chunks]
|
| 29 |
+
metadatas = [chunk.metadata for chunk in chunks]
|
| 30 |
+
ids = [f"{Path(file_path).stem}_{i}" for i in range(len(chunks))]
|
| 31 |
+
|
| 32 |
+
vector_store.add_documents(texts, metadatas, ids)
|
| 33 |
+
|
| 34 |
+
stats = vector_store.get_stats()
|
| 35 |
+
print(f"\n✓ Total chunks in store: {stats['total_documents']}")
|
| 36 |
+
print(f"✓ JSON file size: {stats['file_size_mb']:.2f} MB")
|
| 37 |
+
|
| 38 |
+
# Export chunks only (without embeddings)
|
| 39 |
+
vector_store.export_chunks_only()
|
| 40 |
+
|
| 41 |
+
def query_system(query: str):
|
| 42 |
+
"""Query the RAG system"""
|
| 43 |
+
print("\n" + "="*60)
|
| 44 |
+
print(f"QUERY: {query}")
|
| 45 |
+
print("="*60)
|
| 46 |
+
|
| 47 |
+
retriever = get_retriever()
|
| 48 |
+
generator = get_generator()
|
| 49 |
+
|
| 50 |
+
print("\n🔍 Retrieving relevant documents...")
|
| 51 |
+
retrieved_docs = retriever.retrieve(query)
|
| 52 |
+
|
| 53 |
+
print(f"✓ Found {len(retrieved_docs)} relevant chunks")
|
| 54 |
+
for i, doc in enumerate(retrieved_docs, 1):
|
| 55 |
+
print(f"\n[{i}] {doc['source']} (Chunk {doc['chunk_index']}, Similarity: {doc['similarity']:.3f})")
|
| 56 |
+
print(f"Preview: {doc['text'][:150]}...")
|
| 57 |
+
|
| 58 |
+
print("\n💬 Generating response...")
|
| 59 |
+
context = retriever.format_context(retrieved_docs)
|
| 60 |
+
answer = generator.generate_response(query, context)
|
| 61 |
+
|
| 62 |
+
print("\n" + "-"*60)
|
| 63 |
+
print("ANSWER:")
|
| 64 |
+
print("-"*60)
|
| 65 |
+
print(answer)
|
| 66 |
+
print("-"*60)
|
| 67 |
+
|
| 68 |
+
def interactive_mode():
|
| 69 |
+
"""Interactive query mode"""
|
| 70 |
+
print("\n" + "="*60)
|
| 71 |
+
print("INTERACTIVE MODE")
|
| 72 |
+
print("="*60)
|
| 73 |
+
print("Commands:")
|
| 74 |
+
print(" - Type your question to query")
|
| 75 |
+
print(" - Type 'stats' to see store statistics")
|
| 76 |
+
print(" - Type 'quit' or 'exit' to stop")
|
| 77 |
+
print("="*60 + "\n")
|
| 78 |
+
|
| 79 |
+
vector_store = get_json_store()
|
| 80 |
+
|
| 81 |
+
while True:
|
| 82 |
+
query = input("\n💬 Your question: ").strip()
|
| 83 |
+
|
| 84 |
+
if query.lower() in ['quit', 'exit', 'q']:
|
| 85 |
+
print("Goodbye!")
|
| 86 |
+
break
|
| 87 |
+
|
| 88 |
+
if query.lower() == 'stats':
|
| 89 |
+
stats = vector_store.get_stats()
|
| 90 |
+
print("\n📊 Store Statistics:")
|
| 91 |
+
for key, value in stats.items():
|
| 92 |
+
print(f" {key}: {value}")
|
| 93 |
+
continue
|
| 94 |
+
|
| 95 |
+
if not query:
|
| 96 |
+
continue
|
| 97 |
+
|
| 98 |
+
query_system(query)
|
| 99 |
+
|
| 100 |
+
def main():
|
| 101 |
+
"""Main function"""
|
| 102 |
+
print("\n🚀 Cortexa RAG System (JSON Storage)")
|
| 103 |
+
print("="*60)
|
| 104 |
+
|
| 105 |
+
docs = list(DOCUMENTS_DIR.glob("*"))
|
| 106 |
+
docs = [d for d in docs if d.suffix in ['.pdf', '.txt', '.docx']]
|
| 107 |
+
|
| 108 |
+
if not docs:
|
| 109 |
+
print(f"\n⚠️ No documents found in {DOCUMENTS_DIR}")
|
| 110 |
+
print("Please add PDF, TXT, or DOCX files to the documents folder.")
|
| 111 |
+
return
|
| 112 |
+
|
| 113 |
+
print(f"\n📄 Found {len(docs)} documents:")
|
| 114 |
+
for doc in docs:
|
| 115 |
+
print(f" - {doc.name}")
|
| 116 |
+
|
| 117 |
+
load_choice = input("\nLoad documents into store? (y/n): ").strip().lower()
|
| 118 |
+
if load_choice == 'y':
|
| 119 |
+
load_documents([str(d) for d in docs])
|
| 120 |
+
|
| 121 |
+
print("\nStarting interactive query mode...")
|
| 122 |
+
interactive_mode()
|
| 123 |
+
|
| 124 |
+
if __name__ == "__main__":
|
| 125 |
+
main()
|
mcq/__init__.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
MCQ Generation Module
|
| 3 |
+
"""
|
| 4 |
+
from .generator import MCQGenerator, get_mcq_generator
|
| 5 |
+
from .validator import MCQValidator
|
| 6 |
+
|
| 7 |
+
__all__ = ['MCQGenerator', 'get_mcq_generator', 'MCQValidator']
|
mcq/generator.py
ADDED
|
@@ -0,0 +1,252 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
MCQ Generator using LLM
|
| 3 |
+
"""
|
| 4 |
+
import json
|
| 5 |
+
import re
|
| 6 |
+
from typing import List, Dict, Optional
|
| 7 |
+
from models.llm import get_llm_model
|
| 8 |
+
from vectordb.json_store import get_json_store
|
| 9 |
+
|
| 10 |
+
class MCQGenerator:
|
| 11 |
+
def __init__(self):
|
| 12 |
+
self.llm = get_llm_model()
|
| 13 |
+
self.vector_store = get_json_store()
|
| 14 |
+
|
| 15 |
+
def generate_from_text(
|
| 16 |
+
self,
|
| 17 |
+
text: str,
|
| 18 |
+
num_questions: int = 5,
|
| 19 |
+
difficulty: str = "medium",
|
| 20 |
+
topic: Optional[str] = None
|
| 21 |
+
) -> List[Dict]:
|
| 22 |
+
"""Generate MCQs from given text"""
|
| 23 |
+
prompt = self._create_mcq_prompt(text, num_questions, difficulty, topic)
|
| 24 |
+
|
| 25 |
+
# ⚡ Calculate tokens based on number of questions (avg 150 tokens per MCQ)
|
| 26 |
+
tokens_needed = min(num_questions * 150 + 100, 800) # Cap at 800 for speed
|
| 27 |
+
|
| 28 |
+
# Generate MCQs with higher temperature for creativity
|
| 29 |
+
response = self.llm.generate(
|
| 30 |
+
prompt=prompt,
|
| 31 |
+
max_new_tokens=tokens_needed, # Dynamic based on num_questions
|
| 32 |
+
temperature=0.8 # Balanced creativity
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
print(f"\n🤖 LLM Response:\n{response[:500]}...\n") # Debug
|
| 36 |
+
|
| 37 |
+
# Parse MCQs from response
|
| 38 |
+
mcqs = self._parse_mcqs_improved(response, text)
|
| 39 |
+
|
| 40 |
+
return mcqs
|
| 41 |
+
|
| 42 |
+
def generate_from_document(
|
| 43 |
+
self,
|
| 44 |
+
document_name: str,
|
| 45 |
+
num_questions: int = 5,
|
| 46 |
+
difficulty: str = "medium",
|
| 47 |
+
topic: Optional[str] = None
|
| 48 |
+
) -> List[Dict]:
|
| 49 |
+
"""Generate MCQs from a document in the vector store"""
|
| 50 |
+
chunks = self._get_document_chunks(document_name, num_chunks=15)
|
| 51 |
+
|
| 52 |
+
if not chunks:
|
| 53 |
+
raise ValueError(f"Document '{document_name}' not found in vector store")
|
| 54 |
+
|
| 55 |
+
text = "\n\n".join([chunk['text'] for chunk in chunks])
|
| 56 |
+
return self.generate_from_text(text, num_questions, difficulty, topic)
|
| 57 |
+
|
| 58 |
+
def generate_from_topic(
|
| 59 |
+
self,
|
| 60 |
+
topic: str,
|
| 61 |
+
num_questions: int = 5,
|
| 62 |
+
difficulty: str = "medium"
|
| 63 |
+
) -> List[Dict]:
|
| 64 |
+
"""Generate MCQs from a specific topic using vector search"""
|
| 65 |
+
# ⚡ Reduce search for speed - fewer documents = faster
|
| 66 |
+
documents, metadatas, distances = self.vector_store.search(
|
| 67 |
+
query=topic,
|
| 68 |
+
top_k=5 # Reduced from 15 for speed
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
if not documents:
|
| 72 |
+
raise ValueError(f"No content found for topic: {topic}")
|
| 73 |
+
|
| 74 |
+
# ⚡ Use top 3 most relevant (reduced from 5)
|
| 75 |
+
text = "\n\n".join(documents[:3])
|
| 76 |
+
return self.generate_from_text(text, num_questions, difficulty, topic)
|
| 77 |
+
|
| 78 |
+
def _create_mcq_prompt(
|
| 79 |
+
self,
|
| 80 |
+
text: str,
|
| 81 |
+
num_questions: int,
|
| 82 |
+
difficulty: str,
|
| 83 |
+
topic: Optional[str]
|
| 84 |
+
) -> str:
|
| 85 |
+
"""Create improved prompt for MCQ generation"""
|
| 86 |
+
|
| 87 |
+
topic_str = f" about {topic}" if topic else ""
|
| 88 |
+
|
| 89 |
+
# ⚡ Shorter text input = faster generation
|
| 90 |
+
max_text_length = 800 if num_questions <= 3 else 1200
|
| 91 |
+
|
| 92 |
+
# Simpler, clearer prompt
|
| 93 |
+
prompt = f"""Based on the following text, create {num_questions} multiple-choice questions{topic_str}.
|
| 94 |
+
|
| 95 |
+
TEXT:
|
| 96 |
+
{text[:max_text_length]}
|
| 97 |
+
|
| 98 |
+
Create exactly {num_questions} questions. For each question:
|
| 99 |
+
1. Write a clear question
|
| 100 |
+
2. Provide exactly 4 options labeled A, B, C, D
|
| 101 |
+
3. Mark which option is correct
|
| 102 |
+
4. Give a brief explanation
|
| 103 |
+
|
| 104 |
+
Example format:
|
| 105 |
+
|
| 106 |
+
Q1: What is the capital of France?
|
| 107 |
+
A. London
|
| 108 |
+
B. Paris
|
| 109 |
+
C. Berlin
|
| 110 |
+
D. Rome
|
| 111 |
+
ANSWER: B
|
| 112 |
+
EXPLANATION: Paris is the capital and largest city of France.
|
| 113 |
+
|
| 114 |
+
Q2: Which planet is known as the Red Planet?
|
| 115 |
+
A. Venus
|
| 116 |
+
B. Mars
|
| 117 |
+
C. Jupiter
|
| 118 |
+
D. Saturn
|
| 119 |
+
ANSWER: B
|
| 120 |
+
EXPLANATION: Mars appears reddish due to iron oxide on its surface.
|
| 121 |
+
|
| 122 |
+
Now create {num_questions} questions:
|
| 123 |
+
|
| 124 |
+
"""
|
| 125 |
+
return prompt
|
| 126 |
+
|
| 127 |
+
def _parse_mcqs_improved(self, response: str, context: str) -> List[Dict]:
|
| 128 |
+
"""Improved MCQ parsing with fallback"""
|
| 129 |
+
mcqs = []
|
| 130 |
+
|
| 131 |
+
# Try to find questions by Q1:, Q2:, etc.
|
| 132 |
+
question_pattern = r'Q\d+[:.]\s*(.+?)(?=Q\d+[:.|\n]|ANSWER:|$)'
|
| 133 |
+
questions = re.findall(question_pattern, response, re.DOTALL | re.IGNORECASE)
|
| 134 |
+
|
| 135 |
+
if not questions:
|
| 136 |
+
# Fallback: try numbered questions
|
| 137 |
+
question_pattern = r'(\d+[.)])\s*(.+?)(?=\d+[.)]|ANSWER:|$)'
|
| 138 |
+
questions = re.findall(question_pattern, response, re.DOTALL)
|
| 139 |
+
questions = [q[1] for q in questions] # Get just the text
|
| 140 |
+
|
| 141 |
+
# Parse each question block
|
| 142 |
+
for question_text in questions:
|
| 143 |
+
mcq = self._parse_question_block(question_text)
|
| 144 |
+
if mcq:
|
| 145 |
+
mcqs.append(mcq)
|
| 146 |
+
|
| 147 |
+
# If parsing failed, generate synthetic MCQs from context
|
| 148 |
+
if len(mcqs) == 0:
|
| 149 |
+
print("⚠️ Parsing failed, generating synthetic MCQs...")
|
| 150 |
+
mcqs = self._generate_synthetic_mcqs(context, 3)
|
| 151 |
+
|
| 152 |
+
return mcqs
|
| 153 |
+
|
| 154 |
+
def _parse_question_block(self, text: str) -> Optional[Dict]:
|
| 155 |
+
"""Parse a single question block"""
|
| 156 |
+
lines = [l.strip() for l in text.split('\n') if l.strip()]
|
| 157 |
+
|
| 158 |
+
question = None
|
| 159 |
+
options = {}
|
| 160 |
+
correct_answer = None
|
| 161 |
+
explanation = None
|
| 162 |
+
|
| 163 |
+
for i, line in enumerate(lines):
|
| 164 |
+
# Get question (first line)
|
| 165 |
+
if i == 0:
|
| 166 |
+
question = re.sub(r'^Q\d+[:.]\s*', '', line).strip()
|
| 167 |
+
continue
|
| 168 |
+
|
| 169 |
+
# Parse options (A. / A) / A:)
|
| 170 |
+
option_match = re.match(r'^([A-D])[.):\s]+(.+)', line, re.IGNORECASE)
|
| 171 |
+
if option_match:
|
| 172 |
+
letter = option_match.group(1).upper()
|
| 173 |
+
text = option_match.group(2).strip()
|
| 174 |
+
options[letter] = text
|
| 175 |
+
continue
|
| 176 |
+
|
| 177 |
+
# Parse answer
|
| 178 |
+
if 'answer' in line.lower():
|
| 179 |
+
answer_match = re.search(r'\b([A-D])\b', line, re.IGNORECASE)
|
| 180 |
+
if answer_match:
|
| 181 |
+
correct_answer = answer_match.group(1).upper()
|
| 182 |
+
continue
|
| 183 |
+
|
| 184 |
+
# Parse explanation
|
| 185 |
+
if 'explanation' in line.lower():
|
| 186 |
+
explanation = re.sub(r'^explanation[:\s]+', '', line, flags=re.IGNORECASE).strip()
|
| 187 |
+
|
| 188 |
+
# Validate
|
| 189 |
+
if question and len(options) >= 3 and correct_answer and correct_answer in options:
|
| 190 |
+
return {
|
| 191 |
+
'question': question,
|
| 192 |
+
'options': options,
|
| 193 |
+
'correct_answer': correct_answer,
|
| 194 |
+
'explanation': explanation or "Based on the provided context.",
|
| 195 |
+
'difficulty': 'medium'
|
| 196 |
+
}
|
| 197 |
+
|
| 198 |
+
return None
|
| 199 |
+
|
| 200 |
+
def _generate_synthetic_mcqs(self, text: str, num: int) -> List[Dict]:
|
| 201 |
+
"""Generate simple synthetic MCQs when parsing fails"""
|
| 202 |
+
# Extract key sentences
|
| 203 |
+
sentences = [s.strip() for s in text.split('.') if len(s.strip()) > 50][:num * 2]
|
| 204 |
+
|
| 205 |
+
mcqs = []
|
| 206 |
+
for i, sentence in enumerate(sentences[:num]):
|
| 207 |
+
# Create a simple MCQ from the sentence
|
| 208 |
+
words = sentence.split()
|
| 209 |
+
if len(words) < 5:
|
| 210 |
+
continue
|
| 211 |
+
|
| 212 |
+
# Create question by removing a key word
|
| 213 |
+
key_word = words[len(words)//2]
|
| 214 |
+
question_text = sentence.replace(key_word, "______")
|
| 215 |
+
|
| 216 |
+
mcq = {
|
| 217 |
+
'question': f"Fill in the blank: {question_text}",
|
| 218 |
+
'options': {
|
| 219 |
+
'A': key_word,
|
| 220 |
+
'B': f"Not {key_word}",
|
| 221 |
+
'C': "None of the above",
|
| 222 |
+
'D': "Cannot be determined"
|
| 223 |
+
},
|
| 224 |
+
'correct_answer': 'A',
|
| 225 |
+
'explanation': f"The correct term is '{key_word}' based on the context.",
|
| 226 |
+
'difficulty': 'easy'
|
| 227 |
+
}
|
| 228 |
+
mcqs.append(mcq)
|
| 229 |
+
|
| 230 |
+
return mcqs
|
| 231 |
+
|
| 232 |
+
def _get_document_chunks(self, document_name: str, num_chunks: int = 10) -> List[Dict]:
|
| 233 |
+
"""Get chunks from a specific document"""
|
| 234 |
+
matching_chunks = []
|
| 235 |
+
|
| 236 |
+
for doc in self.vector_store.data['documents']:
|
| 237 |
+
if document_name.lower() in doc['metadata'].get('source', '').lower():
|
| 238 |
+
matching_chunks.append({
|
| 239 |
+
'text': doc['text'],
|
| 240 |
+
'metadata': doc['metadata']
|
| 241 |
+
})
|
| 242 |
+
|
| 243 |
+
return matching_chunks[:num_chunks]
|
| 244 |
+
|
| 245 |
+
# Singleton
|
| 246 |
+
_mcq_generator = None
|
| 247 |
+
|
| 248 |
+
def get_mcq_generator() -> MCQGenerator:
|
| 249 |
+
global _mcq_generator
|
| 250 |
+
if _mcq_generator is None:
|
| 251 |
+
_mcq_generator = MCQGenerator()
|
| 252 |
+
return _mcq_generator
|
mcq/validator.py
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
MCQ Validator and Scorer
|
| 3 |
+
"""
|
| 4 |
+
from typing import List, Dict
|
| 5 |
+
|
| 6 |
+
class MCQValidator:
|
| 7 |
+
|
| 8 |
+
@staticmethod
|
| 9 |
+
def validate_mcq(mcq: Dict) -> bool:
|
| 10 |
+
"""
|
| 11 |
+
Validate if MCQ has all required fields
|
| 12 |
+
|
| 13 |
+
Args:
|
| 14 |
+
mcq: MCQ dictionary
|
| 15 |
+
|
| 16 |
+
Returns:
|
| 17 |
+
True if valid, False otherwise
|
| 18 |
+
"""
|
| 19 |
+
required_fields = ['question', 'options', 'correct_answer']
|
| 20 |
+
|
| 21 |
+
# Check required fields
|
| 22 |
+
if not all(field in mcq for field in required_fields):
|
| 23 |
+
return False
|
| 24 |
+
|
| 25 |
+
# Check options
|
| 26 |
+
if not isinstance(mcq['options'], dict):
|
| 27 |
+
return False
|
| 28 |
+
|
| 29 |
+
if len(mcq['options']) < 2:
|
| 30 |
+
return False
|
| 31 |
+
|
| 32 |
+
# Check correct answer
|
| 33 |
+
if mcq['correct_answer'] not in mcq['options']:
|
| 34 |
+
return False
|
| 35 |
+
|
| 36 |
+
return True
|
| 37 |
+
|
| 38 |
+
@staticmethod
|
| 39 |
+
def score_answers(
|
| 40 |
+
mcqs: List[Dict],
|
| 41 |
+
user_answers: Dict[int, str]
|
| 42 |
+
) -> Dict:
|
| 43 |
+
"""
|
| 44 |
+
Score user answers
|
| 45 |
+
|
| 46 |
+
Args:
|
| 47 |
+
mcqs: List of MCQs
|
| 48 |
+
user_answers: Dict mapping question index to user's answer
|
| 49 |
+
|
| 50 |
+
Returns:
|
| 51 |
+
Scoring result dictionary
|
| 52 |
+
"""
|
| 53 |
+
total_questions = len(mcqs)
|
| 54 |
+
correct_count = 0
|
| 55 |
+
results = []
|
| 56 |
+
|
| 57 |
+
for i, mcq in enumerate(mcqs):
|
| 58 |
+
user_answer = user_answers.get(i)
|
| 59 |
+
correct_answer = mcq['correct_answer']
|
| 60 |
+
is_correct = user_answer == correct_answer
|
| 61 |
+
|
| 62 |
+
if is_correct:
|
| 63 |
+
correct_count += 1
|
| 64 |
+
|
| 65 |
+
results.append({
|
| 66 |
+
'question_index': i,
|
| 67 |
+
'question': mcq['question'],
|
| 68 |
+
'user_answer': user_answer,
|
| 69 |
+
'correct_answer': correct_answer,
|
| 70 |
+
'is_correct': is_correct,
|
| 71 |
+
'explanation': mcq.get('explanation', '')
|
| 72 |
+
})
|
| 73 |
+
|
| 74 |
+
score_percentage = (correct_count / total_questions * 100) if total_questions > 0 else 0
|
| 75 |
+
|
| 76 |
+
return {
|
| 77 |
+
'total_questions': total_questions,
|
| 78 |
+
'correct_answers': correct_count,
|
| 79 |
+
'incorrect_answers': total_questions - correct_count,
|
| 80 |
+
'score_percentage': round(score_percentage, 2),
|
| 81 |
+
'grade': MCQValidator._calculate_grade(score_percentage),
|
| 82 |
+
'results': results
|
| 83 |
+
}
|
| 84 |
+
|
| 85 |
+
@staticmethod
|
| 86 |
+
def _calculate_grade(score: float) -> str:
|
| 87 |
+
"""Calculate letter grade from score"""
|
| 88 |
+
if score >= 90:
|
| 89 |
+
return 'A+'
|
| 90 |
+
elif score >= 80:
|
| 91 |
+
return 'A'
|
| 92 |
+
elif score >= 70:
|
| 93 |
+
return 'B'
|
| 94 |
+
elif score >= 60:
|
| 95 |
+
return 'C'
|
| 96 |
+
elif score >= 50:
|
| 97 |
+
return 'D'
|
| 98 |
+
else:
|
| 99 |
+
return 'F'
|
models/__init__.py
ADDED
|
File without changes
|
models/embeddings.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Embedding model for document and query vectorization
|
| 3 |
+
"""
|
| 4 |
+
import torch
|
| 5 |
+
from sentence_transformers import SentenceTransformer
|
| 6 |
+
from typing import List
|
| 7 |
+
import numpy as np
|
| 8 |
+
from config import EMBEDDING_MODEL, DEVICE, MODELS_DIR
|
| 9 |
+
|
| 10 |
+
class EmbeddingModel:
|
| 11 |
+
def __init__(self):
|
| 12 |
+
print(f"Loading embedding model: {EMBEDDING_MODEL}")
|
| 13 |
+
self.model = SentenceTransformer(
|
| 14 |
+
EMBEDDING_MODEL,
|
| 15 |
+
cache_folder=str(MODELS_DIR),
|
| 16 |
+
device=DEVICE
|
| 17 |
+
)
|
| 18 |
+
self.dimension = self.model.get_sentence_embedding_dimension()
|
| 19 |
+
print(f"✓ Embedding model loaded (dimension: {self.dimension})")
|
| 20 |
+
|
| 21 |
+
def encode(self, texts: List[str], batch_size: int = 32) -> np.ndarray:
|
| 22 |
+
"""
|
| 23 |
+
Encode texts into embeddings
|
| 24 |
+
|
| 25 |
+
Args:
|
| 26 |
+
texts: List of text strings
|
| 27 |
+
batch_size: Batch size for encoding
|
| 28 |
+
|
| 29 |
+
Returns:
|
| 30 |
+
Numpy array of embeddings
|
| 31 |
+
"""
|
| 32 |
+
if not texts:
|
| 33 |
+
return np.array([])
|
| 34 |
+
|
| 35 |
+
embeddings = self.model.encode(
|
| 36 |
+
texts,
|
| 37 |
+
batch_size=batch_size,
|
| 38 |
+
show_progress_bar=True,
|
| 39 |
+
convert_to_numpy=True,
|
| 40 |
+
normalize_embeddings=True # L2 normalization for cosine similarity
|
| 41 |
+
)
|
| 42 |
+
return embeddings
|
| 43 |
+
|
| 44 |
+
def encode_query(self, query: str) -> np.ndarray:
|
| 45 |
+
"""
|
| 46 |
+
Encode a single query
|
| 47 |
+
|
| 48 |
+
Args:
|
| 49 |
+
query: Query string
|
| 50 |
+
|
| 51 |
+
Returns:
|
| 52 |
+
Numpy array of embedding
|
| 53 |
+
"""
|
| 54 |
+
return self.model.encode(
|
| 55 |
+
query,
|
| 56 |
+
convert_to_numpy=True,
|
| 57 |
+
normalize_embeddings=True
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
# Singleton instance
|
| 61 |
+
_embedding_model = None
|
| 62 |
+
|
| 63 |
+
def get_embedding_model() -> EmbeddingModel:
|
| 64 |
+
"""Get or create embedding model instance"""
|
| 65 |
+
global _embedding_model
|
| 66 |
+
if _embedding_model is None:
|
| 67 |
+
_embedding_model = EmbeddingModel()
|
| 68 |
+
return _embedding_model
|
models/llm.py
ADDED
|
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Language model for text generation
|
| 3 |
+
"""
|
| 4 |
+
import torch
|
| 5 |
+
from transformers import (
|
| 6 |
+
AutoModelForCausalLM,
|
| 7 |
+
AutoTokenizer,
|
| 8 |
+
BitsAndBytesConfig,
|
| 9 |
+
pipeline
|
| 10 |
+
)
|
| 11 |
+
from typing import Optional
|
| 12 |
+
from config import LLM_MODEL, DEVICE, MODELS_DIR, MAX_NEW_TOKENS, TEMPERATURE, TOP_P
|
| 13 |
+
|
| 14 |
+
class LanguageModel:
|
| 15 |
+
def __init__(self):
|
| 16 |
+
print(f"Loading language model: {LLM_MODEL}")
|
| 17 |
+
|
| 18 |
+
# Quantization config for GPU (optional, only if you want smaller models)
|
| 19 |
+
quantization_config = None
|
| 20 |
+
|
| 21 |
+
# Only use quantization if on GPU
|
| 22 |
+
if DEVICE == "cuda":
|
| 23 |
+
try:
|
| 24 |
+
# Try 8-bit quantization (recommended)
|
| 25 |
+
quantization_config = BitsAndBytesConfig(
|
| 26 |
+
load_in_8bit=True,
|
| 27 |
+
llm_int8_threshold=6.0
|
| 28 |
+
)
|
| 29 |
+
print("Using 8-bit quantization")
|
| 30 |
+
except:
|
| 31 |
+
print("8-bit quantization not available, using full precision")
|
| 32 |
+
|
| 33 |
+
# Load tokenizer
|
| 34 |
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
| 35 |
+
LLM_MODEL,
|
| 36 |
+
cache_dir=str(MODELS_DIR),
|
| 37 |
+
trust_remote_code=True
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
# Set pad token if not set
|
| 41 |
+
if self.tokenizer.pad_token is None:
|
| 42 |
+
self.tokenizer.pad_token = self.tokenizer.eos_token
|
| 43 |
+
|
| 44 |
+
# Load model
|
| 45 |
+
self.model = AutoModelForCausalLM.from_pretrained(
|
| 46 |
+
LLM_MODEL,
|
| 47 |
+
cache_dir=str(MODELS_DIR),
|
| 48 |
+
quantization_config=quantization_config,
|
| 49 |
+
device_map="auto" if DEVICE == "cuda" else None,
|
| 50 |
+
torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32,
|
| 51 |
+
trust_remote_code=True
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
if DEVICE == "cpu":
|
| 55 |
+
self.model = self.model.to(DEVICE)
|
| 56 |
+
|
| 57 |
+
self.model.eval()
|
| 58 |
+
print(f"✓ Language model loaded on {DEVICE}")
|
| 59 |
+
|
| 60 |
+
def generate(
|
| 61 |
+
self,
|
| 62 |
+
prompt: str,
|
| 63 |
+
max_new_tokens: int = MAX_NEW_TOKENS,
|
| 64 |
+
temperature: float = TEMPERATURE,
|
| 65 |
+
top_p: float = TOP_P
|
| 66 |
+
) -> str:
|
| 67 |
+
"""
|
| 68 |
+
Generate text from prompt
|
| 69 |
+
|
| 70 |
+
Args:
|
| 71 |
+
prompt: Input prompt
|
| 72 |
+
max_new_tokens: Maximum tokens to generate
|
| 73 |
+
temperature: Sampling temperature
|
| 74 |
+
top_p: Top-p sampling
|
| 75 |
+
|
| 76 |
+
Returns:
|
| 77 |
+
Generated text
|
| 78 |
+
"""
|
| 79 |
+
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
|
| 80 |
+
|
| 81 |
+
with torch.no_grad():
|
| 82 |
+
outputs = self.model.generate(
|
| 83 |
+
**inputs,
|
| 84 |
+
max_new_tokens=max_new_tokens,
|
| 85 |
+
temperature=temperature,
|
| 86 |
+
top_p=top_p,
|
| 87 |
+
do_sample=True,
|
| 88 |
+
pad_token_id=self.tokenizer.pad_token_id,
|
| 89 |
+
eos_token_id=self.tokenizer.eos_token_id
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
# Decode and remove input prompt
|
| 93 |
+
generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 94 |
+
|
| 95 |
+
# Remove the input prompt from output
|
| 96 |
+
if generated_text.startswith(prompt):
|
| 97 |
+
generated_text = generated_text[len(prompt):].strip()
|
| 98 |
+
|
| 99 |
+
return generated_text
|
| 100 |
+
|
| 101 |
+
# Singleton instance
|
| 102 |
+
_llm_model = None
|
| 103 |
+
|
| 104 |
+
def get_llm_model() -> LanguageModel:
|
| 105 |
+
"""Get or create LLM instance"""
|
| 106 |
+
global _llm_model
|
| 107 |
+
if _llm_model is None:
|
| 108 |
+
_llm_model = LanguageModel()
|
| 109 |
+
return _llm_model
|
rag/__init__.py
ADDED
|
File without changes
|
rag/generator.py
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Response generation component
|
| 3 |
+
"""
|
| 4 |
+
from models.llm import get_llm_model
|
| 5 |
+
from rag.retriever import get_retriever
|
| 6 |
+
from typing import List, Dict
|
| 7 |
+
|
| 8 |
+
class ResponseGenerator:
|
| 9 |
+
def __init__(self):
|
| 10 |
+
self.llm = get_llm_model()
|
| 11 |
+
self.retriever = get_retriever()
|
| 12 |
+
|
| 13 |
+
def create_prompt(self, query: str, context: str) -> str:
|
| 14 |
+
"""
|
| 15 |
+
Create prompt for LLM with context and query
|
| 16 |
+
|
| 17 |
+
Args:
|
| 18 |
+
query: User query
|
| 19 |
+
context: Retrieved context
|
| 20 |
+
|
| 21 |
+
Returns:
|
| 22 |
+
Formatted prompt
|
| 23 |
+
"""
|
| 24 |
+
prompt = f"""You are a helpful AI assistant that answers questions based on the provided context.
|
| 25 |
+
|
| 26 |
+
Context Information:
|
| 27 |
+
{context}
|
| 28 |
+
|
| 29 |
+
Question: {query}
|
| 30 |
+
|
| 31 |
+
Instructions:
|
| 32 |
+
1. Answer the question using ONLY the information from the context above
|
| 33 |
+
2. If the context doesn't contain enough information, say "I don't have enough information to answer this question."
|
| 34 |
+
3. Cite the source numbers (e.g., [Source 1]) when providing information
|
| 35 |
+
4. Be concise and accurate
|
| 36 |
+
|
| 37 |
+
Answer:"""
|
| 38 |
+
|
| 39 |
+
return prompt
|
| 40 |
+
|
| 41 |
+
def generate_response(
|
| 42 |
+
self,
|
| 43 |
+
query: str,
|
| 44 |
+
context: str = None,
|
| 45 |
+
max_tokens: int = 512
|
| 46 |
+
) -> str:
|
| 47 |
+
"""
|
| 48 |
+
Generate response using LLM
|
| 49 |
+
|
| 50 |
+
Args:
|
| 51 |
+
query: User query
|
| 52 |
+
context: Retrieved context (optional, will retrieve if not provided)
|
| 53 |
+
max_tokens: Maximum tokens to generate
|
| 54 |
+
|
| 55 |
+
Returns:
|
| 56 |
+
Generated response
|
| 57 |
+
"""
|
| 58 |
+
# Retrieve context if not provided
|
| 59 |
+
if context is None:
|
| 60 |
+
retrieved_docs = self.retriever.retrieve(query)
|
| 61 |
+
context = self.retriever.format_context(retrieved_docs)
|
| 62 |
+
|
| 63 |
+
# Create prompt
|
| 64 |
+
prompt = self.create_prompt(query, context)
|
| 65 |
+
|
| 66 |
+
# Generate response
|
| 67 |
+
response = self.llm.generate(
|
| 68 |
+
prompt=prompt,
|
| 69 |
+
max_new_tokens=max_tokens
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
return response.strip()
|
| 73 |
+
|
| 74 |
+
# Singleton instance
|
| 75 |
+
_generator = None
|
| 76 |
+
|
| 77 |
+
def get_generator() -> ResponseGenerator:
|
| 78 |
+
"""Get or create ResponseGenerator instance"""
|
| 79 |
+
global _generator
|
| 80 |
+
if _generator is None:
|
| 81 |
+
_generator = ResponseGenerator()
|
| 82 |
+
return _generator
|
rag/retriever.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Document retrieval component
|
| 3 |
+
"""
|
| 4 |
+
from typing import List, Dict
|
| 5 |
+
from vectordb.json_store import get_json_store
|
| 6 |
+
from config import TOP_K, SIMILARITY_THRESHOLD
|
| 7 |
+
|
| 8 |
+
class DocumentRetriever:
|
| 9 |
+
def __init__(self):
|
| 10 |
+
self.vector_store = get_json_store()
|
| 11 |
+
|
| 12 |
+
def retrieve(
|
| 13 |
+
self,
|
| 14 |
+
query: str,
|
| 15 |
+
top_k: int = TOP_K,
|
| 16 |
+
filter_metadata: Dict = None,
|
| 17 |
+
min_similarity: float = SIMILARITY_THRESHOLD
|
| 18 |
+
) -> List[Dict]:
|
| 19 |
+
"""
|
| 20 |
+
Retrieve relevant documents for a query
|
| 21 |
+
"""
|
| 22 |
+
# Search vector store
|
| 23 |
+
documents, metadatas, distances = self.vector_store.search(
|
| 24 |
+
query=query,
|
| 25 |
+
top_k=top_k,
|
| 26 |
+
filter_metadata=filter_metadata
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
# Format results
|
| 30 |
+
results = []
|
| 31 |
+
for doc, metadata, distance in zip(documents, metadatas, distances):
|
| 32 |
+
similarity = 1 - distance
|
| 33 |
+
|
| 34 |
+
if similarity >= min_similarity:
|
| 35 |
+
results.append({
|
| 36 |
+
'text': doc,
|
| 37 |
+
'metadata': metadata,
|
| 38 |
+
'similarity': similarity,
|
| 39 |
+
'source': metadata.get('source', 'Unknown'),
|
| 40 |
+
'chunk_index': metadata.get('chunk_index', 0)
|
| 41 |
+
})
|
| 42 |
+
|
| 43 |
+
results.sort(key=lambda x: x['similarity'], reverse=True)
|
| 44 |
+
return results
|
| 45 |
+
|
| 46 |
+
def format_context(self, retrieved_docs: List[Dict]) -> str:
|
| 47 |
+
"""Format retrieved documents into context string"""
|
| 48 |
+
if not retrieved_docs:
|
| 49 |
+
return "No relevant information found."
|
| 50 |
+
|
| 51 |
+
context_parts = []
|
| 52 |
+
for i, doc in enumerate(retrieved_docs, 1):
|
| 53 |
+
source = doc['metadata'].get('source', 'Unknown')
|
| 54 |
+
chunk_idx = doc['metadata'].get('chunk_index', 0)
|
| 55 |
+
similarity = doc['similarity']
|
| 56 |
+
|
| 57 |
+
context_parts.append(
|
| 58 |
+
f"[Source {i}: {source}, Chunk {chunk_idx}, Relevance: {similarity:.2f}]\n"
|
| 59 |
+
f"{doc['text']}\n"
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
return "\n".join(context_parts)
|
| 63 |
+
|
| 64 |
+
_retriever = None
|
| 65 |
+
|
| 66 |
+
def get_retriever() -> DocumentRetriever:
|
| 67 |
+
global _retriever
|
| 68 |
+
if _retriever is None:
|
| 69 |
+
_retriever = DocumentRetriever()
|
| 70 |
+
return _retriever
|
requirements.txt
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Core dependencies
|
| 2 |
+
torch>=2.0.0
|
| 3 |
+
transformers>=4.30.0
|
| 4 |
+
sentence-transformers>=2.2.2
|
| 5 |
+
chromadb>=0.4.0
|
| 6 |
+
langchain>=0.1.0
|
| 7 |
+
pydantic>=2.0.0
|
| 8 |
+
fastapi>=0.100.0
|
| 9 |
+
uvicorn>=0.23.0
|
| 10 |
+
python-multipart>=0.0.6
|
| 11 |
+
|
| 12 |
+
# Document processing
|
| 13 |
+
PyPDF2>=3.0.0
|
| 14 |
+
pymupdf>=1.23.0
|
| 15 |
+
python-docx>=0.8.11
|
| 16 |
+
pdfplumber>=0.10.0
|
| 17 |
+
|
| 18 |
+
# Utilities
|
| 19 |
+
numpy<2
|
| 20 |
+
pandas>=2.0.0
|
| 21 |
+
tqdm>=4.65.0
|
| 22 |
+
python-dotenv>=1.0.0
|
| 23 |
+
|
| 24 |
+
# Optional but recommended
|
| 25 |
+
accelerate>=0.20.0
|
| 26 |
+
# bitsandbytes>=0.41.0 # For 8-bit quantization
|
| 27 |
+
|
| 28 |
+
# For AI Assistant
|
| 29 |
+
duckduckgo-search>=4.0.0
|
| 30 |
+
requests>=2.31.0
|
| 31 |
+
beautifulsoup4>=4.12.0
|
| 32 |
+
|
| 33 |
+
# For Voice-to-Text
|
| 34 |
+
openai-whisper>=20231117
|
| 35 |
+
|
| 36 |
+
# Using chocolatey
|
| 37 |
+
# choco install ffmpeg
|
speech/__init__.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Speech-to-Text module for lecture transcription
|
| 3 |
+
"""
|
| 4 |
+
from .transcriber import LectureTranscriber, get_transcriber
|
| 5 |
+
from .formatter import TextFormatter
|
| 6 |
+
from .audio_handler import SimpleAudioHandler as AudioHandler # Use simple version
|
| 7 |
+
|
| 8 |
+
__all__ = [
|
| 9 |
+
'LectureTranscriber',
|
| 10 |
+
'get_transcriber',
|
| 11 |
+
'TextFormatter',
|
| 12 |
+
'AudioHandler'
|
| 13 |
+
]
|
speech/audio_handler.py
ADDED
|
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Handle audio file operations
|
| 3 |
+
"""
|
| 4 |
+
import os
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from typing import Optional
|
| 7 |
+
import subprocess
|
| 8 |
+
from config import AUDIO_DIR, MAX_AUDIO_SIZE_MB, SUPPORTED_AUDIO_FORMATS
|
| 9 |
+
|
| 10 |
+
class AudioHandler:
|
| 11 |
+
"""Handle audio file processing and validation"""
|
| 12 |
+
|
| 13 |
+
@staticmethod
|
| 14 |
+
def validate_audio(file_path: str) -> bool:
|
| 15 |
+
"""
|
| 16 |
+
Validate audio file
|
| 17 |
+
|
| 18 |
+
Args:
|
| 19 |
+
file_path: Path to audio file
|
| 20 |
+
|
| 21 |
+
Returns:
|
| 22 |
+
True if valid
|
| 23 |
+
"""
|
| 24 |
+
path = Path(file_path)
|
| 25 |
+
|
| 26 |
+
# Check if file exists
|
| 27 |
+
if not path.exists():
|
| 28 |
+
raise FileNotFoundError(f"Audio file not found: {file_path}")
|
| 29 |
+
|
| 30 |
+
# Check file size
|
| 31 |
+
size_mb = path.stat().st_size / (1024 * 1024)
|
| 32 |
+
if size_mb > MAX_AUDIO_SIZE_MB:
|
| 33 |
+
raise ValueError(f"Audio file too large: {size_mb:.2f}MB > {MAX_AUDIO_SIZE_MB}MB")
|
| 34 |
+
|
| 35 |
+
# Check format
|
| 36 |
+
if path.suffix.lower() not in SUPPORTED_AUDIO_FORMATS:
|
| 37 |
+
raise ValueError(f"Unsupported format: {path.suffix}. Supported: {SUPPORTED_AUDIO_FORMATS}")
|
| 38 |
+
|
| 39 |
+
return True
|
| 40 |
+
|
| 41 |
+
@staticmethod
|
| 42 |
+
def get_audio_duration(file_path: str) -> float:
|
| 43 |
+
"""
|
| 44 |
+
Get audio duration in seconds using ffprobe (part of ffmpeg)
|
| 45 |
+
|
| 46 |
+
Args:
|
| 47 |
+
file_path: Path to audio file
|
| 48 |
+
|
| 49 |
+
Returns:
|
| 50 |
+
Duration in seconds
|
| 51 |
+
"""
|
| 52 |
+
try:
|
| 53 |
+
# Use ffprobe to get duration
|
| 54 |
+
result = subprocess.run(
|
| 55 |
+
[
|
| 56 |
+
'ffprobe',
|
| 57 |
+
'-v', 'error',
|
| 58 |
+
'-show_entries', 'format=duration',
|
| 59 |
+
'-of', 'default=noprint_wrappers=1:nokey=1',
|
| 60 |
+
file_path
|
| 61 |
+
],
|
| 62 |
+
stdout=subprocess.PIPE,
|
| 63 |
+
stderr=subprocess.PIPE,
|
| 64 |
+
text=True,
|
| 65 |
+
timeout=30
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
if result.returncode == 0:
|
| 69 |
+
duration = float(result.stdout.strip())
|
| 70 |
+
return duration
|
| 71 |
+
else:
|
| 72 |
+
# Fallback: estimate based on file size (very rough estimate)
|
| 73 |
+
print("⚠️ Could not get exact duration, using estimate")
|
| 74 |
+
return 0.0
|
| 75 |
+
|
| 76 |
+
except (subprocess.TimeoutExpired, FileNotFoundError, ValueError) as e:
|
| 77 |
+
print(f"⚠️ Could not determine audio duration: {e}")
|
| 78 |
+
# Return 0 if we can't determine duration
|
| 79 |
+
return 0.0
|
| 80 |
+
|
| 81 |
+
@staticmethod
|
| 82 |
+
def convert_to_wav(input_path: str, output_path: Optional[str] = None) -> str:
|
| 83 |
+
"""
|
| 84 |
+
Convert audio to WAV format using ffmpeg (optional, Whisper handles most formats)
|
| 85 |
+
|
| 86 |
+
Args:
|
| 87 |
+
input_path: Path to input audio
|
| 88 |
+
output_path: Optional output path
|
| 89 |
+
|
| 90 |
+
Returns:
|
| 91 |
+
Path to converted WAV file
|
| 92 |
+
"""
|
| 93 |
+
input_path = Path(input_path)
|
| 94 |
+
|
| 95 |
+
if output_path is None:
|
| 96 |
+
output_path = AUDIO_DIR / f"{input_path.stem}.wav"
|
| 97 |
+
|
| 98 |
+
print(f"🔄 Converting {input_path.name} to WAV...")
|
| 99 |
+
|
| 100 |
+
try:
|
| 101 |
+
# Use ffmpeg to convert
|
| 102 |
+
subprocess.run(
|
| 103 |
+
[
|
| 104 |
+
'ffmpeg',
|
| 105 |
+
'-i', str(input_path),
|
| 106 |
+
'-ar', '16000', # 16kHz sample rate (good for speech)
|
| 107 |
+
'-ac', '1', # Mono
|
| 108 |
+
'-y', # Overwrite output
|
| 109 |
+
str(output_path)
|
| 110 |
+
],
|
| 111 |
+
check=True,
|
| 112 |
+
stdout=subprocess.PIPE,
|
| 113 |
+
stderr=subprocess.PIPE,
|
| 114 |
+
timeout=300
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
print(f"✅ Converted to: {output_path}")
|
| 118 |
+
return str(output_path)
|
| 119 |
+
|
| 120 |
+
except subprocess.CalledProcessError as e:
|
| 121 |
+
print(f"❌ Conversion failed: {e}")
|
| 122 |
+
raise ValueError(f"Could not convert audio file: {e}")
|
| 123 |
+
except FileNotFoundError:
|
| 124 |
+
raise ValueError("FFmpeg not found. Please install FFmpeg to convert audio files.")
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
# Simplified version that doesn't require ffmpeg for basic validation
|
| 128 |
+
class SimpleAudioHandler:
|
| 129 |
+
"""Simplified audio handler without external dependencies"""
|
| 130 |
+
|
| 131 |
+
@staticmethod
|
| 132 |
+
def validate_audio(file_path: str) -> bool:
|
| 133 |
+
"""Basic validation without ffmpeg"""
|
| 134 |
+
path = Path(file_path)
|
| 135 |
+
|
| 136 |
+
if not path.exists():
|
| 137 |
+
raise FileNotFoundError(f"Audio file not found: {file_path}")
|
| 138 |
+
|
| 139 |
+
size_mb = path.stat().st_size / (1024 * 1024)
|
| 140 |
+
if size_mb > MAX_AUDIO_SIZE_MB:
|
| 141 |
+
raise ValueError(f"Audio file too large: {size_mb:.2f}MB > {MAX_AUDIO_SIZE_MB}MB")
|
| 142 |
+
|
| 143 |
+
if path.suffix.lower() not in SUPPORTED_AUDIO_FORMATS:
|
| 144 |
+
raise ValueError(f"Unsupported format: {path.suffix}. Supported: {SUPPORTED_AUDIO_FORMATS}")
|
| 145 |
+
|
| 146 |
+
return True
|
| 147 |
+
|
| 148 |
+
@staticmethod
|
| 149 |
+
def get_audio_duration(file_path: str) -> float:
|
| 150 |
+
"""Return 0.0 as we can't determine without external tools"""
|
| 151 |
+
return 0.0
|
| 152 |
+
|
| 153 |
+
@staticmethod
|
| 154 |
+
def convert_to_wav(input_path: str, output_path: Optional[str] = None) -> str:
|
| 155 |
+
"""No conversion, just return input path (Whisper handles most formats)"""
|
| 156 |
+
return str(input_path)
|
speech/formatter.py
ADDED
|
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Format transcribed text into structured documents
|
| 3 |
+
"""
|
| 4 |
+
from typing import Dict, List, Optional
|
| 5 |
+
from docx import Document
|
| 6 |
+
from docx.shared import Pt, Inches
|
| 7 |
+
from docx.enum.text import WD_ALIGN_PARAGRAPH
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from config import TRANSCRIPTS_DIR
|
| 10 |
+
|
| 11 |
+
class TextFormatter:
|
| 12 |
+
"""Format transcribed text into structured documents"""
|
| 13 |
+
|
| 14 |
+
def __init__(self):
|
| 15 |
+
"""Initialize formatter"""
|
| 16 |
+
pass
|
| 17 |
+
|
| 18 |
+
def format_as_structured_text(self, text: str, segments: List[Dict] = None) -> str:
|
| 19 |
+
"""
|
| 20 |
+
Format text with basic structure
|
| 21 |
+
|
| 22 |
+
Args:
|
| 23 |
+
text: Transcribed text
|
| 24 |
+
segments: Optional timestamp segments
|
| 25 |
+
|
| 26 |
+
Returns:
|
| 27 |
+
Formatted text with basic structure
|
| 28 |
+
"""
|
| 29 |
+
# Basic formatting without LLM (for now)
|
| 30 |
+
# Split into paragraphs based on pauses (sentences)
|
| 31 |
+
sentences = text.split('. ')
|
| 32 |
+
|
| 33 |
+
formatted_lines = []
|
| 34 |
+
formatted_lines.append("## Lecture Transcript\n")
|
| 35 |
+
|
| 36 |
+
# Group sentences into paragraphs (every 3-4 sentences)
|
| 37 |
+
paragraph = []
|
| 38 |
+
for i, sentence in enumerate(sentences):
|
| 39 |
+
sentence = sentence.strip()
|
| 40 |
+
if not sentence:
|
| 41 |
+
continue
|
| 42 |
+
|
| 43 |
+
paragraph.append(sentence)
|
| 44 |
+
|
| 45 |
+
# Create paragraph break every 3-4 sentences
|
| 46 |
+
if len(paragraph) >= 3 or i == len(sentences) - 1:
|
| 47 |
+
formatted_lines.append('. '.join(paragraph) + '.\n')
|
| 48 |
+
paragraph = []
|
| 49 |
+
|
| 50 |
+
return '\n'.join(formatted_lines)
|
| 51 |
+
|
| 52 |
+
def format_with_timestamps(self, segments: List[Dict]) -> str:
|
| 53 |
+
"""
|
| 54 |
+
Format text with timestamps for each segment
|
| 55 |
+
|
| 56 |
+
Args:
|
| 57 |
+
segments: List of segments with timestamps
|
| 58 |
+
|
| 59 |
+
Returns:
|
| 60 |
+
Formatted text with timestamps
|
| 61 |
+
"""
|
| 62 |
+
formatted = []
|
| 63 |
+
formatted.append("## Lecture Transcript (with timestamps)\n")
|
| 64 |
+
|
| 65 |
+
for seg in segments:
|
| 66 |
+
start_time = self._format_time(seg.get('start', 0))
|
| 67 |
+
end_time = self._format_time(seg.get('end', 0))
|
| 68 |
+
text = seg.get('text', '').strip()
|
| 69 |
+
|
| 70 |
+
formatted.append(f"**[{start_time} - {end_time}]**")
|
| 71 |
+
formatted.append(f"{text}\n")
|
| 72 |
+
|
| 73 |
+
return '\n'.join(formatted)
|
| 74 |
+
|
| 75 |
+
def _format_time(self, seconds: float) -> str:
|
| 76 |
+
"""Convert seconds to MM:SS format"""
|
| 77 |
+
minutes = int(seconds // 60)
|
| 78 |
+
secs = int(seconds % 60)
|
| 79 |
+
return f"{minutes:02d}:{secs:02d}"
|
| 80 |
+
|
| 81 |
+
def export_to_docx(
|
| 82 |
+
self,
|
| 83 |
+
text: str,
|
| 84 |
+
filename: str,
|
| 85 |
+
title: str = "Lecture Transcript",
|
| 86 |
+
segments: List[Dict] = None
|
| 87 |
+
) -> str:
|
| 88 |
+
"""
|
| 89 |
+
Export formatted text to DOCX document
|
| 90 |
+
|
| 91 |
+
Args:
|
| 92 |
+
text: Formatted text
|
| 93 |
+
filename: Output filename
|
| 94 |
+
title: Document title
|
| 95 |
+
segments: Optional timestamp segments
|
| 96 |
+
|
| 97 |
+
Returns:
|
| 98 |
+
Path to saved document
|
| 99 |
+
"""
|
| 100 |
+
doc = Document()
|
| 101 |
+
|
| 102 |
+
# Add title
|
| 103 |
+
title_para = doc.add_heading(title, level=0)
|
| 104 |
+
title_para.alignment = WD_ALIGN_PARAGRAPH.CENTER
|
| 105 |
+
|
| 106 |
+
# Add content
|
| 107 |
+
for line in text.split('\n'):
|
| 108 |
+
line = line.strip()
|
| 109 |
+
if not line:
|
| 110 |
+
continue
|
| 111 |
+
|
| 112 |
+
if line.startswith('## '):
|
| 113 |
+
doc.add_heading(line.replace('## ', ''), level=1)
|
| 114 |
+
elif line.startswith('### '):
|
| 115 |
+
doc.add_heading(line.replace('### ', ''), level=2)
|
| 116 |
+
elif line.startswith('**[') and ']**' in line:
|
| 117 |
+
# Timestamp line
|
| 118 |
+
doc.add_paragraph(line, style='Intense Quote')
|
| 119 |
+
else:
|
| 120 |
+
doc.add_paragraph(line)
|
| 121 |
+
|
| 122 |
+
# Save document
|
| 123 |
+
output_path = TRANSCRIPTS_DIR / f"{filename}.docx"
|
| 124 |
+
doc.save(output_path)
|
| 125 |
+
|
| 126 |
+
print(f"📄 Document saved: {output_path}")
|
| 127 |
+
return str(output_path)
|
| 128 |
+
|
| 129 |
+
def export_to_markdown(
|
| 130 |
+
self,
|
| 131 |
+
text: str,
|
| 132 |
+
filename: str,
|
| 133 |
+
title: str = "Lecture Transcript"
|
| 134 |
+
) -> str:
|
| 135 |
+
"""
|
| 136 |
+
Export formatted text to Markdown
|
| 137 |
+
|
| 138 |
+
Args:
|
| 139 |
+
text: Formatted text
|
| 140 |
+
filename: Output filename
|
| 141 |
+
title: Document title
|
| 142 |
+
|
| 143 |
+
Returns:
|
| 144 |
+
Path to saved document
|
| 145 |
+
"""
|
| 146 |
+
output_path = TRANSCRIPTS_DIR / f"{filename}.md"
|
| 147 |
+
|
| 148 |
+
with open(output_path, 'w', encoding='utf-8') as f:
|
| 149 |
+
f.write(f"# {title}\n\n")
|
| 150 |
+
f.write(text)
|
| 151 |
+
|
| 152 |
+
print(f"📝 Markdown saved: {output_path}")
|
| 153 |
+
return str(output_path)
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
# Optional: Advanced formatter with LLM (if you want to add later)
|
| 157 |
+
class AdvancedTextFormatter(TextFormatter):
|
| 158 |
+
"""Format with LLM for better structure detection"""
|
| 159 |
+
|
| 160 |
+
def __init__(self):
|
| 161 |
+
"""Initialize with LLM"""
|
| 162 |
+
super().__init__()
|
| 163 |
+
try:
|
| 164 |
+
from rag.generator import get_generator
|
| 165 |
+
self.generator = get_generator()
|
| 166 |
+
self.use_llm = True
|
| 167 |
+
except Exception as e:
|
| 168 |
+
print(f"⚠️ LLM not available for formatting: {e}")
|
| 169 |
+
self.use_llm = False
|
| 170 |
+
|
| 171 |
+
def format_as_structured_text(self, text: str, segments: List[Dict] = None) -> str:
|
| 172 |
+
"""Format with LLM if available, otherwise use basic formatting"""
|
| 173 |
+
|
| 174 |
+
if not self.use_llm:
|
| 175 |
+
return super().format_as_structured_text(text, segments)
|
| 176 |
+
|
| 177 |
+
# Use LLM to detect structure
|
| 178 |
+
prompt = f"""Format this lecture transcript with headings and structure.
|
| 179 |
+
|
| 180 |
+
Rules:
|
| 181 |
+
1. Add main headings (##) for major topics
|
| 182 |
+
2. Add subheadings (###) for subtopics
|
| 183 |
+
3. Keep original text
|
| 184 |
+
4. Organize into paragraphs
|
| 185 |
+
|
| 186 |
+
Transcript:
|
| 187 |
+
{text[:2000]}
|
| 188 |
+
|
| 189 |
+
Formatted:"""
|
| 190 |
+
|
| 191 |
+
try:
|
| 192 |
+
context = "" # No context needed
|
| 193 |
+
formatted = self.generator.generate_response(prompt, context)
|
| 194 |
+
return formatted
|
| 195 |
+
except Exception as e:
|
| 196 |
+
print(f"⚠️ LLM formatting failed: {e}. Using basic formatting.")
|
| 197 |
+
return super().format_as_structured_text(text, segments)
|
speech/transcriber.py
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Whisper-based transcription for lecture audio
|
| 3 |
+
"""
|
| 4 |
+
import whisper
|
| 5 |
+
import torch
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from typing import Dict, List, Optional
|
| 8 |
+
from config import WHISPER_MODEL, DEVICE, WHISPER_LANGUAGE
|
| 9 |
+
|
| 10 |
+
class LectureTranscriber:
|
| 11 |
+
"""Transcribe audio using OpenAI Whisper"""
|
| 12 |
+
|
| 13 |
+
def __init__(self, model_name: str = WHISPER_MODEL):
|
| 14 |
+
"""
|
| 15 |
+
Initialize Whisper model
|
| 16 |
+
|
| 17 |
+
Args:
|
| 18 |
+
model_name: Whisper model size (tiny, base, small, medium, large)
|
| 19 |
+
"""
|
| 20 |
+
print(f"🎙️ Loading Whisper model '{model_name}'...")
|
| 21 |
+
self.model = whisper.load_model(model_name, device=DEVICE)
|
| 22 |
+
self.language = WHISPER_LANGUAGE
|
| 23 |
+
print(f"✅ Whisper model loaded on {DEVICE}")
|
| 24 |
+
|
| 25 |
+
def transcribe_audio(
|
| 26 |
+
self,
|
| 27 |
+
audio_path: str,
|
| 28 |
+
language: Optional[str] = None,
|
| 29 |
+
include_timestamps: bool = True
|
| 30 |
+
) -> Dict:
|
| 31 |
+
"""
|
| 32 |
+
Transcribe audio file to text
|
| 33 |
+
|
| 34 |
+
Args:
|
| 35 |
+
audio_path: Path to audio file
|
| 36 |
+
language: Language code (default: 'en')
|
| 37 |
+
include_timestamps: Include word-level timestamps
|
| 38 |
+
|
| 39 |
+
Returns:
|
| 40 |
+
Dict with transcription results
|
| 41 |
+
"""
|
| 42 |
+
try:
|
| 43 |
+
print(f"🎧 Transcribing: {Path(audio_path).name}")
|
| 44 |
+
|
| 45 |
+
result = self.model.transcribe(
|
| 46 |
+
audio_path,
|
| 47 |
+
language=language or self.language,
|
| 48 |
+
task="transcribe",
|
| 49 |
+
verbose=False,
|
| 50 |
+
word_timestamps=include_timestamps
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
print(f"✅ Transcription complete!")
|
| 54 |
+
|
| 55 |
+
return {
|
| 56 |
+
"text": result["text"].strip(),
|
| 57 |
+
"segments": result.get("segments", []),
|
| 58 |
+
"language": result.get("language", language or self.language),
|
| 59 |
+
"duration": self._calculate_duration(result.get("segments", []))
|
| 60 |
+
}
|
| 61 |
+
|
| 62 |
+
except Exception as e:
|
| 63 |
+
print(f"❌ Transcription failed: {str(e)}")
|
| 64 |
+
raise
|
| 65 |
+
|
| 66 |
+
def transcribe_with_timestamps(self, audio_path: str) -> List[Dict]:
|
| 67 |
+
"""
|
| 68 |
+
Transcribe with detailed timestamps for each segment
|
| 69 |
+
|
| 70 |
+
Args:
|
| 71 |
+
audio_path: Path to audio file
|
| 72 |
+
|
| 73 |
+
Returns:
|
| 74 |
+
List of segments with timestamps
|
| 75 |
+
"""
|
| 76 |
+
result = self.transcribe_audio(audio_path, include_timestamps=True)
|
| 77 |
+
|
| 78 |
+
segments = []
|
| 79 |
+
for seg in result.get("segments", []):
|
| 80 |
+
segments.append({
|
| 81 |
+
"start": seg.get("start", 0),
|
| 82 |
+
"end": seg.get("end", 0),
|
| 83 |
+
"text": seg.get("text", "").strip()
|
| 84 |
+
})
|
| 85 |
+
|
| 86 |
+
return segments
|
| 87 |
+
|
| 88 |
+
def _calculate_duration(self, segments: List[Dict]) -> float:
|
| 89 |
+
"""Calculate total audio duration from segments"""
|
| 90 |
+
if not segments:
|
| 91 |
+
return 0.0
|
| 92 |
+
return segments[-1].get("end", 0)
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
# Global instance for lazy loading
|
| 96 |
+
_transcriber = None
|
| 97 |
+
|
| 98 |
+
def get_transcriber(model_name: str = WHISPER_MODEL) -> LectureTranscriber:
|
| 99 |
+
"""Get or create transcriber instance"""
|
| 100 |
+
global _transcriber
|
| 101 |
+
if _transcriber is None:
|
| 102 |
+
_transcriber = LectureTranscriber(model_name)
|
| 103 |
+
return _transcriber
|
tests/test_mcq.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Test MCQ Generation
|
| 3 |
+
"""
|
| 4 |
+
import sys
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
sys.path.insert(0, str(Path(__file__).parent.parent))
|
| 7 |
+
|
| 8 |
+
from mcq.generator import get_mcq_generator
|
| 9 |
+
|
| 10 |
+
def test_generate_from_topic():
|
| 11 |
+
print("\n🧪 Testing MCQ Generation from Topic...")
|
| 12 |
+
|
| 13 |
+
generator = get_mcq_generator()
|
| 14 |
+
|
| 15 |
+
# Generate MCQs about Big Data
|
| 16 |
+
mcqs = generator.generate_from_topic(
|
| 17 |
+
topic="Big Data Analytics",
|
| 18 |
+
num_questions=3,
|
| 19 |
+
difficulty="medium"
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
print(f"\n✓ Generated {len(mcqs)} MCQs\n")
|
| 23 |
+
|
| 24 |
+
for i, mcq in enumerate(mcqs, 1):
|
| 25 |
+
print(f"Question {i}: {mcq['question']}")
|
| 26 |
+
for letter, option in mcq['options'].items():
|
| 27 |
+
print(f" {letter}) {option}")
|
| 28 |
+
print(f" ✓ Correct Answer: {mcq['correct_answer']}")
|
| 29 |
+
print(f" 📝 Explanation: {mcq['explanation']}\n")
|
| 30 |
+
|
| 31 |
+
if __name__ == "__main__":
|
| 32 |
+
test_generate_from_topic()
|
tests/test_rag.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Test script for RAG system
|
| 3 |
+
"""
|
| 4 |
+
import unittest
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
import sys
|
| 7 |
+
|
| 8 |
+
# Add parent directory to path
|
| 9 |
+
sys.path.insert(0, str(Path(__file__).parent.parent))
|
| 10 |
+
|
| 11 |
+
from models.embeddings import get_embedding_model
|
| 12 |
+
from models.llm import get_llm_model
|
| 13 |
+
from vectordb.document_processor import DocumentProcessor
|
| 14 |
+
# from vectordb.chroma_store import get_chroma_store
|
| 15 |
+
from rag.retriever import get_retriever
|
| 16 |
+
from rag.generator import get_generator
|
| 17 |
+
|
| 18 |
+
class TestRAGSystem(unittest.TestCase):
|
| 19 |
+
|
| 20 |
+
def test_embeddings(self):
|
| 21 |
+
"""Test embedding model"""
|
| 22 |
+
print("\n🧪 Testing embedding model...")
|
| 23 |
+
model = get_embedding_model()
|
| 24 |
+
|
| 25 |
+
texts = ["This is a test", "Another test sentence"]
|
| 26 |
+
embeddings = model.encode(texts)
|
| 27 |
+
|
| 28 |
+
self.assertEqual(len(embeddings), 2)
|
| 29 |
+
self.assertEqual(embeddings.shape[1], model.dimension)
|
| 30 |
+
print("✓ Embeddings test passed")
|
| 31 |
+
|
| 32 |
+
def test_document_processor(self):
|
| 33 |
+
"""Test document processing"""
|
| 34 |
+
print("\n🧪 Testing document processor...")
|
| 35 |
+
processor = DocumentProcessor()
|
| 36 |
+
|
| 37 |
+
text = "This is a test document. " * 100
|
| 38 |
+
chunks = processor.chunk_text(text, chunk_size=100, overlap=20)
|
| 39 |
+
|
| 40 |
+
self.assertGreater(len(chunks), 0)
|
| 41 |
+
print(f"✓ Created {len(chunks)} chunks")
|
| 42 |
+
|
| 43 |
+
def test_retrieval(self):
|
| 44 |
+
"""Test document retrieval"""
|
| 45 |
+
print("\n🧪 Testing retrieval...")
|
| 46 |
+
retriever = get_retriever()
|
| 47 |
+
|
| 48 |
+
query = "test query"
|
| 49 |
+
results = retriever.retrieve(query, top_k=3)
|
| 50 |
+
|
| 51 |
+
self.assertIsInstance(results, list)
|
| 52 |
+
print(f"✓ Retrieved {len(results)} documents")
|
| 53 |
+
|
| 54 |
+
def test_generation(self):
|
| 55 |
+
"""Test response generation"""
|
| 56 |
+
print("\n🧪 Testing generation...")
|
| 57 |
+
generator = get_generator()
|
| 58 |
+
|
| 59 |
+
query = "What is machine learning?"
|
| 60 |
+
context = "Machine learning is a subset of artificial intelligence."
|
| 61 |
+
|
| 62 |
+
response = generator.generate_response(query, context, max_tokens=50)
|
| 63 |
+
|
| 64 |
+
self.assertIsInstance(response, str)
|
| 65 |
+
self.assertGreater(len(response), 0)
|
| 66 |
+
print(f"✓ Generated response: {response[:100]}...")
|
| 67 |
+
|
| 68 |
+
if __name__ == "__main__":
|
| 69 |
+
unittest.main(verbosity=2)
|
vectordb/__init__.py
ADDED
|
File without changes
|
vectordb/document_processor.py
ADDED
|
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Document processing and chunking
|
| 3 |
+
"""
|
| 4 |
+
import os
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from typing import List, Dict
|
| 7 |
+
import PyPDF2
|
| 8 |
+
import pdfplumber
|
| 9 |
+
from docx import Document
|
| 10 |
+
from config import CHUNK_SIZE, CHUNK_OVERLAP
|
| 11 |
+
|
| 12 |
+
class DocumentChunk:
|
| 13 |
+
def __init__(
|
| 14 |
+
self,
|
| 15 |
+
text: str,
|
| 16 |
+
metadata: Dict,
|
| 17 |
+
chunk_id: int
|
| 18 |
+
):
|
| 19 |
+
self.text = text
|
| 20 |
+
self.metadata = metadata
|
| 21 |
+
self.chunk_id = chunk_id
|
| 22 |
+
|
| 23 |
+
class DocumentProcessor:
|
| 24 |
+
def __init__(self):
|
| 25 |
+
self.supported_formats = ['.pdf', '.txt', '.docx']
|
| 26 |
+
|
| 27 |
+
def load_document(self, file_path: str) -> str:
|
| 28 |
+
"""Load document content based on file type"""
|
| 29 |
+
path = Path(file_path)
|
| 30 |
+
|
| 31 |
+
if not path.exists():
|
| 32 |
+
raise FileNotFoundError(f"File not found: {file_path}")
|
| 33 |
+
|
| 34 |
+
ext = path.suffix.lower()
|
| 35 |
+
|
| 36 |
+
if ext == '.pdf':
|
| 37 |
+
return self._load_pdf(file_path)
|
| 38 |
+
elif ext == '.txt':
|
| 39 |
+
return self._load_txt(file_path)
|
| 40 |
+
elif ext == '.docx':
|
| 41 |
+
return self._load_docx(file_path)
|
| 42 |
+
else:
|
| 43 |
+
raise ValueError(f"Unsupported file format: {ext}")
|
| 44 |
+
|
| 45 |
+
def _load_pdf(self, file_path: str) -> str:
|
| 46 |
+
"""Extract text from PDF"""
|
| 47 |
+
text = ""
|
| 48 |
+
try:
|
| 49 |
+
# Try pdfplumber first (better for tables)
|
| 50 |
+
with pdfplumber.open(file_path) as pdf:
|
| 51 |
+
for page in pdf.pages:
|
| 52 |
+
page_text = page.extract_text()
|
| 53 |
+
if page_text:
|
| 54 |
+
text += page_text + "\n"
|
| 55 |
+
except:
|
| 56 |
+
# Fallback to PyPDF2
|
| 57 |
+
with open(file_path, 'rb') as file:
|
| 58 |
+
pdf_reader = PyPDF2.PdfReader(file)
|
| 59 |
+
for page in pdf_reader.pages:
|
| 60 |
+
text += page.extract_text() + "\n"
|
| 61 |
+
|
| 62 |
+
return text.strip()
|
| 63 |
+
|
| 64 |
+
def _load_txt(self, file_path: str) -> str:
|
| 65 |
+
"""Load text file"""
|
| 66 |
+
with open(file_path, 'r', encoding='utf-8') as file:
|
| 67 |
+
return file.read()
|
| 68 |
+
|
| 69 |
+
def _load_docx(self, file_path: str) -> str:
|
| 70 |
+
"""Extract text from DOCX"""
|
| 71 |
+
doc = Document(file_path)
|
| 72 |
+
text = "\n".join([para.text for para in doc.paragraphs])
|
| 73 |
+
return text
|
| 74 |
+
|
| 75 |
+
def chunk_text(
|
| 76 |
+
self,
|
| 77 |
+
text: str,
|
| 78 |
+
chunk_size: int = CHUNK_SIZE,
|
| 79 |
+
overlap: int = CHUNK_OVERLAP
|
| 80 |
+
) -> List[str]:
|
| 81 |
+
"""
|
| 82 |
+
Split text into overlapping chunks
|
| 83 |
+
|
| 84 |
+
Args:
|
| 85 |
+
text: Input text
|
| 86 |
+
chunk_size: Maximum chunk size in characters
|
| 87 |
+
overlap: Overlap between chunks
|
| 88 |
+
|
| 89 |
+
Returns:
|
| 90 |
+
List of text chunks
|
| 91 |
+
"""
|
| 92 |
+
if not text:
|
| 93 |
+
return []
|
| 94 |
+
|
| 95 |
+
# Split by sentences first (simple approach)
|
| 96 |
+
sentences = text.replace('\n', ' ').split('. ')
|
| 97 |
+
|
| 98 |
+
chunks = []
|
| 99 |
+
current_chunk = ""
|
| 100 |
+
|
| 101 |
+
for sentence in sentences:
|
| 102 |
+
sentence = sentence.strip() + ". "
|
| 103 |
+
|
| 104 |
+
# If adding this sentence exceeds chunk size
|
| 105 |
+
if len(current_chunk) + len(sentence) > chunk_size:
|
| 106 |
+
if current_chunk:
|
| 107 |
+
chunks.append(current_chunk.strip())
|
| 108 |
+
# Start new chunk with overlap
|
| 109 |
+
words = current_chunk.split()
|
| 110 |
+
overlap_words = words[-overlap:] if len(words) > overlap else words
|
| 111 |
+
current_chunk = " ".join(overlap_words) + " " + sentence
|
| 112 |
+
else:
|
| 113 |
+
# Sentence itself is longer than chunk_size
|
| 114 |
+
chunks.append(sentence[:chunk_size])
|
| 115 |
+
current_chunk = sentence[chunk_size:]
|
| 116 |
+
else:
|
| 117 |
+
current_chunk += sentence
|
| 118 |
+
|
| 119 |
+
# Add last chunk
|
| 120 |
+
if current_chunk:
|
| 121 |
+
chunks.append(current_chunk.strip())
|
| 122 |
+
|
| 123 |
+
return chunks
|
| 124 |
+
|
| 125 |
+
def process_document(
|
| 126 |
+
self,
|
| 127 |
+
file_path: str,
|
| 128 |
+
metadata: Dict = None
|
| 129 |
+
) -> List[DocumentChunk]:
|
| 130 |
+
"""
|
| 131 |
+
Process document into chunks with metadata
|
| 132 |
+
|
| 133 |
+
Args:
|
| 134 |
+
file_path: Path to document
|
| 135 |
+
metadata: Additional metadata
|
| 136 |
+
|
| 137 |
+
Returns:
|
| 138 |
+
List of DocumentChunk objects
|
| 139 |
+
"""
|
| 140 |
+
# Load document
|
| 141 |
+
text = self.load_document(file_path)
|
| 142 |
+
|
| 143 |
+
# Create metadata
|
| 144 |
+
file_metadata = {
|
| 145 |
+
'source': str(Path(file_path).name),
|
| 146 |
+
'file_path': str(file_path),
|
| 147 |
+
'file_type': Path(file_path).suffix,
|
| 148 |
+
'total_chars': len(text)
|
| 149 |
+
}
|
| 150 |
+
|
| 151 |
+
if metadata:
|
| 152 |
+
file_metadata.update(metadata)
|
| 153 |
+
|
| 154 |
+
# Chunk text
|
| 155 |
+
chunks = self.chunk_text(text)
|
| 156 |
+
|
| 157 |
+
# Create DocumentChunk objects
|
| 158 |
+
doc_chunks = []
|
| 159 |
+
for i, chunk in enumerate(chunks):
|
| 160 |
+
chunk_metadata = file_metadata.copy()
|
| 161 |
+
chunk_metadata['chunk_index'] = i
|
| 162 |
+
chunk_metadata['total_chunks'] = len(chunks)
|
| 163 |
+
|
| 164 |
+
doc_chunks.append(
|
| 165 |
+
DocumentChunk(
|
| 166 |
+
text=chunk,
|
| 167 |
+
metadata=chunk_metadata,
|
| 168 |
+
chunk_id=i
|
| 169 |
+
)
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
return doc_chunks
|
vectordb/json_store.py
ADDED
|
@@ -0,0 +1,230 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
JSON-based vector store for document embeddings
|
| 3 |
+
"""
|
| 4 |
+
import json
|
| 5 |
+
import numpy as np
|
| 6 |
+
from typing import List, Dict, Tuple
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from datetime import datetime
|
| 9 |
+
|
| 10 |
+
from config import EMBEDDINGS_JSON, TOP_K, PROCESSED_DIR
|
| 11 |
+
from models.embeddings import get_embedding_model
|
| 12 |
+
|
| 13 |
+
class JSONStore:
|
| 14 |
+
def __init__(self):
|
| 15 |
+
self.embeddings_file = EMBEDDINGS_JSON
|
| 16 |
+
self.embedding_model = get_embedding_model()
|
| 17 |
+
self.data = self._load_data()
|
| 18 |
+
print(f"✓ JSON Store initialized ({len(self.data['documents'])} documents loaded)")
|
| 19 |
+
|
| 20 |
+
def _load_data(self) -> Dict:
|
| 21 |
+
"""Load data from JSON file"""
|
| 22 |
+
if self.embeddings_file.exists():
|
| 23 |
+
with open(self.embeddings_file, 'r', encoding='utf-8') as f:
|
| 24 |
+
data = json.load(f)
|
| 25 |
+
# Convert embeddings back to numpy arrays
|
| 26 |
+
for doc in data['documents']:
|
| 27 |
+
doc['embedding'] = np.array(doc['embedding'])
|
| 28 |
+
return data
|
| 29 |
+
else:
|
| 30 |
+
model_name = getattr(self.embedding_model.model, '_model_name_or_path',
|
| 31 |
+
getattr(self.embedding_model.model, 'name_or_path',
|
| 32 |
+
'unknown'))
|
| 33 |
+
|
| 34 |
+
return {
|
| 35 |
+
'documents': [],
|
| 36 |
+
'metadata': {
|
| 37 |
+
'created_at': datetime.now().isoformat(),
|
| 38 |
+
'embedding_model': model_name,
|
| 39 |
+
'embedding_dimension': self.embedding_model.dimension
|
| 40 |
+
}
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
def _save_data(self):
|
| 44 |
+
"""Save data to JSON file"""
|
| 45 |
+
# Convert numpy arrays to lists for JSON serialization
|
| 46 |
+
save_data = {
|
| 47 |
+
'documents': [],
|
| 48 |
+
'metadata': self.data['metadata']
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
+
for doc in self.data['documents']:
|
| 52 |
+
doc_copy = doc.copy()
|
| 53 |
+
doc_copy['embedding'] = doc['embedding'].tolist()
|
| 54 |
+
save_data['documents'].append(doc_copy)
|
| 55 |
+
|
| 56 |
+
# Ensure directory exists
|
| 57 |
+
self.embeddings_file.parent.mkdir(parents=True, exist_ok=True)
|
| 58 |
+
|
| 59 |
+
with open(self.embeddings_file, 'w', encoding='utf-8') as f:
|
| 60 |
+
json.dump(save_data, f, indent=2, ensure_ascii=False)
|
| 61 |
+
|
| 62 |
+
def add_documents(
|
| 63 |
+
self,
|
| 64 |
+
texts: List[str],
|
| 65 |
+
metadatas: List[Dict],
|
| 66 |
+
ids: List[str] = None
|
| 67 |
+
):
|
| 68 |
+
"""
|
| 69 |
+
Add documents to store
|
| 70 |
+
|
| 71 |
+
Args:
|
| 72 |
+
texts: List of document texts
|
| 73 |
+
metadatas: List of metadata dicts
|
| 74 |
+
ids: Optional list of document IDs
|
| 75 |
+
"""
|
| 76 |
+
if not texts:
|
| 77 |
+
return
|
| 78 |
+
|
| 79 |
+
# Generate embeddings
|
| 80 |
+
print(f"Generating embeddings for {len(texts)} chunks...")
|
| 81 |
+
embeddings = self.embedding_model.encode(texts)
|
| 82 |
+
|
| 83 |
+
# Generate IDs if not provided
|
| 84 |
+
if ids is None:
|
| 85 |
+
existing_count = len(self.data['documents'])
|
| 86 |
+
ids = [f"doc_{existing_count + i}" for i in range(len(texts))]
|
| 87 |
+
|
| 88 |
+
# Add documents
|
| 89 |
+
for i, (text, metadata, doc_id, embedding) in enumerate(zip(texts, metadatas, ids, embeddings)):
|
| 90 |
+
self.data['documents'].append({
|
| 91 |
+
'id': doc_id,
|
| 92 |
+
'text': text,
|
| 93 |
+
'metadata': metadata,
|
| 94 |
+
'embedding': embedding,
|
| 95 |
+
'added_at': datetime.now().isoformat()
|
| 96 |
+
})
|
| 97 |
+
|
| 98 |
+
# Save to file
|
| 99 |
+
self._save_data()
|
| 100 |
+
print(f"✓ Added {len(texts)} chunks to JSON store")
|
| 101 |
+
|
| 102 |
+
def search(
|
| 103 |
+
self,
|
| 104 |
+
query: str,
|
| 105 |
+
top_k: int = TOP_K,
|
| 106 |
+
filter_metadata: Dict = None
|
| 107 |
+
) -> Tuple[List[str], List[Dict], List[float]]:
|
| 108 |
+
"""
|
| 109 |
+
Search for similar documents using cosine similarity
|
| 110 |
+
|
| 111 |
+
Args:
|
| 112 |
+
query: Search query
|
| 113 |
+
top_k: Number of results to return
|
| 114 |
+
filter_metadata: Optional metadata filter
|
| 115 |
+
|
| 116 |
+
Returns:
|
| 117 |
+
Tuple of (texts, metadatas, distances)
|
| 118 |
+
"""
|
| 119 |
+
if not self.data['documents']:
|
| 120 |
+
return [], [], []
|
| 121 |
+
|
| 122 |
+
# Generate query embedding
|
| 123 |
+
query_embedding = self.embedding_model.encode_query(query)
|
| 124 |
+
|
| 125 |
+
# Calculate similarities
|
| 126 |
+
results = []
|
| 127 |
+
for doc in self.data['documents']:
|
| 128 |
+
# Apply metadata filter if provided
|
| 129 |
+
if filter_metadata:
|
| 130 |
+
match = all(
|
| 131 |
+
doc['metadata'].get(k) == v
|
| 132 |
+
for k, v in filter_metadata.items()
|
| 133 |
+
)
|
| 134 |
+
if not match:
|
| 135 |
+
continue
|
| 136 |
+
|
| 137 |
+
# Calculate cosine similarity
|
| 138 |
+
doc_embedding = doc['embedding']
|
| 139 |
+
similarity = np.dot(query_embedding, doc_embedding) / (
|
| 140 |
+
np.linalg.norm(query_embedding) * np.linalg.norm(doc_embedding)
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
# Convert similarity to distance (1 - similarity for consistency)
|
| 144 |
+
distance = 1 - similarity
|
| 145 |
+
|
| 146 |
+
results.append({
|
| 147 |
+
'text': doc['text'],
|
| 148 |
+
'metadata': doc['metadata'],
|
| 149 |
+
'distance': distance,
|
| 150 |
+
'similarity': similarity
|
| 151 |
+
})
|
| 152 |
+
|
| 153 |
+
# Sort by distance (ascending)
|
| 154 |
+
results.sort(key=lambda x: x['distance'])
|
| 155 |
+
|
| 156 |
+
# Get top_k results
|
| 157 |
+
results = results[:top_k]
|
| 158 |
+
|
| 159 |
+
# Extract components
|
| 160 |
+
texts = [r['text'] for r in results]
|
| 161 |
+
metadatas = [r['metadata'] for r in results]
|
| 162 |
+
distances = [r['distance'] for r in results]
|
| 163 |
+
|
| 164 |
+
return texts, metadatas, distances
|
| 165 |
+
|
| 166 |
+
def delete_all(self):
|
| 167 |
+
"""Delete all documents"""
|
| 168 |
+
self.data = {
|
| 169 |
+
'documents': [],
|
| 170 |
+
'metadata': self.data['metadata']
|
| 171 |
+
}
|
| 172 |
+
self._save_data()
|
| 173 |
+
print("✓ Deleted all documents")
|
| 174 |
+
|
| 175 |
+
def get_stats(self) -> Dict:
|
| 176 |
+
"""Get store statistics"""
|
| 177 |
+
file_size_mb = 0
|
| 178 |
+
if self.embeddings_file.exists():
|
| 179 |
+
file_size_mb = self.embeddings_file.stat().st_size / (1024 * 1024)
|
| 180 |
+
|
| 181 |
+
return {
|
| 182 |
+
'total_documents': len(self.data['documents']),
|
| 183 |
+
'embedding_dimension': self.data['metadata']['embedding_dimension'],
|
| 184 |
+
'embedding_model': self.data['metadata']['embedding_model'],
|
| 185 |
+
'file_path': str(self.embeddings_file),
|
| 186 |
+
'file_size_mb': round(file_size_mb, 2)
|
| 187 |
+
}
|
| 188 |
+
|
| 189 |
+
def export_chunks_only(self, output_file: str = None):
|
| 190 |
+
"""
|
| 191 |
+
Export only text chunks and metadata (without embeddings) to JSON
|
| 192 |
+
|
| 193 |
+
Args:
|
| 194 |
+
output_file: Output file path (optional)
|
| 195 |
+
"""
|
| 196 |
+
if output_file is None:
|
| 197 |
+
output_file = Path(PROCESSED_DIR) / "chunks_only.json"
|
| 198 |
+
else:
|
| 199 |
+
output_file = Path(output_file)
|
| 200 |
+
|
| 201 |
+
chunks_data = {
|
| 202 |
+
'total_chunks': len(self.data['documents']),
|
| 203 |
+
'created_at': datetime.now().isoformat(),
|
| 204 |
+
'chunks': [
|
| 205 |
+
{
|
| 206 |
+
'id': doc['id'],
|
| 207 |
+
'text': doc['text'],
|
| 208 |
+
'metadata': doc['metadata']
|
| 209 |
+
}
|
| 210 |
+
for doc in self.data['documents']
|
| 211 |
+
]
|
| 212 |
+
}
|
| 213 |
+
|
| 214 |
+
# Ensure directory exists
|
| 215 |
+
output_file.parent.mkdir(parents=True, exist_ok=True)
|
| 216 |
+
|
| 217 |
+
with open(output_file, 'w', encoding='utf-8') as f:
|
| 218 |
+
json.dump(chunks_data, f, indent=2, ensure_ascii=False)
|
| 219 |
+
|
| 220 |
+
print(f"✓ Exported {len(chunks_data['chunks'])} chunks to {output_file}")
|
| 221 |
+
|
| 222 |
+
# Singleton instance
|
| 223 |
+
_json_store = None
|
| 224 |
+
|
| 225 |
+
def get_json_store() -> JSONStore:
|
| 226 |
+
"""Get or create JSONStore instance"""
|
| 227 |
+
global _json_store
|
| 228 |
+
if _json_store is None:
|
| 229 |
+
_json_store = JSONStore()
|
| 230 |
+
return _json_store
|