Spaces:
Sleeping
Sleeping
first init
Browse files- .gitignore +27 -0
- Dockerfile +24 -0
- config.py +44 -0
- data/spaces.json +12 -0
- main.py +896 -0
- models/__pycache__/studio_models.cpython-311.pyc +0 -0
- models/studio_models.py +219 -0
- requirements.txt +39 -0
- runtime.txt +1 -0
- start_ngrok_tunnel.py +69 -0
- utils/__init__.py +1 -0
- utils/__pycache__/__init__.cpython-311.pyc +0 -0
- utils/__pycache__/__init__.cpython-314.pyc +0 -0
- utils/__pycache__/config_manager.cpython-311.pyc +0 -0
- utils/__pycache__/config_manager.cpython-314.pyc +0 -0
- utils/__pycache__/document_processor.cpython-311.pyc +0 -0
- utils/__pycache__/document_processor.cpython-314.pyc +0 -0
- utils/__pycache__/hybrid_retriever.cpython-311.pyc +0 -0
- utils/__pycache__/hybrid_retriever.cpython-314.pyc +0 -0
- utils/__pycache__/llm_generator.cpython-311.pyc +0 -0
- utils/__pycache__/llm_generator.cpython-314.pyc +0 -0
- utils/__pycache__/model_inference.cpython-311.pyc +0 -0
- utils/__pycache__/simple_generator.cpython-311.pyc +0 -0
- utils/__pycache__/spaces_manager.cpython-311.pyc +0 -0
- utils/__pycache__/spaces_manager.cpython-314.pyc +0 -0
- utils/__pycache__/studio_generator.cpython-311.pyc +0 -0
- utils/__pycache__/studio_manager.cpython-311.pyc +0 -0
- utils/__pycache__/vector_db.cpython-311.pyc +0 -0
- utils/__pycache__/vector_db.cpython-314.pyc +0 -0
- utils/chat_manager.py +123 -0
- utils/config_manager.py +80 -0
- utils/document_processor.py +222 -0
- utils/hybrid_retriever.py +149 -0
- utils/llm_generator.py +297 -0
- utils/model_inference.py +156 -0
- utils/simple_generator.py +444 -0
- utils/spaces_manager.py +124 -0
- utils/studio_generator.py +309 -0
- utils/studio_manager.py +473 -0
- utils/vector_db.py +148 -0
.gitignore
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Python
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
*$py.class
|
| 5 |
+
*.so
|
| 6 |
+
.Python
|
| 7 |
+
*.pyc
|
| 8 |
+
|
| 9 |
+
# Virtual environment
|
| 10 |
+
venv/
|
| 11 |
+
env/
|
| 12 |
+
ENV/
|
| 13 |
+
|
| 14 |
+
# Environment variables
|
| 15 |
+
.env
|
| 16 |
+
.env.local
|
| 17 |
+
|
| 18 |
+
# IDE
|
| 19 |
+
.vscode/
|
| 20 |
+
.idea/
|
| 21 |
+
*.swp
|
| 22 |
+
*.swo
|
| 23 |
+
*~
|
| 24 |
+
.DS_Store
|
| 25 |
+
|
| 26 |
+
# Logs
|
| 27 |
+
*.log
|
Dockerfile
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.11-slim
|
| 2 |
+
|
| 3 |
+
# Set working directory
|
| 4 |
+
WORKDIR /app
|
| 5 |
+
|
| 6 |
+
# Install system dependencies needed for some ML libraries and PyPDF2
|
| 7 |
+
RUN apt-get update && apt-get install -y \
|
| 8 |
+
build-essential \
|
| 9 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 10 |
+
|
| 11 |
+
# Copy requirements first to leverage Docker cache
|
| 12 |
+
COPY requirements.txt .
|
| 13 |
+
|
| 14 |
+
# Install Python dependencies
|
| 15 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
| 16 |
+
|
| 17 |
+
# Copy the rest of the application
|
| 18 |
+
COPY . .
|
| 19 |
+
|
| 20 |
+
# Expose the port FastAPI will run on
|
| 21 |
+
EXPOSE 7860
|
| 22 |
+
|
| 23 |
+
# Command to run the application (Hugging Face routes to 7860 by default)
|
| 24 |
+
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "7860"]
|
config.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
from dotenv import load_dotenv
|
| 4 |
+
|
| 5 |
+
# Load environment variables
|
| 6 |
+
load_dotenv()
|
| 7 |
+
|
| 8 |
+
# Project paths
|
| 9 |
+
PROJECT_ROOT = Path(__file__).parent
|
| 10 |
+
DATA_DIR = PROJECT_ROOT.parent / "data" # Use project root's data folder, not backend/data
|
| 11 |
+
MODELS_DIR = PROJECT_ROOT / "models"
|
| 12 |
+
UPLOADS_DIR = DATA_DIR / "uploads"
|
| 13 |
+
VECTOR_DB_DIR = DATA_DIR / "vector_db"
|
| 14 |
+
CHATS_DIR = DATA_DIR / "chats"
|
| 15 |
+
|
| 16 |
+
# Create directories if they don't exist
|
| 17 |
+
for dir_path in [DATA_DIR, MODELS_DIR, UPLOADS_DIR, VECTOR_DB_DIR, CHATS_DIR]:
|
| 18 |
+
dir_path.mkdir(parents=True, exist_ok=True)
|
| 19 |
+
|
| 20 |
+
# Model configuration
|
| 21 |
+
# RAG uses pre-trained models directly - no training required!
|
| 22 |
+
MODEL_NAME = os.getenv("MODEL_NAME", "microsoft/phi-2") # Pre-trained model
|
| 23 |
+
USE_PRETRAINED = os.getenv("USE_PRETRAINED", "true").lower() == "true" # Use pre-trained by default
|
| 24 |
+
MODEL_PATH = os.getenv("MODEL_PATH", str(MODELS_DIR / "trained_model")) # Only if fine-tuned
|
| 25 |
+
EMBEDDING_MODEL = "sentence-transformers/all-MiniLM-L6-v2" # For document embeddings
|
| 26 |
+
|
| 27 |
+
# API Keys
|
| 28 |
+
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "")
|
| 29 |
+
HUGGINGFACE_API_KEY = os.getenv("HUGGINGFACE_API_KEY", "")
|
| 30 |
+
|
| 31 |
+
# Application settings
|
| 32 |
+
MAX_UPLOAD_SIZE = int(os.getenv("MAX_UPLOAD_SIZE", "200")) # MB
|
| 33 |
+
TEMPERATURE = float(os.getenv("TEMPERATURE", "0.7"))
|
| 34 |
+
MAX_TOKENS = int(os.getenv("MAX_TOKENS", "2048"))
|
| 35 |
+
CHUNK_SIZE = 512
|
| 36 |
+
CHUNK_OVERLAP = 50
|
| 37 |
+
|
| 38 |
+
# Use cases
|
| 39 |
+
USE_CASES = {
|
| 40 |
+
"explanation": "Provide detailed explanation of concepts",
|
| 41 |
+
"summary": "Generate concise summary of content",
|
| 42 |
+
"qa": "Answer questions based on content",
|
| 43 |
+
"notes": "Create structured study notes"
|
| 44 |
+
}
|
data/spaces.json
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"spaces": [
|
| 3 |
+
{
|
| 4 |
+
"id": "general",
|
| 5 |
+
"name": "General",
|
| 6 |
+
"description": "General study materials",
|
| 7 |
+
"created_at": "2026-03-12T10:42:37.952166",
|
| 8 |
+
"file_count": 0,
|
| 9 |
+
"chat_count": 0
|
| 10 |
+
}
|
| 11 |
+
]
|
| 12 |
+
}
|
main.py
ADDED
|
@@ -0,0 +1,896 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
FastAPI Backend for NotebookPRO
|
| 3 |
+
Handles RAG, LLM, file processing, and chat management
|
| 4 |
+
"""
|
| 5 |
+
from fastapi import FastAPI, File, UploadFile, HTTPException
|
| 6 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 7 |
+
from pydantic import BaseModel
|
| 8 |
+
from typing import List, Optional, Dict, Any
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
import json
|
| 11 |
+
from datetime import datetime
|
| 12 |
+
import uuid
|
| 13 |
+
import sys
|
| 14 |
+
import warnings
|
| 15 |
+
import logging
|
| 16 |
+
import os
|
| 17 |
+
import shutil
|
| 18 |
+
|
| 19 |
+
# Suppress warnings
|
| 20 |
+
warnings.filterwarnings('ignore')
|
| 21 |
+
os.environ['PYTHONWARNINGS'] = 'ignore'
|
| 22 |
+
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
|
| 23 |
+
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
|
| 24 |
+
os.environ.setdefault('OMP_NUM_THREADS', '2')
|
| 25 |
+
os.environ.setdefault('MKL_NUM_THREADS', '2')
|
| 26 |
+
os.environ.setdefault('OPENBLAS_NUM_THREADS', '2')
|
| 27 |
+
os.environ.setdefault('NUMEXPR_NUM_THREADS', '2')
|
| 28 |
+
#logging.getLogger().setLevel(logging.ERROR)
|
| 29 |
+
|
| 30 |
+
# Add project root to path
|
| 31 |
+
sys.path.append(str(Path(__file__).parent.parent))
|
| 32 |
+
|
| 33 |
+
import config
|
| 34 |
+
from utils.document_processor import DocumentProcessor
|
| 35 |
+
from utils.vector_db import VectorDatabase
|
| 36 |
+
from utils.hybrid_retriever import HybridRetriever
|
| 37 |
+
from utils.llm_generator import LLMGenerator
|
| 38 |
+
from utils.config_manager import ConfigManager
|
| 39 |
+
from utils.spaces_manager import SpacesManager
|
| 40 |
+
from utils.studio_manager import StudioManager
|
| 41 |
+
from utils.studio_generator import StudioGenerator
|
| 42 |
+
|
| 43 |
+
# Initialize FastAPI
|
| 44 |
+
app = FastAPI(title="NotebookPRO API", version="2.0.0")
|
| 45 |
+
|
| 46 |
+
# CORS - Allow Flutter web to connect
|
| 47 |
+
app.add_middleware(
|
| 48 |
+
CORSMiddleware,
|
| 49 |
+
allow_origins=["*"], # In production, specify your Flutter web URL
|
| 50 |
+
allow_credentials=True,
|
| 51 |
+
allow_methods=["*"],
|
| 52 |
+
allow_headers=["*"],
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
# Global instances
|
| 56 |
+
config_manager = ConfigManager()
|
| 57 |
+
spaces_manager = SpacesManager()
|
| 58 |
+
studio_manager = StudioManager()
|
| 59 |
+
studio_generator = None # Will be initialized after LLM
|
| 60 |
+
vector_db = None
|
| 61 |
+
llm_generator = None
|
| 62 |
+
current_space = None
|
| 63 |
+
|
| 64 |
+
# ==================== Pydantic Models ====================
|
| 65 |
+
|
| 66 |
+
class ChatMessage(BaseModel):
|
| 67 |
+
role: str
|
| 68 |
+
content: str
|
| 69 |
+
timestamp: str
|
| 70 |
+
sources: Optional[List[Dict[str, Any]]] = None
|
| 71 |
+
|
| 72 |
+
class ChatRequest(BaseModel):
|
| 73 |
+
query: str
|
| 74 |
+
space_id: str
|
| 75 |
+
chat_id: Optional[str] = None
|
| 76 |
+
workflow: str = "chat"
|
| 77 |
+
|
| 78 |
+
class ChatResponse(BaseModel):
|
| 79 |
+
response: str
|
| 80 |
+
sources: List[Dict[str, Any]]
|
| 81 |
+
chat_id: str
|
| 82 |
+
timestamp: str
|
| 83 |
+
|
| 84 |
+
class SpaceCreate(BaseModel):
|
| 85 |
+
name: str
|
| 86 |
+
|
| 87 |
+
class SpaceResponse(BaseModel):
|
| 88 |
+
id: str
|
| 89 |
+
name: str
|
| 90 |
+
created_at: str
|
| 91 |
+
file_count: int
|
| 92 |
+
|
| 93 |
+
class ChatInfo(BaseModel):
|
| 94 |
+
id: str
|
| 95 |
+
title: str
|
| 96 |
+
preview: str
|
| 97 |
+
created_at: str
|
| 98 |
+
updated_at: str
|
| 99 |
+
message_count: int
|
| 100 |
+
|
| 101 |
+
class ConfigResponse(BaseModel):
|
| 102 |
+
groq_api_key: Optional[str]
|
| 103 |
+
gemini_api_key: Optional[str]
|
| 104 |
+
|
| 105 |
+
class ConfigUpdate(BaseModel):
|
| 106 |
+
groq_api_key: Optional[str] = None
|
| 107 |
+
gemini_api_key: Optional[str] = None
|
| 108 |
+
|
| 109 |
+
class ChatToNotebookRequest(BaseModel):
|
| 110 |
+
space_id: str
|
| 111 |
+
question: str
|
| 112 |
+
answer: str
|
| 113 |
+
chat_id: Optional[str] = None
|
| 114 |
+
assistant_timestamp: Optional[str] = None
|
| 115 |
+
tags: List[str] = []
|
| 116 |
+
space_name: Optional[str] = None
|
| 117 |
+
|
| 118 |
+
# ==================== Helper Functions ====================
|
| 119 |
+
|
| 120 |
+
def get_data_dir():
|
| 121 |
+
"""Get data directory path"""
|
| 122 |
+
return Path(__file__).parent.parent / "data"
|
| 123 |
+
|
| 124 |
+
def get_space_dir(space_id: str):
|
| 125 |
+
"""Get space-specific directory"""
|
| 126 |
+
return get_data_dir() / "spaces" / space_id
|
| 127 |
+
|
| 128 |
+
def load_chats_for_space(space_id: str) -> List[Dict]:
|
| 129 |
+
"""Load all chats for a space"""
|
| 130 |
+
chats_file = get_space_dir(space_id) / "chats.json"
|
| 131 |
+
if chats_file.exists():
|
| 132 |
+
with open(chats_file, 'r', encoding='utf-8') as f:
|
| 133 |
+
return json.load(f)
|
| 134 |
+
return []
|
| 135 |
+
|
| 136 |
+
def save_chats_for_space(space_id: str, chats: List[Dict]):
|
| 137 |
+
"""Save chats for a space"""
|
| 138 |
+
chats_file = get_space_dir(space_id) / "chats.json"
|
| 139 |
+
chats_file.parent.mkdir(parents=True, exist_ok=True)
|
| 140 |
+
with open(chats_file, 'w', encoding='utf-8') as f:
|
| 141 |
+
json.dump(chats, f, indent=2, ensure_ascii=False)
|
| 142 |
+
|
| 143 |
+
def get_chat_title(messages: List[Dict]) -> str:
|
| 144 |
+
"""Generate chat title from first user message"""
|
| 145 |
+
for msg in messages:
|
| 146 |
+
if msg['role'] == 'user':
|
| 147 |
+
content = msg['content'][:50]
|
| 148 |
+
return content + "..." if len(msg['content']) > 50 else content
|
| 149 |
+
return "New Chat"
|
| 150 |
+
|
| 151 |
+
def ensure_notebooks_for_existing_spaces() -> int:
|
| 152 |
+
"""Ensure every existing space has an associated notebook metadata record."""
|
| 153 |
+
created_count = 0
|
| 154 |
+
spaces = spaces_manager.get_all_spaces()
|
| 155 |
+
|
| 156 |
+
for space in spaces:
|
| 157 |
+
space_id = space.get('id')
|
| 158 |
+
if not space_id:
|
| 159 |
+
continue
|
| 160 |
+
|
| 161 |
+
existing_notebook = studio_manager.get_space_notebook(space_id)
|
| 162 |
+
if existing_notebook:
|
| 163 |
+
continue
|
| 164 |
+
|
| 165 |
+
studio_manager.ensure_space_notebook(space_id, space.get('name', space_id))
|
| 166 |
+
created_count += 1
|
| 167 |
+
|
| 168 |
+
return created_count
|
| 169 |
+
|
| 170 |
+
def rebuild_space_index_if_missing(space_id: str) -> int:
|
| 171 |
+
"""Rebuild a space index from uploaded files if the current index is empty."""
|
| 172 |
+
if not vector_db:
|
| 173 |
+
return 0
|
| 174 |
+
|
| 175 |
+
try:
|
| 176 |
+
if vector_db.get_collection_count() > 0:
|
| 177 |
+
return 0
|
| 178 |
+
except Exception:
|
| 179 |
+
# If count check fails, continue with a best-effort rebuild.
|
| 180 |
+
pass
|
| 181 |
+
|
| 182 |
+
uploads_dir = get_space_dir(space_id) / "uploads"
|
| 183 |
+
if not uploads_dir.exists():
|
| 184 |
+
return 0
|
| 185 |
+
|
| 186 |
+
files = [
|
| 187 |
+
p for p in uploads_dir.iterdir()
|
| 188 |
+
if p.is_file() and p.suffix.lower() in {".pdf", ".docx", ".txt"}
|
| 189 |
+
]
|
| 190 |
+
if not files:
|
| 191 |
+
return 0
|
| 192 |
+
|
| 193 |
+
processor = DocumentProcessor()
|
| 194 |
+
texts: List[str] = []
|
| 195 |
+
metadatas: List[Dict[str, Any]] = []
|
| 196 |
+
ids: List[str] = []
|
| 197 |
+
|
| 198 |
+
for file_path in files:
|
| 199 |
+
try:
|
| 200 |
+
file_data = processor.process_file(file_path)
|
| 201 |
+
chunks = processor.chunk_text(
|
| 202 |
+
file_data['content'],
|
| 203 |
+
chunk_size=512,
|
| 204 |
+
overlap=50,
|
| 205 |
+
semantic=True,
|
| 206 |
+
)
|
| 207 |
+
total_chunks = len(chunks)
|
| 208 |
+
for idx, chunk in enumerate(chunks):
|
| 209 |
+
texts.append(chunk)
|
| 210 |
+
metadatas.append({
|
| 211 |
+
'filename': file_path.name,
|
| 212 |
+
'chunk_index': idx,
|
| 213 |
+
'total_chunks': total_chunks,
|
| 214 |
+
'source_type': file_data['format'],
|
| 215 |
+
})
|
| 216 |
+
ids.append(f"{space_id}_rebuild_{len(ids)}_{uuid.uuid4().hex[:8]}")
|
| 217 |
+
except Exception as e:
|
| 218 |
+
print(f"Index rebuild skipped {file_path.name}: {e}")
|
| 219 |
+
|
| 220 |
+
if not texts:
|
| 221 |
+
return 0
|
| 222 |
+
|
| 223 |
+
batch_size = 5000
|
| 224 |
+
for i in range(0, len(texts), batch_size):
|
| 225 |
+
vector_db.add_documents(
|
| 226 |
+
texts[i:i + batch_size],
|
| 227 |
+
metadatas[i:i + batch_size],
|
| 228 |
+
ids[i:i + batch_size],
|
| 229 |
+
)
|
| 230 |
+
|
| 231 |
+
print(f"Rebuilt index for space '{space_id}' with {len(texts)} chunks")
|
| 232 |
+
return len(texts)
|
| 233 |
+
|
| 234 |
+
def initialize_space(space_id: str):
|
| 235 |
+
"""Initialize vector DB and components for a space"""
|
| 236 |
+
global vector_db, llm_generator, studio_generator, current_space
|
| 237 |
+
|
| 238 |
+
# Fast path: reuse already initialized components for the active space.
|
| 239 |
+
if current_space == space_id and vector_db is not None and llm_generator is not None:
|
| 240 |
+
return
|
| 241 |
+
|
| 242 |
+
# Get API keys
|
| 243 |
+
import os
|
| 244 |
+
# Try the config manager first, but fallback to the .env file variables
|
| 245 |
+
groq_key = config_manager.get_api_key('groq') or os.getenv('GROQ_API_KEY')
|
| 246 |
+
gemini_key = config_manager.get_api_key('gemini') or os.getenv('GOOGLE_API_KEY') or os.getenv('GEMINI_API_KEY')
|
| 247 |
+
|
| 248 |
+
if not groq_key and not gemini_key:
|
| 249 |
+
raise HTTPException(status_code=400, detail="No API keys configured. Please add Groq or Gemini API key.")
|
| 250 |
+
|
| 251 |
+
# Initialize vector database for this space (space-local persistence path).
|
| 252 |
+
# Initialize Qdrant cloud database for this space
|
| 253 |
+
vector_db = VectorDatabase(
|
| 254 |
+
collection_name=f"space_{space_id}"
|
| 255 |
+
)
|
| 256 |
+
|
| 257 |
+
# Backward-compatibility: rebuild embeddings from uploaded files if index is empty.
|
| 258 |
+
rebuild_space_index_if_missing(space_id)
|
| 259 |
+
|
| 260 |
+
# Initialize LLM generator - choose provider based on available keys
|
| 261 |
+
if groq_key:
|
| 262 |
+
llm_generator = LLMGenerator(provider="groq", api_key=groq_key)
|
| 263 |
+
else:
|
| 264 |
+
llm_generator = LLMGenerator(provider="gemini", api_key=gemini_key)
|
| 265 |
+
|
| 266 |
+
# Initialize studio generator with LLM
|
| 267 |
+
studio_generator = StudioGenerator(llm_generator, studio_manager)
|
| 268 |
+
current_space = space_id
|
| 269 |
+
|
| 270 |
+
@app.on_event("startup")
|
| 271 |
+
async def startup_sync_notebooks():
|
| 272 |
+
"""Auto-create missing notebooks for pre-existing spaces when backend starts."""
|
| 273 |
+
try:
|
| 274 |
+
created = ensure_notebooks_for_existing_spaces()
|
| 275 |
+
if created > 0:
|
| 276 |
+
print(f"Created {created} missing notebook(s) for existing spaces")
|
| 277 |
+
except Exception as e:
|
| 278 |
+
# Keep server startup resilient even if sync fails.
|
| 279 |
+
print(f"Notebook startup sync failed: {e}")
|
| 280 |
+
|
| 281 |
+
# ==================== API Endpoints ====================
|
| 282 |
+
|
| 283 |
+
@app.get("/")
|
| 284 |
+
async def root():
|
| 285 |
+
"""Health check"""
|
| 286 |
+
return {"status": "NotebookPRO API is running", "version": "2.0.0"}
|
| 287 |
+
|
| 288 |
+
@app.get("/api/config", response_model=ConfigResponse)
|
| 289 |
+
async def get_config():
|
| 290 |
+
"""Get current API keys (masked)"""
|
| 291 |
+
groq_key = config_manager.get_api_key('groq')
|
| 292 |
+
gemini_key = config_manager.get_api_key('gemini')
|
| 293 |
+
|
| 294 |
+
return ConfigResponse(
|
| 295 |
+
groq_api_key="***" + groq_key[-4:] if groq_key else None,
|
| 296 |
+
gemini_api_key="***" + gemini_key[-4:] if gemini_key else None
|
| 297 |
+
)
|
| 298 |
+
|
| 299 |
+
@app.post("/api/config")
|
| 300 |
+
async def update_config(config_update: ConfigUpdate):
|
| 301 |
+
"""Update API keys"""
|
| 302 |
+
if config_update.groq_api_key:
|
| 303 |
+
config_manager.set_api_key('groq', config_update.groq_api_key)
|
| 304 |
+
if config_update.gemini_api_key:
|
| 305 |
+
config_manager.set_api_key('gemini', config_update.gemini_api_key)
|
| 306 |
+
|
| 307 |
+
return {"status": "success", "message": "Configuration updated"}
|
| 308 |
+
|
| 309 |
+
@app.get("/api/spaces", response_model=List[SpaceResponse])
|
| 310 |
+
async def get_spaces():
|
| 311 |
+
"""Get all spaces"""
|
| 312 |
+
# Self-healing check in case spaces were created externally while server is running.
|
| 313 |
+
ensure_notebooks_for_existing_spaces()
|
| 314 |
+
spaces = spaces_manager.get_all_spaces()
|
| 315 |
+
|
| 316 |
+
result = []
|
| 317 |
+
for space in spaces:
|
| 318 |
+
space_id = space['id']
|
| 319 |
+
space_dir = get_space_dir(space_id)
|
| 320 |
+
processed_file = space_dir / "processed_files.json"
|
| 321 |
+
|
| 322 |
+
file_count = 0
|
| 323 |
+
if processed_file.exists():
|
| 324 |
+
with open(processed_file, 'r') as f:
|
| 325 |
+
file_count = len(json.load(f))
|
| 326 |
+
|
| 327 |
+
result.append(SpaceResponse(
|
| 328 |
+
id=space_id,
|
| 329 |
+
name=space['name'],
|
| 330 |
+
created_at=space['created_at'],
|
| 331 |
+
file_count=file_count
|
| 332 |
+
))
|
| 333 |
+
|
| 334 |
+
return result
|
| 335 |
+
|
| 336 |
+
@app.post("/api/spaces", response_model=SpaceResponse)
|
| 337 |
+
async def create_space(space_data: SpaceCreate):
|
| 338 |
+
"""Create a new space"""
|
| 339 |
+
try:
|
| 340 |
+
space = spaces_manager.create_space(space_data.name)
|
| 341 |
+
|
| 342 |
+
# Create associated notebook metadata with the same name as the space.
|
| 343 |
+
studio_manager.ensure_space_notebook(space['id'], space['name'])
|
| 344 |
+
|
| 345 |
+
return SpaceResponse(
|
| 346 |
+
id=space['id'],
|
| 347 |
+
name=space['name'],
|
| 348 |
+
created_at=space['created_at'],
|
| 349 |
+
file_count=0
|
| 350 |
+
)
|
| 351 |
+
except ValueError as e:
|
| 352 |
+
raise HTTPException(status_code=400, detail=str(e))
|
| 353 |
+
|
| 354 |
+
@app.delete("/api/spaces/{space_id}")
|
| 355 |
+
async def delete_space(space_id: str):
|
| 356 |
+
"""Delete a space"""
|
| 357 |
+
try:
|
| 358 |
+
spaces_manager.delete_space(space_id)
|
| 359 |
+
|
| 360 |
+
# Delete space directory
|
| 361 |
+
space_dir = get_space_dir(space_id)
|
| 362 |
+
if space_dir.exists():
|
| 363 |
+
shutil.rmtree(space_dir)
|
| 364 |
+
|
| 365 |
+
return {"status": "success", "message": f"Space {space_id} deleted"}
|
| 366 |
+
except ValueError as e:
|
| 367 |
+
raise HTTPException(status_code=400, detail=str(e))
|
| 368 |
+
except Exception as e:
|
| 369 |
+
raise HTTPException(status_code=500, detail=f"Error deleting space: {str(e)}")
|
| 370 |
+
|
| 371 |
+
@app.get("/api/spaces/{space_id}/chats", response_model=List[ChatInfo])
|
| 372 |
+
async def get_chats(space_id: str):
|
| 373 |
+
"""Get all chats for a space"""
|
| 374 |
+
chats = load_chats_for_space(space_id)
|
| 375 |
+
|
| 376 |
+
result = []
|
| 377 |
+
for chat in chats:
|
| 378 |
+
messages = chat.get('messages', [])
|
| 379 |
+
result.append(ChatInfo(
|
| 380 |
+
id=chat['id'],
|
| 381 |
+
title=get_chat_title(messages),
|
| 382 |
+
preview=messages[0]['content'][:100] if messages else "",
|
| 383 |
+
created_at=chat.get('created_at', ''),
|
| 384 |
+
updated_at=chat.get('updated_at', ''),
|
| 385 |
+
message_count=len(messages)
|
| 386 |
+
))
|
| 387 |
+
|
| 388 |
+
return result
|
| 389 |
+
|
| 390 |
+
@app.get("/api/spaces/{space_id}/chats/{chat_id}")
|
| 391 |
+
async def get_chat(space_id: str, chat_id: str):
|
| 392 |
+
"""Get specific chat by ID"""
|
| 393 |
+
chats = load_chats_for_space(space_id)
|
| 394 |
+
|
| 395 |
+
for chat in chats:
|
| 396 |
+
if chat['id'] == chat_id:
|
| 397 |
+
return chat
|
| 398 |
+
|
| 399 |
+
raise HTTPException(status_code=404, detail="Chat not found")
|
| 400 |
+
|
| 401 |
+
@app.delete("/api/spaces/{space_id}/chats/{chat_id}")
|
| 402 |
+
async def delete_chat(space_id: str, chat_id: str):
|
| 403 |
+
"""Delete a chat"""
|
| 404 |
+
chats = load_chats_for_space(space_id)
|
| 405 |
+
chats = [c for c in chats if c['id'] != chat_id]
|
| 406 |
+
save_chats_for_space(space_id, chats)
|
| 407 |
+
|
| 408 |
+
return {"status": "success", "message": f"Chat {chat_id} deleted"}
|
| 409 |
+
|
| 410 |
+
@app.post("/api/chat", response_model=ChatResponse)
|
| 411 |
+
async def chat(request: ChatRequest):
|
| 412 |
+
"""Process a chat message with RAG"""
|
| 413 |
+
try:
|
| 414 |
+
# Initialize space if needed
|
| 415 |
+
initialize_space(request.space_id)
|
| 416 |
+
|
| 417 |
+
# Create hybrid retriever with 60% vector, 40% BM25
|
| 418 |
+
hybrid_retriever = HybridRetriever(vector_db, alpha=0.6)
|
| 419 |
+
|
| 420 |
+
# Retrieve relevant documents
|
| 421 |
+
documents, metadatas, scores = hybrid_retriever.retrieve(
|
| 422 |
+
query=request.query,
|
| 423 |
+
n_results=5
|
| 424 |
+
)
|
| 425 |
+
|
| 426 |
+
# Build context from retrieved documents
|
| 427 |
+
context_parts = []
|
| 428 |
+
sources = []
|
| 429 |
+
|
| 430 |
+
for idx, (doc, meta, score) in enumerate(zip(documents, metadatas, scores), 1):
|
| 431 |
+
# Extract clean filename for source citation
|
| 432 |
+
filename = meta.get('filename', 'Unknown')
|
| 433 |
+
clean_name = filename.replace('.pdf', '').replace('.docx', '').replace('.txt', '')
|
| 434 |
+
context_parts.append(f"Source [{idx}] ({clean_name}):\n{doc}\n")
|
| 435 |
+
sources.append({
|
| 436 |
+
"content": doc[:200] + "..." if len(doc) > 200 else doc,
|
| 437 |
+
"metadata": meta,
|
| 438 |
+
"score": float(score)
|
| 439 |
+
})
|
| 440 |
+
|
| 441 |
+
context = "\n".join(context_parts)
|
| 442 |
+
|
| 443 |
+
# Use the advanced generate_response method which has the new NotebookLM-style prompt
|
| 444 |
+
response = llm_generator.generate_response(
|
| 445 |
+
prompt=request.query,
|
| 446 |
+
context=context,
|
| 447 |
+
use_case=request.workflow if request.workflow in ["summary", "explanation", "qa", "notes"] else "qa",
|
| 448 |
+
metadatas=metadatas,
|
| 449 |
+
temperature=0.3
|
| 450 |
+
)
|
| 451 |
+
|
| 452 |
+
# Create or update chat
|
| 453 |
+
chat_id = request.chat_id or str(uuid.uuid4())
|
| 454 |
+
chats = load_chats_for_space(request.space_id)
|
| 455 |
+
|
| 456 |
+
# Find existing chat or create new
|
| 457 |
+
chat = None
|
| 458 |
+
for c in chats:
|
| 459 |
+
if c['id'] == chat_id:
|
| 460 |
+
chat = c
|
| 461 |
+
break
|
| 462 |
+
|
| 463 |
+
if not chat:
|
| 464 |
+
chat = {
|
| 465 |
+
'id': chat_id,
|
| 466 |
+
'messages': [],
|
| 467 |
+
'created_at': datetime.now().isoformat(),
|
| 468 |
+
'updated_at': datetime.now().isoformat()
|
| 469 |
+
}
|
| 470 |
+
chats.append(chat)
|
| 471 |
+
|
| 472 |
+
# Add messages
|
| 473 |
+
timestamp = datetime.now().isoformat()
|
| 474 |
+
chat['messages'].extend([
|
| 475 |
+
{'role': 'user', 'content': request.query, 'timestamp': timestamp},
|
| 476 |
+
{
|
| 477 |
+
'role': 'assistant',
|
| 478 |
+
'content': response,
|
| 479 |
+
'timestamp': timestamp,
|
| 480 |
+
'sources': sources
|
| 481 |
+
}
|
| 482 |
+
])
|
| 483 |
+
chat['updated_at'] = timestamp
|
| 484 |
+
|
| 485 |
+
# Save chats
|
| 486 |
+
save_chats_for_space(request.space_id, chats)
|
| 487 |
+
|
| 488 |
+
return ChatResponse(
|
| 489 |
+
response=response,
|
| 490 |
+
sources=sources,
|
| 491 |
+
chat_id=chat_id,
|
| 492 |
+
timestamp=timestamp
|
| 493 |
+
)
|
| 494 |
+
|
| 495 |
+
except Exception as e:
|
| 496 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 497 |
+
|
| 498 |
+
@app.post("/api/spaces/{space_id}/upload")
|
| 499 |
+
async def upload_files(space_id: str, files: List[UploadFile] = File(...)):
|
| 500 |
+
"""Upload and process files for a space"""
|
| 501 |
+
try:
|
| 502 |
+
# Initialize space
|
| 503 |
+
initialize_space(space_id)
|
| 504 |
+
|
| 505 |
+
# Save uploaded files temporarily
|
| 506 |
+
space_dir = get_space_dir(space_id)
|
| 507 |
+
uploads_dir = space_dir / "uploads"
|
| 508 |
+
uploads_dir.mkdir(parents=True, exist_ok=True)
|
| 509 |
+
|
| 510 |
+
processor = DocumentProcessor()
|
| 511 |
+
all_chunks = []
|
| 512 |
+
processed_files = []
|
| 513 |
+
|
| 514 |
+
for file in files:
|
| 515 |
+
# Save file
|
| 516 |
+
file_path = uploads_dir / file.filename
|
| 517 |
+
with open(file_path, "wb") as f:
|
| 518 |
+
content = await file.read()
|
| 519 |
+
f.write(content)
|
| 520 |
+
|
| 521 |
+
# Process file and extract content
|
| 522 |
+
try:
|
| 523 |
+
file_data = processor.process_file(file_path)
|
| 524 |
+
content = file_data['content']
|
| 525 |
+
|
| 526 |
+
# Chunk the content
|
| 527 |
+
chunks = processor.chunk_text(content, chunk_size=512, overlap=50, semantic=True)
|
| 528 |
+
|
| 529 |
+
# Format chunks for vector database
|
| 530 |
+
formatted_chunks = []
|
| 531 |
+
for idx, chunk in enumerate(chunks):
|
| 532 |
+
formatted_chunks.append({
|
| 533 |
+
'content': chunk,
|
| 534 |
+
'metadata': {
|
| 535 |
+
'filename': file.filename,
|
| 536 |
+
'chunk_index': idx,
|
| 537 |
+
'total_chunks': len(chunks),
|
| 538 |
+
'source_type': file_data['format']
|
| 539 |
+
}
|
| 540 |
+
})
|
| 541 |
+
|
| 542 |
+
all_chunks.extend(formatted_chunks)
|
| 543 |
+
processed_files.append({
|
| 544 |
+
'filename': file.filename,
|
| 545 |
+
'chunks': len(chunks),
|
| 546 |
+
'processed_at': datetime.now().isoformat()
|
| 547 |
+
})
|
| 548 |
+
except Exception as e:
|
| 549 |
+
# Log error but continue with other files
|
| 550 |
+
print(f"Error processing {file.filename}: {str(e)}")
|
| 551 |
+
continue
|
| 552 |
+
|
| 553 |
+
# Add to vector database in batches to avoid size limits
|
| 554 |
+
if all_chunks:
|
| 555 |
+
# Extract texts, metadatas, and generate IDs
|
| 556 |
+
texts = [chunk['content'] for chunk in all_chunks]
|
| 557 |
+
metadatas = [chunk['metadata'] for chunk in all_chunks]
|
| 558 |
+
ids = [f"{space_id}_{idx}_{uuid.uuid4().hex[:8]}" for idx in range(len(all_chunks))]
|
| 559 |
+
|
| 560 |
+
# Process in batches of 5000 to avoid ChromaDB batch size limit
|
| 561 |
+
batch_size = 5000
|
| 562 |
+
for i in range(0, len(texts), batch_size):
|
| 563 |
+
batch_texts = texts[i:i + batch_size]
|
| 564 |
+
batch_metadatas = metadatas[i:i + batch_size]
|
| 565 |
+
batch_ids = ids[i:i + batch_size]
|
| 566 |
+
|
| 567 |
+
vector_db.add_documents(batch_texts, batch_metadatas, batch_ids)
|
| 568 |
+
print(f"Processed batch {i//batch_size + 1}/{(len(texts)-1)//batch_size + 1}")
|
| 569 |
+
|
| 570 |
+
# Save processed files info
|
| 571 |
+
processed_file = space_dir / "processed_files.json"
|
| 572 |
+
existing = []
|
| 573 |
+
if processed_file.exists():
|
| 574 |
+
with open(processed_file, 'r') as f:
|
| 575 |
+
existing = json.load(f)
|
| 576 |
+
|
| 577 |
+
existing.extend(processed_files)
|
| 578 |
+
with open(processed_file, 'w') as f:
|
| 579 |
+
json.dump(existing, f, indent=2)
|
| 580 |
+
|
| 581 |
+
return {
|
| 582 |
+
"status": "success",
|
| 583 |
+
"files_processed": len(processed_files),
|
| 584 |
+
"total_chunks": len(all_chunks)
|
| 585 |
+
}
|
| 586 |
+
|
| 587 |
+
except Exception as e:
|
| 588 |
+
raise e # This strips the wrapper and forces FastAPI to log the raw stack trace
|
| 589 |
+
|
| 590 |
+
@app.get("/api/spaces/{space_id}/files")
|
| 591 |
+
async def get_files(space_id: str):
|
| 592 |
+
"""Get processed files for a space"""
|
| 593 |
+
processed_file = get_space_dir(space_id) / "processed_files.json"
|
| 594 |
+
|
| 595 |
+
if processed_file.exists():
|
| 596 |
+
with open(processed_file, 'r') as f:
|
| 597 |
+
return json.load(f)
|
| 598 |
+
|
| 599 |
+
return []
|
| 600 |
+
|
| 601 |
+
@app.delete("/api/spaces/{space_id}/files/{filename}")
|
| 602 |
+
async def delete_file(space_id: str, filename: str):
|
| 603 |
+
"""Delete a specific file from a space"""
|
| 604 |
+
try:
|
| 605 |
+
# Remove from processed_files.json
|
| 606 |
+
processed_file = get_space_dir(space_id) / "processed_files.json"
|
| 607 |
+
files_data = []
|
| 608 |
+
|
| 609 |
+
if processed_file.exists():
|
| 610 |
+
with open(processed_file, 'r') as f:
|
| 611 |
+
files_data = json.load(f)
|
| 612 |
+
|
| 613 |
+
# Filter out the file to delete
|
| 614 |
+
files_data = [f for f in files_data if f.get('filename') != filename]
|
| 615 |
+
|
| 616 |
+
with open(processed_file, 'w') as f:
|
| 617 |
+
json.dump(files_data, f, indent=2)
|
| 618 |
+
|
| 619 |
+
# Delete the actual file
|
| 620 |
+
file_path = get_space_dir(space_id) / "uploads" / filename
|
| 621 |
+
if file_path.exists():
|
| 622 |
+
file_path.unlink()
|
| 623 |
+
|
| 624 |
+
# Remove from vector database (if initialized)
|
| 625 |
+
# Note: This removes all chunks with this filename from metadata
|
| 626 |
+
if vector_db:
|
| 627 |
+
try:
|
| 628 |
+
# Get all documents in the collection
|
| 629 |
+
collection = vector_db.collection
|
| 630 |
+
results = collection.get()
|
| 631 |
+
|
| 632 |
+
# Find IDs of documents with matching filename
|
| 633 |
+
ids_to_delete = []
|
| 634 |
+
for idx, metadata in enumerate(results['metadatas']):
|
| 635 |
+
if metadata and metadata.get('filename') == filename:
|
| 636 |
+
ids_to_delete.append(results['ids'][idx])
|
| 637 |
+
|
| 638 |
+
# Delete those documents
|
| 639 |
+
if ids_to_delete:
|
| 640 |
+
collection.delete(ids=ids_to_delete)
|
| 641 |
+
print(f"Deleted {len(ids_to_delete)} chunks for {filename}")
|
| 642 |
+
except Exception as e:
|
| 643 |
+
print(f"Error removing from vector DB: {e}")
|
| 644 |
+
|
| 645 |
+
return {
|
| 646 |
+
"status": "success",
|
| 647 |
+
"message": f"File {filename} deleted"
|
| 648 |
+
}
|
| 649 |
+
|
| 650 |
+
except Exception as e:
|
| 651 |
+
raise HTTPException(status_code=500, detail=f"Error deleting file: {str(e)}")
|
| 652 |
+
|
| 653 |
+
|
| 654 |
+
# ==================== STUDIO API ROUTES ====================
|
| 655 |
+
# Routes for Notebook, Flashcards, and Quiz features
|
| 656 |
+
|
| 657 |
+
# Import studio models
|
| 658 |
+
from models.studio_models import (
|
| 659 |
+
NotebookEntry, NotebookEntryCreate, NotebookEntryUpdate,
|
| 660 |
+
Flashcard, FlashcardCreate, FlashcardUpdate, FlashcardReview,
|
| 661 |
+
FlashcardGenerateRequest,
|
| 662 |
+
Quiz, QuizCreate, QuizGenerateRequest, QuizSubmission, QuizResult, QuizHistory,
|
| 663 |
+
MasteryLevel
|
| 664 |
+
)
|
| 665 |
+
|
| 666 |
+
# ===== NOTEBOOK ROUTES =====
|
| 667 |
+
|
| 668 |
+
@app.post("/api/studio/notebook", response_model=NotebookEntry)
|
| 669 |
+
async def create_notebook_entry(entry_data: NotebookEntryCreate):
|
| 670 |
+
"""Create a new notebook entry"""
|
| 671 |
+
try:
|
| 672 |
+
entry = studio_manager.create_notebook_entry(entry_data)
|
| 673 |
+
return entry
|
| 674 |
+
except Exception as e:
|
| 675 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 676 |
+
|
| 677 |
+
@app.get("/api/studio/notebook/space/{space_id}")
|
| 678 |
+
async def get_space_notebook(space_id: str):
|
| 679 |
+
"""Get or create notebook metadata for a space."""
|
| 680 |
+
try:
|
| 681 |
+
space = spaces_manager.get_space(space_id)
|
| 682 |
+
space_name = space['name'] if space else space_id
|
| 683 |
+
notebook = studio_manager.ensure_space_notebook(space_id, space_name)
|
| 684 |
+
return notebook
|
| 685 |
+
except Exception as e:
|
| 686 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 687 |
+
|
| 688 |
+
@app.post("/api/studio/notebook/from-chat", response_model=NotebookEntry)
|
| 689 |
+
async def add_chat_to_notebook(request: ChatToNotebookRequest):
|
| 690 |
+
"""Add a chat question/answer pair into a space notebook."""
|
| 691 |
+
try:
|
| 692 |
+
space = spaces_manager.get_space(request.space_id)
|
| 693 |
+
resolved_space_name = request.space_name or (space['name'] if space else request.space_id)
|
| 694 |
+
|
| 695 |
+
entry = studio_manager.create_notebook_entry_from_chat(
|
| 696 |
+
space_id=request.space_id,
|
| 697 |
+
question=request.question,
|
| 698 |
+
answer=request.answer,
|
| 699 |
+
chat_id=request.chat_id,
|
| 700 |
+
assistant_timestamp=request.assistant_timestamp,
|
| 701 |
+
tags=request.tags,
|
| 702 |
+
space_name=resolved_space_name
|
| 703 |
+
)
|
| 704 |
+
return entry
|
| 705 |
+
except Exception as e:
|
| 706 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 707 |
+
|
| 708 |
+
@app.get("/api/studio/notebook", response_model=List[NotebookEntry])
|
| 709 |
+
async def list_notebook_entries(space_id: Optional[str] = None):
|
| 710 |
+
"""List all notebook entries, optionally filtered by space"""
|
| 711 |
+
try:
|
| 712 |
+
entries = studio_manager.list_notebook_entries(space_id)
|
| 713 |
+
return entries
|
| 714 |
+
except Exception as e:
|
| 715 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 716 |
+
|
| 717 |
+
@app.get("/api/studio/notebook/{entry_id}", response_model=NotebookEntry)
|
| 718 |
+
async def get_notebook_entry(entry_id: str):
|
| 719 |
+
"""Get a single notebook entry"""
|
| 720 |
+
entry = studio_manager.get_notebook_entry(entry_id)
|
| 721 |
+
if not entry:
|
| 722 |
+
raise HTTPException(status_code=404, detail="Notebook entry not found")
|
| 723 |
+
return entry
|
| 724 |
+
|
| 725 |
+
@app.put("/api/studio/notebook/{entry_id}", response_model=NotebookEntry)
|
| 726 |
+
async def update_notebook_entry(entry_id: str, update_data: NotebookEntryUpdate):
|
| 727 |
+
"""Update a notebook entry"""
|
| 728 |
+
entry = studio_manager.update_notebook_entry(entry_id, update_data)
|
| 729 |
+
if not entry:
|
| 730 |
+
raise HTTPException(status_code=404, detail="Notebook entry not found")
|
| 731 |
+
return entry
|
| 732 |
+
|
| 733 |
+
@app.delete("/api/studio/notebook/{entry_id}")
|
| 734 |
+
async def delete_notebook_entry(entry_id: str):
|
| 735 |
+
"""Delete a notebook entry"""
|
| 736 |
+
success = studio_manager.delete_notebook_entry(entry_id)
|
| 737 |
+
if not success:
|
| 738 |
+
raise HTTPException(status_code=404, detail="Notebook entry not found")
|
| 739 |
+
return {"status": "success", "message": "Notebook entry deleted"}
|
| 740 |
+
|
| 741 |
+
|
| 742 |
+
# ===== FLASHCARD ROUTES =====
|
| 743 |
+
|
| 744 |
+
@app.post("/api/studio/flashcards", response_model=Flashcard)
|
| 745 |
+
async def create_flashcard(card_data: FlashcardCreate):
|
| 746 |
+
"""Create a new flashcard"""
|
| 747 |
+
try:
|
| 748 |
+
card = studio_manager.create_flashcard(card_data)
|
| 749 |
+
return card
|
| 750 |
+
except Exception as e:
|
| 751 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 752 |
+
|
| 753 |
+
@app.get("/api/studio/flashcards", response_model=List[Flashcard])
|
| 754 |
+
async def list_flashcards(
|
| 755 |
+
space_id: Optional[str] = None,
|
| 756 |
+
mastery: Optional[MasteryLevel] = None
|
| 757 |
+
):
|
| 758 |
+
"""List all flashcards, optionally filtered"""
|
| 759 |
+
try:
|
| 760 |
+
cards = studio_manager.list_flashcards(space_id, mastery)
|
| 761 |
+
return cards
|
| 762 |
+
except Exception as e:
|
| 763 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 764 |
+
|
| 765 |
+
@app.get("/api/studio/flashcards/{card_id}", response_model=Flashcard)
|
| 766 |
+
async def get_flashcard(card_id: str):
|
| 767 |
+
"""Get a single flashcard"""
|
| 768 |
+
card = studio_manager.get_flashcard(card_id)
|
| 769 |
+
if not card:
|
| 770 |
+
raise HTTPException(status_code=404, detail="Flashcard not found")
|
| 771 |
+
return card
|
| 772 |
+
|
| 773 |
+
@app.put("/api/studio/flashcards/{card_id}", response_model=Flashcard)
|
| 774 |
+
async def update_flashcard(card_id: str, update_data: FlashcardUpdate):
|
| 775 |
+
"""Update a flashcard"""
|
| 776 |
+
card = studio_manager.update_flashcard(card_id, update_data)
|
| 777 |
+
if not card:
|
| 778 |
+
raise HTTPException(status_code=404, detail="Flashcard not found")
|
| 779 |
+
return card
|
| 780 |
+
|
| 781 |
+
@app.post("/api/studio/flashcards/{card_id}/review", response_model=Flashcard)
|
| 782 |
+
async def review_flashcard(card_id: str, review: FlashcardReview):
|
| 783 |
+
"""Record a flashcard review"""
|
| 784 |
+
card = studio_manager.review_flashcard(card_id, review)
|
| 785 |
+
if not card:
|
| 786 |
+
raise HTTPException(status_code=404, detail="Flashcard not found")
|
| 787 |
+
return card
|
| 788 |
+
|
| 789 |
+
@app.delete("/api/studio/flashcards/{card_id}")
|
| 790 |
+
async def delete_flashcard(card_id: str):
|
| 791 |
+
"""Delete a flashcard"""
|
| 792 |
+
success = studio_manager.delete_flashcard(card_id)
|
| 793 |
+
if not success:
|
| 794 |
+
raise HTTPException(status_code=404, detail="Flashcard not found")
|
| 795 |
+
return {"status": "success", "message": "Flashcard deleted"}
|
| 796 |
+
|
| 797 |
+
@app.post("/api/studio/flashcards/generate", response_model=List[Flashcard])
|
| 798 |
+
async def generate_flashcards(request: FlashcardGenerateRequest):
|
| 799 |
+
"""Generate flashcards from content using LLM"""
|
| 800 |
+
global studio_generator
|
| 801 |
+
|
| 802 |
+
if not studio_generator:
|
| 803 |
+
raise HTTPException(status_code=503, detail="LLM not initialized")
|
| 804 |
+
|
| 805 |
+
try:
|
| 806 |
+
cards = await studio_generator.generate_flashcards(request)
|
| 807 |
+
return cards
|
| 808 |
+
except Exception as e:
|
| 809 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 810 |
+
|
| 811 |
+
|
| 812 |
+
# ===== QUIZ ROUTES =====
|
| 813 |
+
|
| 814 |
+
@app.post("/api/studio/quizzes", response_model=Quiz)
|
| 815 |
+
async def create_quiz(quiz_data: QuizCreate):
|
| 816 |
+
"""Create a new quiz"""
|
| 817 |
+
try:
|
| 818 |
+
quiz = studio_manager.create_quiz(quiz_data)
|
| 819 |
+
return quiz
|
| 820 |
+
except Exception as e:
|
| 821 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 822 |
+
|
| 823 |
+
@app.get("/api/studio/quizzes", response_model=List[Quiz])
|
| 824 |
+
async def list_quizzes(space_id: Optional[str] = None):
|
| 825 |
+
"""List all quizzes, optionally filtered by space"""
|
| 826 |
+
try:
|
| 827 |
+
quizzes = studio_manager.list_quizzes(space_id)
|
| 828 |
+
return quizzes
|
| 829 |
+
except Exception as e:
|
| 830 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 831 |
+
|
| 832 |
+
@app.get("/api/studio/quizzes/{quiz_id}", response_model=Quiz)
|
| 833 |
+
async def get_quiz(quiz_id: str):
|
| 834 |
+
"""Get a quiz by ID"""
|
| 835 |
+
quiz = studio_manager.get_quiz(quiz_id)
|
| 836 |
+
if not quiz:
|
| 837 |
+
raise HTTPException(status_code=404, detail="Quiz not found")
|
| 838 |
+
return quiz
|
| 839 |
+
|
| 840 |
+
@app.delete("/api/studio/quizzes/{quiz_id}")
|
| 841 |
+
async def delete_quiz(quiz_id: str):
|
| 842 |
+
"""Delete a quiz"""
|
| 843 |
+
success = studio_manager.delete_quiz(quiz_id)
|
| 844 |
+
if not success:
|
| 845 |
+
raise HTTPException(status_code=404, detail="Quiz not found")
|
| 846 |
+
return {"status": "success", "message": "Quiz deleted"}
|
| 847 |
+
|
| 848 |
+
@app.post("/api/studio/quizzes/generate", response_model=Quiz)
|
| 849 |
+
async def generate_quiz(request: QuizGenerateRequest):
|
| 850 |
+
"""Generate a quiz from content using LLM"""
|
| 851 |
+
global studio_generator
|
| 852 |
+
|
| 853 |
+
if not studio_generator:
|
| 854 |
+
raise HTTPException(status_code=503, detail="LLM not initialized")
|
| 855 |
+
|
| 856 |
+
try:
|
| 857 |
+
quiz = await studio_generator.generate_quiz(request)
|
| 858 |
+
if not quiz:
|
| 859 |
+
raise HTTPException(status_code=500, detail="Failed to generate quiz")
|
| 860 |
+
return quiz
|
| 861 |
+
except Exception as e:
|
| 862 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 863 |
+
|
| 864 |
+
@app.post("/api/studio/quizzes/{quiz_id}/submit", response_model=QuizResult)
|
| 865 |
+
async def submit_quiz(quiz_id: str, submission: QuizSubmission):
|
| 866 |
+
"""Submit quiz answers and get results"""
|
| 867 |
+
try:
|
| 868 |
+
result = studio_manager.submit_quiz(quiz_id, submission.answers)
|
| 869 |
+
if not result:
|
| 870 |
+
raise HTTPException(status_code=404, detail="Quiz not found")
|
| 871 |
+
return result
|
| 872 |
+
except Exception as e:
|
| 873 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 874 |
+
|
| 875 |
+
@app.get("/api/studio/quizzes/{quiz_id}/history", response_model=QuizHistory)
|
| 876 |
+
async def get_quiz_history(quiz_id: str):
|
| 877 |
+
"""Get quiz attempt history"""
|
| 878 |
+
try:
|
| 879 |
+
history = studio_manager.get_quiz_history(quiz_id)
|
| 880 |
+
if not history:
|
| 881 |
+
raise HTTPException(status_code=404, detail="Quiz not found")
|
| 882 |
+
return history
|
| 883 |
+
except HTTPException as he:
|
| 884 |
+
# If the error is already an HTTPException (like the missing API key error), pass it through directly
|
| 885 |
+
raise he
|
| 886 |
+
except Exception as e:
|
| 887 |
+
# For all other crashes, print the actual traceback to the terminal so you can see what broke
|
| 888 |
+
import traceback
|
| 889 |
+
traceback.print_exc()
|
| 890 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 891 |
+
|
| 892 |
+
# ==================== Run Server ====================
|
| 893 |
+
|
| 894 |
+
if __name__ == "__main__":
|
| 895 |
+
import uvicorn
|
| 896 |
+
uvicorn.run(app, host="0.0.0.0", port=8000, log_level="error")
|
models/__pycache__/studio_models.cpython-311.pyc
ADDED
|
Binary file (15.4 kB). View file
|
|
|
models/studio_models.py
ADDED
|
@@ -0,0 +1,219 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Studio Models - Notebook, Flashcards, Quiz
|
| 3 |
+
These models represent the core Studio features for NotebookPRO
|
| 4 |
+
"""
|
| 5 |
+
from pydantic import BaseModel, Field
|
| 6 |
+
from typing import List, Optional, Dict, Any
|
| 7 |
+
from datetime import datetime
|
| 8 |
+
from enum import Enum
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
# ============================================================================
|
| 12 |
+
# NOTEBOOK MODELS
|
| 13 |
+
# ============================================================================
|
| 14 |
+
|
| 15 |
+
class NotebookEntry(BaseModel):
|
| 16 |
+
"""A single note entry in the notebook"""
|
| 17 |
+
id: str = Field(..., description="Unique identifier for the note")
|
| 18 |
+
space_id: str = Field(..., description="Space this note belongs to")
|
| 19 |
+
title: str = Field(..., description="Title of the note")
|
| 20 |
+
content: str = Field(..., description="Main content/body of the note")
|
| 21 |
+
source_type: str = Field(default="manual", description="Source: manual, chat, generated")
|
| 22 |
+
source_id: Optional[str] = Field(None, description="ID of source (e.g., chat message ID)")
|
| 23 |
+
tags: List[str] = Field(default_factory=list, description="Tags for categorization")
|
| 24 |
+
created_at: datetime = Field(default_factory=datetime.now)
|
| 25 |
+
updated_at: datetime = Field(default_factory=datetime.now)
|
| 26 |
+
metadata: Dict[str, Any] = Field(default_factory=dict)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class NotebookEntryCreate(BaseModel):
|
| 30 |
+
"""Request model for creating a notebook entry"""
|
| 31 |
+
space_id: str
|
| 32 |
+
title: str
|
| 33 |
+
content: str
|
| 34 |
+
source_type: str = "manual"
|
| 35 |
+
source_id: Optional[str] = None
|
| 36 |
+
tags: List[str] = []
|
| 37 |
+
metadata: Dict[str, Any] = {}
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class NotebookEntryUpdate(BaseModel):
|
| 41 |
+
"""Request model for updating a notebook entry"""
|
| 42 |
+
title: Optional[str] = None
|
| 43 |
+
content: Optional[str] = None
|
| 44 |
+
tags: Optional[List[str]] = None
|
| 45 |
+
metadata: Optional[Dict[str, Any]] = None
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
# ============================================================================
|
| 49 |
+
# FLASHCARD MODELS
|
| 50 |
+
# ============================================================================
|
| 51 |
+
|
| 52 |
+
class DifficultyLevel(str, Enum):
|
| 53 |
+
"""Difficulty level for flashcards"""
|
| 54 |
+
EASY = "easy"
|
| 55 |
+
MEDIUM = "medium"
|
| 56 |
+
HARD = "hard"
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
class MasteryLevel(str, Enum):
|
| 60 |
+
"""User's mastery level for a flashcard"""
|
| 61 |
+
NEW = "new"
|
| 62 |
+
LEARNING = "learning"
|
| 63 |
+
REVIEWING = "reviewing"
|
| 64 |
+
MASTERED = "mastered"
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
class Flashcard(BaseModel):
|
| 68 |
+
"""A single flashcard for memorization"""
|
| 69 |
+
id: str = Field(..., description="Unique identifier")
|
| 70 |
+
space_id: str = Field(..., description="Space this flashcard belongs to")
|
| 71 |
+
question: str = Field(..., description="Front of the card (question/prompt)")
|
| 72 |
+
answer: str = Field(..., description="Back of the card (answer/explanation)")
|
| 73 |
+
difficulty: DifficultyLevel = Field(default=DifficultyLevel.MEDIUM)
|
| 74 |
+
mastery: MasteryLevel = Field(default=MasteryLevel.NEW)
|
| 75 |
+
source_type: str = Field(default="manual", description="Source: manual, generated, notebook")
|
| 76 |
+
source_id: Optional[str] = Field(None, description="Source ID (e.g., notebook entry ID)")
|
| 77 |
+
tags: List[str] = Field(default_factory=list)
|
| 78 |
+
review_count: int = Field(default=0, description="Number of times reviewed")
|
| 79 |
+
correct_count: int = Field(default=0, description="Number of times answered correctly")
|
| 80 |
+
last_reviewed: Optional[datetime] = None
|
| 81 |
+
next_review: Optional[datetime] = None
|
| 82 |
+
created_at: datetime = Field(default_factory=datetime.now)
|
| 83 |
+
metadata: Dict[str, Any] = Field(default_factory=dict)
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
class FlashcardCreate(BaseModel):
|
| 87 |
+
"""Request model for creating a flashcard"""
|
| 88 |
+
space_id: str
|
| 89 |
+
question: str
|
| 90 |
+
answer: str
|
| 91 |
+
difficulty: DifficultyLevel = DifficultyLevel.MEDIUM
|
| 92 |
+
source_type: str = "manual"
|
| 93 |
+
source_id: Optional[str] = None
|
| 94 |
+
tags: List[str] = []
|
| 95 |
+
metadata: Dict[str, Any] = {}
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
class FlashcardUpdate(BaseModel):
|
| 99 |
+
"""Request model for updating a flashcard"""
|
| 100 |
+
question: Optional[str] = None
|
| 101 |
+
answer: Optional[str] = None
|
| 102 |
+
difficulty: Optional[DifficultyLevel] = None
|
| 103 |
+
mastery: Optional[MasteryLevel] = None
|
| 104 |
+
tags: Optional[List[str]] = None
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
class FlashcardReview(BaseModel):
|
| 108 |
+
"""Request model for reviewing a flashcard"""
|
| 109 |
+
correct: bool = Field(..., description="Whether the user answered correctly")
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
class FlashcardGenerateRequest(BaseModel):
|
| 113 |
+
"""Request to generate flashcards from content"""
|
| 114 |
+
space_id: str
|
| 115 |
+
source_type: str = Field(..., description="Source type: notebook, file, text")
|
| 116 |
+
source_ids: Optional[List[str]] = Field(None, description="IDs of notebook entries or files")
|
| 117 |
+
text_content: Optional[str] = Field(None, description="Direct text content to generate from")
|
| 118 |
+
num_cards: int = Field(default=5, description="Number of flashcards to generate")
|
| 119 |
+
difficulty: DifficultyLevel = DifficultyLevel.MEDIUM
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
# ============================================================================
|
| 123 |
+
# QUIZ MODELS
|
| 124 |
+
# ============================================================================
|
| 125 |
+
|
| 126 |
+
class QuestionType(str, Enum):
|
| 127 |
+
"""Type of quiz question"""
|
| 128 |
+
MULTIPLE_CHOICE = "multiple_choice"
|
| 129 |
+
TRUE_FALSE = "true_false"
|
| 130 |
+
SHORT_ANSWER = "short_answer"
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
class QuizQuestion(BaseModel):
|
| 134 |
+
"""A single quiz question"""
|
| 135 |
+
id: str = Field(..., description="Unique identifier")
|
| 136 |
+
question: str = Field(..., description="Question text")
|
| 137 |
+
type: QuestionType = Field(..., description="Question type")
|
| 138 |
+
options: Optional[List[str]] = Field(None, description="Options for multiple choice")
|
| 139 |
+
correct_answer: str = Field(..., description="Correct answer")
|
| 140 |
+
explanation: Optional[str] = Field(None, description="Explanation of the answer")
|
| 141 |
+
points: int = Field(default=1, description="Points for this question")
|
| 142 |
+
difficulty: DifficultyLevel = Field(default=DifficultyLevel.MEDIUM)
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
class Quiz(BaseModel):
|
| 146 |
+
"""A quiz session"""
|
| 147 |
+
id: str = Field(..., description="Unique identifier")
|
| 148 |
+
space_id: str = Field(..., description="Space this quiz belongs to")
|
| 149 |
+
title: str = Field(..., description="Quiz title")
|
| 150 |
+
description: Optional[str] = None
|
| 151 |
+
questions: List[QuizQuestion] = Field(..., description="List of questions")
|
| 152 |
+
source_type: str = Field(default="manual", description="Source: manual, generated, notebook, file")
|
| 153 |
+
source_ids: Optional[List[str]] = None
|
| 154 |
+
created_at: datetime = Field(default_factory=datetime.now)
|
| 155 |
+
metadata: Dict[str, Any] = Field(default_factory=dict)
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
class QuizCreate(BaseModel):
|
| 159 |
+
"""Request model for creating a quiz"""
|
| 160 |
+
space_id: str
|
| 161 |
+
title: str
|
| 162 |
+
description: Optional[str] = None
|
| 163 |
+
questions: List[QuizQuestion] = []
|
| 164 |
+
source_type: str = "manual"
|
| 165 |
+
source_ids: Optional[List[str]] = None
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
class QuizGenerateRequest(BaseModel):
|
| 169 |
+
"""Request to generate a quiz from content"""
|
| 170 |
+
space_id: str
|
| 171 |
+
title: str
|
| 172 |
+
source_type: str = Field(..., description="Source type: notebook, file, text")
|
| 173 |
+
source_ids: Optional[List[str]] = Field(None, description="IDs of notebook entries or files")
|
| 174 |
+
text_content: Optional[str] = Field(None, description="Direct text content")
|
| 175 |
+
num_questions: int = Field(default=5, description="Number of questions")
|
| 176 |
+
question_types: List[QuestionType] = Field(
|
| 177 |
+
default=[QuestionType.MULTIPLE_CHOICE],
|
| 178 |
+
description="Types of questions to include"
|
| 179 |
+
)
|
| 180 |
+
difficulty: DifficultyLevel = DifficultyLevel.MEDIUM
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
class QuizAnswer(BaseModel):
|
| 184 |
+
"""User's answer to a quiz question"""
|
| 185 |
+
question_id: str
|
| 186 |
+
answer: str
|
| 187 |
+
time_spent: Optional[int] = Field(None, description="Time spent in seconds")
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
class QuizSubmission(BaseModel):
|
| 191 |
+
"""User's quiz submission"""
|
| 192 |
+
quiz_id: str
|
| 193 |
+
answers: List[QuizAnswer]
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
class QuizResult(BaseModel):
|
| 197 |
+
"""Result of a quiz submission"""
|
| 198 |
+
quiz_id: str
|
| 199 |
+
submission_id: str = Field(..., description="Unique submission ID")
|
| 200 |
+
total_questions: int
|
| 201 |
+
correct_answers: int
|
| 202 |
+
incorrect_answers: int
|
| 203 |
+
score_percentage: float
|
| 204 |
+
total_points: int
|
| 205 |
+
earned_points: int
|
| 206 |
+
answers: List[Dict[str, Any]] = Field(..., description="Detailed answer results")
|
| 207 |
+
completed_at: datetime = Field(default_factory=datetime.now)
|
| 208 |
+
time_taken: Optional[int] = Field(None, description="Total time in seconds")
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
class QuizHistory(BaseModel):
|
| 212 |
+
"""Quiz attempt history"""
|
| 213 |
+
quiz_id: str
|
| 214 |
+
space_id: str
|
| 215 |
+
quiz_title: str
|
| 216 |
+
results: List[QuizResult] = Field(default_factory=list)
|
| 217 |
+
best_score: float = Field(default=0.0)
|
| 218 |
+
average_score: float = Field(default=0.0)
|
| 219 |
+
attempts_count: int = Field(default=0)
|
requirements.txt
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# FastAPI Framework - version compatible with Pydantic v1
|
| 2 |
+
fastapi==0.104.1
|
| 3 |
+
uvicorn[standard]==0.24.0
|
| 4 |
+
python-multipart==0.0.6
|
| 5 |
+
pydantic==1.10.13
|
| 6 |
+
|
| 7 |
+
# HTTP client - version compatible with groq
|
| 8 |
+
httpx==0.25.2
|
| 9 |
+
|
| 10 |
+
# Ngrok tunnel for public access
|
| 11 |
+
pyngrok==7.0.0
|
| 12 |
+
|
| 13 |
+
# Vector database - Pydantic v1 compatible version
|
| 14 |
+
|
| 15 |
+
# ML/AI dependencies - compatible versions
|
| 16 |
+
sentence-transformers==2.7.0
|
| 17 |
+
huggingface-hub==0.23.0
|
| 18 |
+
rank-bm25==0.2.2
|
| 19 |
+
|
| 20 |
+
# LLM providers
|
| 21 |
+
groq==1.1.1
|
| 22 |
+
google-generativeai==0.3.2
|
| 23 |
+
|
| 24 |
+
# Streamlit for UI components
|
| 25 |
+
streamlit==1.31.0
|
| 26 |
+
|
| 27 |
+
# File processing
|
| 28 |
+
PyPDF2==3.0.1
|
| 29 |
+
pdfplumber==0.10.3
|
| 30 |
+
python-docx==1.1.0
|
| 31 |
+
|
| 32 |
+
# Environment variables
|
| 33 |
+
python-dotenv==1.0.0
|
| 34 |
+
|
| 35 |
+
# REMOVE this:
|
| 36 |
+
# chromadb==0.3.21
|
| 37 |
+
|
| 38 |
+
# ADD these:
|
| 39 |
+
qdrant-client==1.8.0
|
runtime.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
3.11.0
|
start_ngrok_tunnel.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Alternative tunnel using ngrok (requires one-time free signup).
|
| 3 |
+
|
| 4 |
+
Setup (one time only):
|
| 5 |
+
1. Sign up at https://ngrok.com (free)
|
| 6 |
+
2. Get your authtoken from dashboard
|
| 7 |
+
3. Run this once: ngrok config add-authtoken YOUR_TOKEN_HERE
|
| 8 |
+
|
| 9 |
+
Then run this script to start the tunnel.
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
from pyngrok import ngrok, conf
|
| 13 |
+
import time
|
| 14 |
+
import json
|
| 15 |
+
from pathlib import Path
|
| 16 |
+
|
| 17 |
+
def start_ngrok_tunnel(port=8000):
|
| 18 |
+
"""Start ngrok tunnel for the backend."""
|
| 19 |
+
|
| 20 |
+
try:
|
| 21 |
+
# Close any existing tunnels
|
| 22 |
+
ngrok.kill()
|
| 23 |
+
|
| 24 |
+
# Start tunnel (use default free tier settings - no subdomain)
|
| 25 |
+
print(f"Starting ngrok tunnel for port {port}...")
|
| 26 |
+
# Just use basic http connection without any domain options
|
| 27 |
+
tunnel = ngrok.connect(port, bind_tls=True)
|
| 28 |
+
public_url = tunnel.public_url
|
| 29 |
+
|
| 30 |
+
print("\n" + "="*60)
|
| 31 |
+
print("🚀 Backend Tunnel Started!")
|
| 32 |
+
print("="*60)
|
| 33 |
+
print(f"Public URL: {public_url}")
|
| 34 |
+
print(f"Local URL: http://localhost:{port}")
|
| 35 |
+
print("="*60)
|
| 36 |
+
print("\n✅ Update your Flutter app:")
|
| 37 |
+
print(f' static const String baseUrl = "{public_url}";')
|
| 38 |
+
print("=" * 60)
|
| 39 |
+
print("\nPress Ctrl+C to stop the tunnel")
|
| 40 |
+
print("="*60 + "\n")
|
| 41 |
+
|
| 42 |
+
# Save config
|
| 43 |
+
config_file = Path(__file__).parent.parent / "tunnel_config.json"
|
| 44 |
+
config = {
|
| 45 |
+
"backend_url": public_url,
|
| 46 |
+
"created_at": time.strftime("%Y-%m-%d %H:%M:%S")
|
| 47 |
+
}
|
| 48 |
+
|
| 49 |
+
with open(config_file, 'w') as f:
|
| 50 |
+
json.dump(config, f, indent=2)
|
| 51 |
+
|
| 52 |
+
try:
|
| 53 |
+
while True:
|
| 54 |
+
time.sleep(1)
|
| 55 |
+
except KeyboardInterrupt:
|
| 56 |
+
print("\n\n✨ Shutting down tunnel...")
|
| 57 |
+
ngrok.kill()
|
| 58 |
+
print("✅ Tunnel closed\n")
|
| 59 |
+
|
| 60 |
+
except Exception as e:
|
| 61 |
+
print(f"\n❌ Error: {e}\n")
|
| 62 |
+
print("Setup ngrok authentication:")
|
| 63 |
+
print("1. Sign up at https://ngrok.com (free)")
|
| 64 |
+
print("2. Get your authtoken from the dashboard")
|
| 65 |
+
print("3. Run: ngrok config add-authtoken YOUR_TOKEN")
|
| 66 |
+
print("4. Run this script again\n")
|
| 67 |
+
|
| 68 |
+
if __name__ == "__main__":
|
| 69 |
+
start_ngrok_tunnel()
|
utils/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""Utility modules for document processing and vector database operations."""
|
utils/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (231 Bytes). View file
|
|
|
utils/__pycache__/__init__.cpython-314.pyc
ADDED
|
Binary file (226 Bytes). View file
|
|
|
utils/__pycache__/config_manager.cpython-311.pyc
ADDED
|
Binary file (5.2 kB). View file
|
|
|
utils/__pycache__/config_manager.cpython-314.pyc
ADDED
|
Binary file (6.09 kB). View file
|
|
|
utils/__pycache__/document_processor.cpython-311.pyc
ADDED
|
Binary file (10.9 kB). View file
|
|
|
utils/__pycache__/document_processor.cpython-314.pyc
ADDED
|
Binary file (11.1 kB). View file
|
|
|
utils/__pycache__/hybrid_retriever.cpython-311.pyc
ADDED
|
Binary file (7.29 kB). View file
|
|
|
utils/__pycache__/hybrid_retriever.cpython-314.pyc
ADDED
|
Binary file (6.87 kB). View file
|
|
|
utils/__pycache__/llm_generator.cpython-311.pyc
ADDED
|
Binary file (13.5 kB). View file
|
|
|
utils/__pycache__/llm_generator.cpython-314.pyc
ADDED
|
Binary file (13.8 kB). View file
|
|
|
utils/__pycache__/model_inference.cpython-311.pyc
ADDED
|
Binary file (6.86 kB). View file
|
|
|
utils/__pycache__/simple_generator.cpython-311.pyc
ADDED
|
Binary file (23.9 kB). View file
|
|
|
utils/__pycache__/spaces_manager.cpython-311.pyc
ADDED
|
Binary file (7.54 kB). View file
|
|
|
utils/__pycache__/spaces_manager.cpython-314.pyc
ADDED
|
Binary file (8.57 kB). View file
|
|
|
utils/__pycache__/studio_generator.cpython-311.pyc
ADDED
|
Binary file (12 kB). View file
|
|
|
utils/__pycache__/studio_manager.cpython-311.pyc
ADDED
|
Binary file (26.7 kB). View file
|
|
|
utils/__pycache__/vector_db.cpython-311.pyc
ADDED
|
Binary file (8.14 kB). View file
|
|
|
utils/__pycache__/vector_db.cpython-314.pyc
ADDED
|
Binary file (6.6 kB). View file
|
|
|
utils/chat_manager.py
ADDED
|
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Chat management utilities for NotebookPRO.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import json
|
| 6 |
+
from datetime import datetime
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from typing import List, Dict, Optional
|
| 9 |
+
import config
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class ChatManager:
|
| 13 |
+
"""Manage chat sessions and history."""
|
| 14 |
+
|
| 15 |
+
def __init__(self):
|
| 16 |
+
self.chats_dir = config.CHATS_DIR
|
| 17 |
+
self.chats_dir.mkdir(parents=True, exist_ok=True)
|
| 18 |
+
|
| 19 |
+
def save_chat(self, chat_id: str, messages: List[Dict], space: Optional[str] = None) -> None:
|
| 20 |
+
"""
|
| 21 |
+
Save a chat session.
|
| 22 |
+
|
| 23 |
+
Args:
|
| 24 |
+
chat_id: Unique chat identifier
|
| 25 |
+
messages: List of message dictionaries
|
| 26 |
+
space: Optional space/subject name
|
| 27 |
+
"""
|
| 28 |
+
chat_data = {
|
| 29 |
+
'id': chat_id,
|
| 30 |
+
'messages': messages,
|
| 31 |
+
'space': space,
|
| 32 |
+
'created_at': datetime.now().isoformat(),
|
| 33 |
+
'updated_at': datetime.now().isoformat()
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
chat_file = self.chats_dir / f"{chat_id}.json"
|
| 37 |
+
with open(chat_file, 'w', encoding='utf-8') as f:
|
| 38 |
+
json.dump(chat_data, f, indent=2, ensure_ascii=False)
|
| 39 |
+
|
| 40 |
+
def load_chat(self, chat_id: str) -> Optional[Dict]:
|
| 41 |
+
"""
|
| 42 |
+
Load a chat session.
|
| 43 |
+
|
| 44 |
+
Args:
|
| 45 |
+
chat_id: Unique chat identifier
|
| 46 |
+
|
| 47 |
+
Returns:
|
| 48 |
+
Chat data dictionary or None if not found
|
| 49 |
+
"""
|
| 50 |
+
chat_file = self.chats_dir / f"{chat_id}.json"
|
| 51 |
+
|
| 52 |
+
if not chat_file.exists():
|
| 53 |
+
return None
|
| 54 |
+
|
| 55 |
+
with open(chat_file, 'r', encoding='utf-8') as f:
|
| 56 |
+
return json.load(f)
|
| 57 |
+
|
| 58 |
+
def list_chats(self, space: Optional[str] = None) -> List[Dict]:
|
| 59 |
+
"""
|
| 60 |
+
List all chats, optionally filtered by space.
|
| 61 |
+
|
| 62 |
+
Args:
|
| 63 |
+
space: Optional space filter
|
| 64 |
+
|
| 65 |
+
Returns:
|
| 66 |
+
List of chat metadata dictionaries
|
| 67 |
+
"""
|
| 68 |
+
chats = []
|
| 69 |
+
|
| 70 |
+
for chat_file in self.chats_dir.glob("*.json"):
|
| 71 |
+
with open(chat_file, 'r', encoding='utf-8') as f:
|
| 72 |
+
chat_data = json.load(f)
|
| 73 |
+
|
| 74 |
+
if space is None or chat_data.get('space') == space:
|
| 75 |
+
chats.append({
|
| 76 |
+
'id': chat_data['id'],
|
| 77 |
+
'space': chat_data.get('space'),
|
| 78 |
+
'message_count': len(chat_data['messages']),
|
| 79 |
+
'created_at': chat_data.get('created_at'),
|
| 80 |
+
'updated_at': chat_data.get('updated_at')
|
| 81 |
+
})
|
| 82 |
+
|
| 83 |
+
# Sort by updated time (most recent first)
|
| 84 |
+
chats.sort(key=lambda x: x.get('updated_at', ''), reverse=True)
|
| 85 |
+
|
| 86 |
+
return chats
|
| 87 |
+
|
| 88 |
+
def delete_chat(self, chat_id: str) -> bool:
|
| 89 |
+
"""
|
| 90 |
+
Delete a chat session.
|
| 91 |
+
|
| 92 |
+
Args:
|
| 93 |
+
chat_id: Unique chat identifier
|
| 94 |
+
|
| 95 |
+
Returns:
|
| 96 |
+
True if deleted, False if not found
|
| 97 |
+
"""
|
| 98 |
+
chat_file = self.chats_dir / f"{chat_id}.json"
|
| 99 |
+
|
| 100 |
+
if chat_file.exists():
|
| 101 |
+
chat_file.unlink()
|
| 102 |
+
return True
|
| 103 |
+
|
| 104 |
+
return False
|
| 105 |
+
|
| 106 |
+
def get_chat_preview(self, chat_id: str, max_messages: int = 5) -> Optional[List[Dict]]:
|
| 107 |
+
"""
|
| 108 |
+
Get a preview of recent messages from a chat.
|
| 109 |
+
|
| 110 |
+
Args:
|
| 111 |
+
chat_id: Unique chat identifier
|
| 112 |
+
max_messages: Maximum number of messages to return
|
| 113 |
+
|
| 114 |
+
Returns:
|
| 115 |
+
List of recent messages or None if chat not found
|
| 116 |
+
"""
|
| 117 |
+
chat_data = self.load_chat(chat_id)
|
| 118 |
+
|
| 119 |
+
if chat_data is None:
|
| 120 |
+
return None
|
| 121 |
+
|
| 122 |
+
messages = chat_data['messages']
|
| 123 |
+
return messages[-max_messages:] if len(messages) > max_messages else messages
|
utils/config_manager.py
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Configuration manager for persistent settings (API keys, preferences).
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import json
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from typing import Dict, Optional
|
| 8 |
+
import config
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class ConfigManager:
|
| 12 |
+
"""Manages persistent user configuration."""
|
| 13 |
+
|
| 14 |
+
def __init__(self):
|
| 15 |
+
self.config_file = config.DATA_DIR / "user_config.json"
|
| 16 |
+
self.config_data = self._load_config()
|
| 17 |
+
|
| 18 |
+
def _load_config(self) -> Dict:
|
| 19 |
+
"""Load configuration from file."""
|
| 20 |
+
if self.config_file.exists():
|
| 21 |
+
try:
|
| 22 |
+
with open(self.config_file, 'r', encoding='utf-8') as f:
|
| 23 |
+
return json.load(f)
|
| 24 |
+
except Exception:
|
| 25 |
+
return self._default_config()
|
| 26 |
+
return self._default_config()
|
| 27 |
+
|
| 28 |
+
def _default_config(self) -> Dict:
|
| 29 |
+
"""Default configuration."""
|
| 30 |
+
return {
|
| 31 |
+
"api_keys": {
|
| 32 |
+
"groq": "",
|
| 33 |
+
"gemini": ""
|
| 34 |
+
},
|
| 35 |
+
"preferences": {
|
| 36 |
+
"llm_provider": "groq",
|
| 37 |
+
"temperature": 0.7,
|
| 38 |
+
"workflow": "Auto-Detect"
|
| 39 |
+
},
|
| 40 |
+
"current_space": "General"
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
def save_config(self):
|
| 44 |
+
"""Save configuration to file."""
|
| 45 |
+
try:
|
| 46 |
+
with open(self.config_file, 'w', encoding='utf-8') as f:
|
| 47 |
+
json.dump(self.config_data, f, indent=2)
|
| 48 |
+
except Exception as e:
|
| 49 |
+
print(f"Error saving config: {e}")
|
| 50 |
+
|
| 51 |
+
def get_api_key(self, provider: str) -> str:
|
| 52 |
+
"""Get API key for provider."""
|
| 53 |
+
return self.config_data.get("api_keys", {}).get(provider, "")
|
| 54 |
+
|
| 55 |
+
def set_api_key(self, provider: str, api_key: str):
|
| 56 |
+
"""Save API key for provider."""
|
| 57 |
+
if "api_keys" not in self.config_data:
|
| 58 |
+
self.config_data["api_keys"] = {}
|
| 59 |
+
self.config_data["api_keys"][provider] = api_key
|
| 60 |
+
self.save_config()
|
| 61 |
+
|
| 62 |
+
def get_preference(self, key: str, default=None):
|
| 63 |
+
"""Get user preference."""
|
| 64 |
+
return self.config_data.get("preferences", {}).get(key, default)
|
| 65 |
+
|
| 66 |
+
def set_preference(self, key: str, value):
|
| 67 |
+
"""Save user preference."""
|
| 68 |
+
if "preferences" not in self.config_data:
|
| 69 |
+
self.config_data["preferences"] = {}
|
| 70 |
+
self.config_data["preferences"][key] = value
|
| 71 |
+
self.save_config()
|
| 72 |
+
|
| 73 |
+
def get_current_space(self) -> str:
|
| 74 |
+
"""Get current workspace."""
|
| 75 |
+
return self.config_data.get("current_space", "General")
|
| 76 |
+
|
| 77 |
+
def set_current_space(self, space_name: str):
|
| 78 |
+
"""Set current workspace."""
|
| 79 |
+
self.config_data["current_space"] = space_name
|
| 80 |
+
self.save_config()
|
utils/document_processor.py
ADDED
|
@@ -0,0 +1,222 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import PyPDF2
|
| 2 |
+
import pdfplumber
|
| 3 |
+
from docx import Document
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from typing import List, Dict
|
| 6 |
+
import re
|
| 7 |
+
import warnings
|
| 8 |
+
import logging
|
| 9 |
+
|
| 10 |
+
# Suppress PyPDF2 warnings about font descriptors
|
| 11 |
+
warnings.filterwarnings('ignore', category=UserWarning, module='PyPDF2')
|
| 12 |
+
logging.getLogger('PyPDF2').setLevel(logging.ERROR)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class DocumentProcessor:
|
| 16 |
+
"""Process various document types and extract text content."""
|
| 17 |
+
|
| 18 |
+
def __init__(self):
|
| 19 |
+
self.supported_formats = ['.pdf', '.txt', '.docx']
|
| 20 |
+
|
| 21 |
+
def process_file(self, file_path: Path) -> Dict[str, any]:
|
| 22 |
+
"""
|
| 23 |
+
Process a single file and extract its content.
|
| 24 |
+
|
| 25 |
+
Args:
|
| 26 |
+
file_path: Path to the file
|
| 27 |
+
|
| 28 |
+
Returns:
|
| 29 |
+
Dictionary containing file metadata and content
|
| 30 |
+
"""
|
| 31 |
+
suffix = file_path.suffix.lower()
|
| 32 |
+
|
| 33 |
+
if suffix == '.pdf':
|
| 34 |
+
content = self._extract_pdf(file_path)
|
| 35 |
+
elif suffix == '.txt':
|
| 36 |
+
content = self._extract_txt(file_path)
|
| 37 |
+
elif suffix == '.docx':
|
| 38 |
+
content = self._extract_docx(file_path)
|
| 39 |
+
else:
|
| 40 |
+
raise ValueError(f"Unsupported file format: {suffix}")
|
| 41 |
+
|
| 42 |
+
return {
|
| 43 |
+
'filename': file_path.name,
|
| 44 |
+
'path': str(file_path),
|
| 45 |
+
'content': content,
|
| 46 |
+
'format': suffix
|
| 47 |
+
}
|
| 48 |
+
|
| 49 |
+
def _extract_pdf(self, file_path: Path) -> str:
|
| 50 |
+
"""Extract text from PDF using pdfplumber with PyPDF2 fallback."""
|
| 51 |
+
text = ""
|
| 52 |
+
try:
|
| 53 |
+
# Primary: Use pdfplumber (better for complex PDFs)
|
| 54 |
+
with pdfplumber.open(file_path) as pdf:
|
| 55 |
+
for page in pdf.pages:
|
| 56 |
+
page_text = page.extract_text()
|
| 57 |
+
if page_text:
|
| 58 |
+
text += page_text + "\n"
|
| 59 |
+
except Exception as e:
|
| 60 |
+
# Fallback: Use PyPDF2 with warnings suppressed
|
| 61 |
+
try:
|
| 62 |
+
with warnings.catch_warnings():
|
| 63 |
+
warnings.simplefilter("ignore")
|
| 64 |
+
with open(file_path, 'rb') as file:
|
| 65 |
+
pdf_reader = PyPDF2.PdfReader(file)
|
| 66 |
+
for page in pdf_reader.pages:
|
| 67 |
+
try:
|
| 68 |
+
page_text = page.extract_text()
|
| 69 |
+
if page_text:
|
| 70 |
+
text += page_text + "\n"
|
| 71 |
+
except Exception:
|
| 72 |
+
continue # Skip problematic pages
|
| 73 |
+
except Exception as e2:
|
| 74 |
+
raise ValueError(f"Could not extract text from PDF: {file_path.name}")
|
| 75 |
+
|
| 76 |
+
return self._clean_text(text)
|
| 77 |
+
|
| 78 |
+
def _extract_txt(self, file_path: Path) -> str:
|
| 79 |
+
"""Extract text from TXT file."""
|
| 80 |
+
try:
|
| 81 |
+
with open(file_path, 'r', encoding='utf-8') as file:
|
| 82 |
+
text = file.read()
|
| 83 |
+
except UnicodeDecodeError:
|
| 84 |
+
with open(file_path, 'r', encoding='latin-1') as file:
|
| 85 |
+
text = file.read()
|
| 86 |
+
|
| 87 |
+
return self._clean_text(text)
|
| 88 |
+
|
| 89 |
+
def _extract_docx(self, file_path: Path) -> str:
|
| 90 |
+
"""Extract text from DOCX file."""
|
| 91 |
+
doc = Document(file_path)
|
| 92 |
+
text = "\n".join([paragraph.text for paragraph in doc.paragraphs])
|
| 93 |
+
return self._clean_text(text)
|
| 94 |
+
|
| 95 |
+
def _clean_text(self, text: str) -> str:
|
| 96 |
+
"""Clean and normalize text."""
|
| 97 |
+
# Remove excessive whitespace
|
| 98 |
+
text = re.sub(r'\s+', ' ', text)
|
| 99 |
+
# Remove special characters but keep punctuation
|
| 100 |
+
text = re.sub(r'[^\w\s.,!?;:()\-\'\"]+', '', text)
|
| 101 |
+
return text.strip()
|
| 102 |
+
|
| 103 |
+
def chunk_text(self, text: str, chunk_size: int = 512, overlap: int = 50, semantic: bool = True) -> List[str]:
|
| 104 |
+
"""
|
| 105 |
+
Split text into chunks using semantic or simple chunking.
|
| 106 |
+
|
| 107 |
+
Args:
|
| 108 |
+
text: The text to chunk
|
| 109 |
+
chunk_size: Target size of each chunk in characters
|
| 110 |
+
overlap: Number of overlapping characters between chunks
|
| 111 |
+
semantic: Use semantic chunking (by headers/concepts) if True
|
| 112 |
+
|
| 113 |
+
Returns:
|
| 114 |
+
List of text chunks
|
| 115 |
+
"""
|
| 116 |
+
if semantic:
|
| 117 |
+
return self._semantic_chunk(text, chunk_size, overlap)
|
| 118 |
+
else:
|
| 119 |
+
return self._simple_chunk(text, chunk_size, overlap)
|
| 120 |
+
|
| 121 |
+
def _semantic_chunk(self, text: str, target_size: int = 512, overlap: int = 50) -> List[str]:
|
| 122 |
+
"""
|
| 123 |
+
Chunk text by detecting headers and logical sections.
|
| 124 |
+
Perfect for lecture slides and structured documents.
|
| 125 |
+
"""
|
| 126 |
+
chunks = []
|
| 127 |
+
|
| 128 |
+
# Split by common header patterns
|
| 129 |
+
# Pattern 1: Lines that are ALL CAPS or Title Case followed by newline
|
| 130 |
+
# Pattern 2: Lines starting with numbers like "1.", "1.1", etc.
|
| 131 |
+
# Pattern 3: Lines with clear visual separators
|
| 132 |
+
|
| 133 |
+
# First, split by double newlines (paragraphs)
|
| 134 |
+
sections = text.split('\n\n')
|
| 135 |
+
|
| 136 |
+
current_chunk = ""
|
| 137 |
+
current_header = ""
|
| 138 |
+
|
| 139 |
+
for section in sections:
|
| 140 |
+
section = section.strip()
|
| 141 |
+
if not section:
|
| 142 |
+
continue
|
| 143 |
+
|
| 144 |
+
# Check if this looks like a header
|
| 145 |
+
is_header = self._is_likely_header(section)
|
| 146 |
+
|
| 147 |
+
if is_header and len(current_chunk) > 100:
|
| 148 |
+
# Save previous chunk and start new one with this header
|
| 149 |
+
if current_chunk:
|
| 150 |
+
chunks.append(current_chunk.strip())
|
| 151 |
+
current_chunk = section + "\n\n"
|
| 152 |
+
current_header = section
|
| 153 |
+
else:
|
| 154 |
+
# Add to current chunk
|
| 155 |
+
potential_chunk = current_chunk + section + "\n\n"
|
| 156 |
+
|
| 157 |
+
# If chunk is getting too large, split it
|
| 158 |
+
if len(potential_chunk) > target_size * 1.5:
|
| 159 |
+
if current_chunk:
|
| 160 |
+
chunks.append(current_chunk.strip())
|
| 161 |
+
current_chunk = section + "\n\n"
|
| 162 |
+
else:
|
| 163 |
+
current_chunk = potential_chunk
|
| 164 |
+
|
| 165 |
+
# Add final chunk
|
| 166 |
+
if current_chunk:
|
| 167 |
+
chunks.append(current_chunk.strip())
|
| 168 |
+
|
| 169 |
+
# If semantic chunking produced too few chunks, fall back to simple chunking
|
| 170 |
+
if len(chunks) < len(text) / (target_size * 2):
|
| 171 |
+
return self._simple_chunk(text, target_size, overlap)
|
| 172 |
+
|
| 173 |
+
return chunks
|
| 174 |
+
|
| 175 |
+
def _is_likely_header(self, text: str) -> bool:
|
| 176 |
+
"""Detect if text is likely a header/title."""
|
| 177 |
+
# Too long to be a header
|
| 178 |
+
if len(text) > 200:
|
| 179 |
+
return False
|
| 180 |
+
|
| 181 |
+
# Single line headers
|
| 182 |
+
if '\n' not in text:
|
| 183 |
+
# ALL CAPS
|
| 184 |
+
if text.isupper() and len(text.split()) <= 10:
|
| 185 |
+
return True
|
| 186 |
+
|
| 187 |
+
# Title Case
|
| 188 |
+
if text.istitle() and len(text.split()) <= 10:
|
| 189 |
+
return True
|
| 190 |
+
|
| 191 |
+
# Numbered sections like "1.", "1.1", "Chapter 1"
|
| 192 |
+
if re.match(r'^(\d+\.)+\s+', text) or re.match(r'^(Chapter|Section|Part)\s+\d+', text, re.IGNORECASE):
|
| 193 |
+
return True
|
| 194 |
+
|
| 195 |
+
return False
|
| 196 |
+
|
| 197 |
+
def _simple_chunk(self, text: str, chunk_size: int = 512, overlap: int = 50) -> List[str]:
|
| 198 |
+
"""
|
| 199 |
+
Split text into overlapping chunks (original method).
|
| 200 |
+
"""
|
| 201 |
+
chunks = []
|
| 202 |
+
start = 0
|
| 203 |
+
text_length = len(text)
|
| 204 |
+
|
| 205 |
+
while start < text_length:
|
| 206 |
+
end = start + chunk_size
|
| 207 |
+
chunk = text[start:end]
|
| 208 |
+
|
| 209 |
+
# Try to break at sentence boundary
|
| 210 |
+
if end < text_length:
|
| 211 |
+
last_period = chunk.rfind('.')
|
| 212 |
+
last_newline = chunk.rfind('\n')
|
| 213 |
+
break_point = max(last_period, last_newline)
|
| 214 |
+
|
| 215 |
+
if break_point > chunk_size * 0.5: # At least 50% through the chunk
|
| 216 |
+
chunk = chunk[:break_point + 1]
|
| 217 |
+
end = start + break_point + 1
|
| 218 |
+
|
| 219 |
+
chunks.append(chunk.strip())
|
| 220 |
+
start = end - overlap
|
| 221 |
+
|
| 222 |
+
return chunks
|
utils/hybrid_retriever.py
ADDED
|
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Hybrid Retriever: Combines Vector (ChromaDB) + Keyword (BM25) search.
|
| 3 |
+
This is the "secret sauce" that makes NotebookLM so accurate.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from typing import List, Dict, Tuple
|
| 7 |
+
from rank_bm25 import BM25Okapi
|
| 8 |
+
import numpy as np
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class HybridRetriever:
|
| 12 |
+
"""
|
| 13 |
+
Combines Dense Retrieval (Embeddings) with Sparse Retrieval (BM25).
|
| 14 |
+
|
| 15 |
+
This is crucial for accuracy:
|
| 16 |
+
- Vector search finds conceptually similar content
|
| 17 |
+
- BM25 finds exact keyword matches (formulas, terms, names)
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
def __init__(self, vector_db, alpha: float = 0.5):
|
| 21 |
+
"""
|
| 22 |
+
Initialize hybrid retriever.
|
| 23 |
+
|
| 24 |
+
Args:
|
| 25 |
+
vector_db: VectorDatabase instance
|
| 26 |
+
alpha: Weight balance (0=only BM25, 1=only vector, 0.5=balanced)
|
| 27 |
+
"""
|
| 28 |
+
self.vector_db = vector_db
|
| 29 |
+
self.alpha = alpha # Weight for vector search
|
| 30 |
+
self.bm25 = None
|
| 31 |
+
self.bm25_corpus = []
|
| 32 |
+
self.bm25_metadata = []
|
| 33 |
+
|
| 34 |
+
def index_documents(self, documents: List[str], metadatas: List[Dict]):
|
| 35 |
+
"""
|
| 36 |
+
Index documents for BM25 keyword search.
|
| 37 |
+
|
| 38 |
+
Args:
|
| 39 |
+
documents: List of document chunks
|
| 40 |
+
metadatas: List of metadata dicts for each chunk
|
| 41 |
+
"""
|
| 42 |
+
# Tokenize documents for BM25
|
| 43 |
+
tokenized_corpus = [doc.lower().split() for doc in documents]
|
| 44 |
+
|
| 45 |
+
# Create BM25 index
|
| 46 |
+
self.bm25 = BM25Okapi(tokenized_corpus)
|
| 47 |
+
self.bm25_corpus = documents
|
| 48 |
+
self.bm25_metadata = metadatas
|
| 49 |
+
|
| 50 |
+
def retrieve(
|
| 51 |
+
self,
|
| 52 |
+
query: str,
|
| 53 |
+
n_results: int = 5,
|
| 54 |
+
score_threshold: float = 0.0
|
| 55 |
+
) -> Tuple[List[str], List[Dict], List[float]]:
|
| 56 |
+
"""
|
| 57 |
+
Hybrid retrieval: combines vector + keyword search.
|
| 58 |
+
|
| 59 |
+
Args:
|
| 60 |
+
query: User's question
|
| 61 |
+
n_results: Number of chunks to retrieve
|
| 62 |
+
score_threshold: Minimum score threshold
|
| 63 |
+
|
| 64 |
+
Returns:
|
| 65 |
+
Tuple of (documents, metadatas, scores)
|
| 66 |
+
"""
|
| 67 |
+
if not self.bm25:
|
| 68 |
+
# Fallback to pure vector search if BM25 not initialized
|
| 69 |
+
results = self.vector_db.query(query, n_results=n_results * 2)
|
| 70 |
+
return (
|
| 71 |
+
results['documents'][0] if results['documents'] else [],
|
| 72 |
+
results['metadatas'][0] if results['metadatas'] else [],
|
| 73 |
+
results.get('distances', [[]])[0]
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
# Get more results than needed for reranking
|
| 77 |
+
fetch_size = n_results * 3
|
| 78 |
+
|
| 79 |
+
# 1. Vector search (semantic similarity)
|
| 80 |
+
vector_results = self.vector_db.query(query, n_results=fetch_size)
|
| 81 |
+
vector_docs = vector_results['documents'][0] if vector_results['documents'] else []
|
| 82 |
+
vector_meta = vector_results['metadatas'][0] if vector_results['metadatas'] else []
|
| 83 |
+
vector_distances = vector_results.get('distances', [[]])[0]
|
| 84 |
+
|
| 85 |
+
# Convert distances to similarity scores (chromadb uses cosine distance)
|
| 86 |
+
vector_scores = [1 / (1 + d) for d in vector_distances]
|
| 87 |
+
|
| 88 |
+
# 2. BM25 search (keyword matching)
|
| 89 |
+
tokenized_query = query.lower().split()
|
| 90 |
+
bm25_scores = self.bm25.get_scores(tokenized_query)
|
| 91 |
+
|
| 92 |
+
# Get top BM25 results
|
| 93 |
+
top_bm25_indices = np.argsort(bm25_scores)[::-1][:fetch_size]
|
| 94 |
+
|
| 95 |
+
# 3. Combine results with weighted scoring
|
| 96 |
+
combined_docs = {} # Use dict to deduplicate by content
|
| 97 |
+
|
| 98 |
+
# Add vector results
|
| 99 |
+
for doc, meta, score in zip(vector_docs, vector_meta, vector_scores):
|
| 100 |
+
combined_docs[doc] = {
|
| 101 |
+
'doc': doc,
|
| 102 |
+
'meta': meta,
|
| 103 |
+
'score': self.alpha * score
|
| 104 |
+
}
|
| 105 |
+
|
| 106 |
+
# Add BM25 results (normalize scores to 0-1 range)
|
| 107 |
+
max_bm25_score = max(bm25_scores) if max(bm25_scores) > 0 else 1
|
| 108 |
+
for idx in top_bm25_indices:
|
| 109 |
+
doc = self.bm25_corpus[idx]
|
| 110 |
+
meta = self.bm25_metadata[idx]
|
| 111 |
+
bm25_score = bm25_scores[idx] / max_bm25_score
|
| 112 |
+
|
| 113 |
+
if doc in combined_docs:
|
| 114 |
+
# Average if document found by both methods
|
| 115 |
+
combined_docs[doc]['score'] += (1 - self.alpha) * bm25_score
|
| 116 |
+
else:
|
| 117 |
+
combined_docs[doc] = {
|
| 118 |
+
'doc': doc,
|
| 119 |
+
'meta': meta,
|
| 120 |
+
'score': (1 - self.alpha) * bm25_score
|
| 121 |
+
}
|
| 122 |
+
|
| 123 |
+
# 4. Rank by combined score
|
| 124 |
+
ranked_results = sorted(
|
| 125 |
+
combined_docs.values(),
|
| 126 |
+
key=lambda x: x['score'],
|
| 127 |
+
reverse=True
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
# 5. Filter by threshold and limit results
|
| 131 |
+
filtered_results = [
|
| 132 |
+
r for r in ranked_results
|
| 133 |
+
if r['score'] >= score_threshold
|
| 134 |
+
][:n_results]
|
| 135 |
+
|
| 136 |
+
# 6. Return in expected format
|
| 137 |
+
documents = [r['doc'] for r in filtered_results]
|
| 138 |
+
metadatas = [r['meta'] for r in filtered_results]
|
| 139 |
+
scores = [r['score'] for r in filtered_results]
|
| 140 |
+
|
| 141 |
+
return documents, metadatas, scores
|
| 142 |
+
|
| 143 |
+
def get_stats(self) -> Dict:
|
| 144 |
+
"""Get retriever statistics."""
|
| 145 |
+
return {
|
| 146 |
+
'bm25_indexed': len(self.bm25_corpus) if self.bm25 else 0,
|
| 147 |
+
'vector_count': self.vector_db.get_collection_count(),
|
| 148 |
+
'alpha': self.alpha
|
| 149 |
+
}
|
utils/llm_generator.py
ADDED
|
@@ -0,0 +1,297 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Real LLM-based generator using Groq or Google Gemini API.
|
| 3 |
+
This ACTUALLY generates responses (unlike SimpleGenerator which just extracts text).
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
from typing import List, Dict, Optional
|
| 8 |
+
import streamlit as st
|
| 9 |
+
|
| 10 |
+
try:
|
| 11 |
+
from groq import Groq
|
| 12 |
+
GROQ_AVAILABLE = True
|
| 13 |
+
except ImportError:
|
| 14 |
+
GROQ_AVAILABLE = False
|
| 15 |
+
|
| 16 |
+
try:
|
| 17 |
+
import google.generativeai as genai
|
| 18 |
+
GEMINI_AVAILABLE = True
|
| 19 |
+
except ImportError:
|
| 20 |
+
GEMINI_AVAILABLE = False
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class LLMGenerator:
|
| 24 |
+
"""
|
| 25 |
+
Actual LLM-based response generation using Groq (Llama-3-70B) or Gemini.
|
| 26 |
+
This is what NotebookLM uses - real AI generation, not text extraction.
|
| 27 |
+
"""
|
| 28 |
+
|
| 29 |
+
def __init__(self, provider: str = "groq", api_key: Optional[str] = None):
|
| 30 |
+
"""
|
| 31 |
+
Initialize LLM generator.
|
| 32 |
+
|
| 33 |
+
Args:
|
| 34 |
+
provider: "groq" or "gemini"
|
| 35 |
+
api_key: API key (if None, reads from environment or asks user)
|
| 36 |
+
"""
|
| 37 |
+
self.provider = provider
|
| 38 |
+
self.client = None
|
| 39 |
+
self.ready = False
|
| 40 |
+
|
| 41 |
+
# Get API key
|
| 42 |
+
if api_key:
|
| 43 |
+
self.api_key = api_key
|
| 44 |
+
elif provider == "groq":
|
| 45 |
+
self.api_key = os.getenv("GROQ_API_KEY", "")
|
| 46 |
+
elif provider == "gemini":
|
| 47 |
+
self.api_key = os.getenv("GEMINI_API_KEY", "")
|
| 48 |
+
else:
|
| 49 |
+
self.api_key = ""
|
| 50 |
+
|
| 51 |
+
# Initialize client
|
| 52 |
+
self._initialize_client()
|
| 53 |
+
|
| 54 |
+
def _initialize_client(self):
|
| 55 |
+
"""Initialize the LLM client."""
|
| 56 |
+
if not self.api_key:
|
| 57 |
+
return
|
| 58 |
+
|
| 59 |
+
try:
|
| 60 |
+
if self.provider == "groq" and GROQ_AVAILABLE:
|
| 61 |
+
# Initialize Groq client with explicit parameters
|
| 62 |
+
# Avoid potential proxies kwarg issue by not passing extra config
|
| 63 |
+
import os
|
| 64 |
+
os.environ["GROQ_API_KEY"] = self.api_key
|
| 65 |
+
self.client = Groq() # Will read from environment
|
| 66 |
+
self.ready = True
|
| 67 |
+
elif self.provider == "gemini" and GEMINI_AVAILABLE:
|
| 68 |
+
genai.configure(api_key=self.api_key)
|
| 69 |
+
self.client = genai.GenerativeModel('gemini-1.5-flash')
|
| 70 |
+
self.ready = True
|
| 71 |
+
except Exception as e:
|
| 72 |
+
print(f"Failed to initialize {self.provider}: {e}")
|
| 73 |
+
self.ready = False
|
| 74 |
+
|
| 75 |
+
def set_api_key(self, api_key: str):
|
| 76 |
+
"""Update API key and reinitialize."""
|
| 77 |
+
self.api_key = api_key
|
| 78 |
+
self._initialize_client()
|
| 79 |
+
|
| 80 |
+
def generate_response(
|
| 81 |
+
self,
|
| 82 |
+
prompt: str,
|
| 83 |
+
context: str = "",
|
| 84 |
+
use_case: str = "explanation",
|
| 85 |
+
metadatas: List[Dict] = None,
|
| 86 |
+
temperature: float = 0.7,
|
| 87 |
+
max_tokens: int = 1500,
|
| 88 |
+
**kwargs
|
| 89 |
+
) -> str:
|
| 90 |
+
"""
|
| 91 |
+
Generate response using actual LLM (NotebookLM-style).
|
| 92 |
+
|
| 93 |
+
Args:
|
| 94 |
+
prompt: User's question
|
| 95 |
+
context: Retrieved context from documents
|
| 96 |
+
use_case: Response type (explanation, summary, qa, notes)
|
| 97 |
+
metadatas: Metadata for citations
|
| 98 |
+
temperature: LLM temperature (0.0-1.0)
|
| 99 |
+
max_tokens: Maximum response length
|
| 100 |
+
|
| 101 |
+
Returns:
|
| 102 |
+
Generated response with inline citations
|
| 103 |
+
"""
|
| 104 |
+
if not self.ready:
|
| 105 |
+
return (
|
| 106 |
+
"⚠️ **LLM not configured.** Please add your API key in the sidebar.\n\n"
|
| 107 |
+
"Get a free key:\n"
|
| 108 |
+
"- **Groq** (recommended, very fast): https://console.groq.com/keys\n"
|
| 109 |
+
"- **Gemini** (Google): https://makersuite.google.com/app/apikey"
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
if not context:
|
| 113 |
+
return (
|
| 114 |
+
"I don't have enough information from your uploaded documents to answer this question. "
|
| 115 |
+
"Please upload relevant study materials first."
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
# Build NotebookLM-style system prompt with strict source grounding
|
| 119 |
+
system_prompt = self._build_system_prompt(use_case)
|
| 120 |
+
|
| 121 |
+
# Build user message with context
|
| 122 |
+
user_message = self._build_user_message(prompt, context, metadatas)
|
| 123 |
+
|
| 124 |
+
try:
|
| 125 |
+
# Generate with LLM
|
| 126 |
+
if self.provider == "groq":
|
| 127 |
+
response = self._generate_groq(system_prompt, user_message, temperature, max_tokens)
|
| 128 |
+
elif self.provider == "gemini":
|
| 129 |
+
response = self._generate_gemini(system_prompt, user_message, temperature, max_tokens)
|
| 130 |
+
else:
|
| 131 |
+
return "Error: Unknown provider"
|
| 132 |
+
|
| 133 |
+
return response
|
| 134 |
+
|
| 135 |
+
except Exception as e:
|
| 136 |
+
return f"Error generating response: {str(e)}\n\nPlease check your API key and try again."
|
| 137 |
+
|
| 138 |
+
def _build_system_prompt(self, use_case: str) -> str:
|
| 139 |
+
"""Build specialized system prompt based on use case."""
|
| 140 |
+
base_prompt = (
|
| 141 |
+
"You are an expert academic assistant for students, acting like a highly intelligent study buddy. "
|
| 142 |
+
"⚠️ CRITICAL RULE: You MUST ONLY use information from the provided context below. "
|
| 143 |
+
"DO NOT use your training knowledge. DO NOT infer beyond what's explicitly stated. "
|
| 144 |
+
"If the context doesn't contain adequate information to answer the question, you MUST respond: "
|
| 145 |
+
"'I cannot find sufficient information about this in the uploaded documents. Please upload materials covering this topic or rephrase your question.'\n\n"
|
| 146 |
+
"⚠️ GROUNDING REQUIREMENT: Every statement must be traceable to the provided context. "
|
| 147 |
+
"If you cannot find it in the context below, DO NOT answer from general knowledge.\n\n"
|
| 148 |
+
"✨ FORMATTING RULES (NotebookLM Style):\n"
|
| 149 |
+
"- Use clean, hierarchical Markdown (### Headers, **Bold** terms).\n"
|
| 150 |
+
"- Break down long paragraphs into easily readable bullet points.\n"
|
| 151 |
+
"- Be direct and concise. Avoid conversational fluff like 'Certainly!' or 'Here is the answer'.\n"
|
| 152 |
+
"- If applicable to the prompt, always try to extract a **Real-World Example** from the text to aid understanding.\n\n"
|
| 153 |
+
)
|
| 154 |
+
|
| 155 |
+
if use_case == "explanation":
|
| 156 |
+
base_prompt += (
|
| 157 |
+
"**Your task:** Explain the concept in a clear, step-by-step manner suitable for students.\n"
|
| 158 |
+
"1. Start with a concise, one-sentence definition.\n"
|
| 159 |
+
"2. Break down the core mechanics or components using bullet points.\n"
|
| 160 |
+
"3. Provide an example (only if found in the text).\n"
|
| 161 |
+
"4. Add a 'Key Takeaway' at the end.\n"
|
| 162 |
+
)
|
| 163 |
+
elif use_case == "summary":
|
| 164 |
+
base_prompt += (
|
| 165 |
+
"**Your task:** Create a highly structured summary.\n"
|
| 166 |
+
"- Start with a brief high-level overview (2 sentences max).\n"
|
| 167 |
+
"- Use '### Key Themes' and list the main points as bulleted items.\n"
|
| 168 |
+
"- Keep each point concise but factually dense.\n"
|
| 169 |
+
)
|
| 170 |
+
elif use_case == "qa":
|
| 171 |
+
base_prompt += (
|
| 172 |
+
"**Your task:** Answer the question directly and comprehensively.\n"
|
| 173 |
+
"- Provide the direct answer immediately in the first sentence.\n"
|
| 174 |
+
"- Use numbered lists or bullet points to provide supporting details from the context.\n"
|
| 175 |
+
"- Use **bold** for key facts, numbers, and formulas.\n"
|
| 176 |
+
)
|
| 177 |
+
elif use_case == "notes":
|
| 178 |
+
base_prompt += (
|
| 179 |
+
"**Your task:** Create comprehensive, structured study notes.\n"
|
| 180 |
+
"- Use clear section headers (###).\n"
|
| 181 |
+
"- Organize information hierarchically (using nested bullet points).\n"
|
| 182 |
+
"- Explicitly highlight **Definitions**, **Formulas**, and **Important Dates/Names**.\n"
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
base_prompt += (
|
| 186 |
+
"\n**Citation Rules:**\n"
|
| 187 |
+
"- You MUST cite your source at the end of every major claim or paragraph using numbered brackets like **[1]**, **[2]** based on the Source number provided in the context.\n"
|
| 188 |
+
"- If a claim comes from multiple sources, use **[1, 2]**.\n"
|
| 189 |
+
"- Do NOT use the document filename in the citation, ONLY the number.\n"
|
| 190 |
+
"- Do NOT make up information - stick strictly to the provided context.\n"
|
| 191 |
+
)
|
| 192 |
+
|
| 193 |
+
return base_prompt
|
| 194 |
+
|
| 195 |
+
def _build_user_message(self, prompt: str, context: str, metadatas: List[Dict] = None) -> str:
|
| 196 |
+
"""Build user message with context and question."""
|
| 197 |
+
# Extract source names from metadata
|
| 198 |
+
sources = []
|
| 199 |
+
if metadatas:
|
| 200 |
+
for meta in metadatas:
|
| 201 |
+
filename = meta.get('filename', 'Unknown')
|
| 202 |
+
clean_name = filename.replace('.pdf', '').replace('.docx', '').replace('.txt', '')
|
| 203 |
+
if clean_name not in sources:
|
| 204 |
+
sources.append(clean_name)
|
| 205 |
+
|
| 206 |
+
message = "**Available Sources (USE ONLY THESE):**\n"
|
| 207 |
+
for source in sources[:5]: # Show up to 5 sources
|
| 208 |
+
message += f"- {source}\n"
|
| 209 |
+
|
| 210 |
+
message += f"\n**===== START OF CONTEXT (ANSWER ONLY FROM THIS) =====**\n\n{context}\n\n"
|
| 211 |
+
message += f"**===== END OF CONTEXT =====**\n\n"
|
| 212 |
+
message += f"**Student's Question:** {prompt}\n\n"
|
| 213 |
+
message += "**Instructions:** Answer ONLY using the context between the markers above. If the context doesn't contain the answer, say you don't have that information. Cite sources in brackets."
|
| 214 |
+
|
| 215 |
+
return message
|
| 216 |
+
|
| 217 |
+
def _generate_groq(self, system_prompt: str, user_message: str, temperature: float, max_tokens: int) -> str:
|
| 218 |
+
"""Generate using Groq API (Llama-3.3-70B)."""
|
| 219 |
+
completion = self.client.chat.completions.create(
|
| 220 |
+
model="llama-3.3-70b-versatile", # Latest 70B model (Dec 2024)
|
| 221 |
+
messages=[
|
| 222 |
+
{"role": "system", "content": system_prompt},
|
| 223 |
+
{"role": "user", "content": user_message}
|
| 224 |
+
],
|
| 225 |
+
temperature=temperature,
|
| 226 |
+
max_tokens=max_tokens,
|
| 227 |
+
top_p=0.95,
|
| 228 |
+
stream=False
|
| 229 |
+
)
|
| 230 |
+
|
| 231 |
+
return completion.choices[0].message.content
|
| 232 |
+
|
| 233 |
+
def _generate_gemini(self, system_prompt: str, user_message: str, temperature: float, max_tokens: int) -> str:
|
| 234 |
+
"""Generate using Google Gemini API."""
|
| 235 |
+
full_prompt = f"{system_prompt}\n\n{user_message}"
|
| 236 |
+
|
| 237 |
+
response = self.client.generate_content(
|
| 238 |
+
full_prompt,
|
| 239 |
+
generation_config=genai.GenerationConfig(
|
| 240 |
+
temperature=temperature,
|
| 241 |
+
max_output_tokens=max_tokens,
|
| 242 |
+
top_p=0.95
|
| 243 |
+
)
|
| 244 |
+
)
|
| 245 |
+
|
| 246 |
+
return response.text
|
| 247 |
+
|
| 248 |
+
def is_ready(self) -> bool:
|
| 249 |
+
"""Check if LLM is ready to generate."""
|
| 250 |
+
return self.ready
|
| 251 |
+
|
| 252 |
+
def get_provider(self) -> str:
|
| 253 |
+
"""Get current provider name."""
|
| 254 |
+
if self.provider == "groq":
|
| 255 |
+
return "Groq (Llama-3.3-70B)"
|
| 256 |
+
elif self.provider == "gemini":
|
| 257 |
+
return "Google Gemini 1.5 Flash"
|
| 258 |
+
return "Unknown"
|
| 259 |
+
|
| 260 |
+
def generate(self, prompt: str, temperature: float = 0.3, max_tokens: int = 1500) -> str:
|
| 261 |
+
"""
|
| 262 |
+
Simple wrapper for backend compatibility.
|
| 263 |
+
Generates response from a complete prompt that already includes context.
|
| 264 |
+
|
| 265 |
+
Args:
|
| 266 |
+
prompt: Complete prompt with context already embedded
|
| 267 |
+
temperature: LLM temperature (0.0-1.0)
|
| 268 |
+
max_tokens: Maximum response length
|
| 269 |
+
|
| 270 |
+
Returns:
|
| 271 |
+
Generated response
|
| 272 |
+
"""
|
| 273 |
+
if not self.ready:
|
| 274 |
+
return (
|
| 275 |
+
"⚠️ **LLM not configured.** Please add your API key.\n\n"
|
| 276 |
+
"Get a free key:\n"
|
| 277 |
+
"- **Groq** (recommended, very fast): https://console.groq.com/keys\n"
|
| 278 |
+
"- **Gemini** (Google): https://makersuite.google.com/app/apikey"
|
| 279 |
+
)
|
| 280 |
+
|
| 281 |
+
try:
|
| 282 |
+
if self.provider == "groq":
|
| 283 |
+
return self._generate_groq(
|
| 284 |
+
system_prompt="You are a helpful AI assistant.",
|
| 285 |
+
user_message=prompt,
|
| 286 |
+
temperature=temperature,
|
| 287 |
+
max_tokens=max_tokens
|
| 288 |
+
)
|
| 289 |
+
elif self.provider == "gemini":
|
| 290 |
+
return self._generate_gemini(
|
| 291 |
+
system_prompt="You are a helpful AI assistant.",
|
| 292 |
+
user_message=prompt,
|
| 293 |
+
temperature=temperature,
|
| 294 |
+
max_tokens=max_tokens
|
| 295 |
+
)
|
| 296 |
+
except Exception as e:
|
| 297 |
+
return f"Error generating response: {str(e)}"
|
utils/model_inference.py
ADDED
|
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
|
| 3 |
+
from typing import List, Dict, Optional
|
| 4 |
+
import config
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class ModelInference:
|
| 8 |
+
"""Handle model loading and inference for text generation."""
|
| 9 |
+
|
| 10 |
+
def __init__(self, model_name: str = None, use_4bit: bool = True):
|
| 11 |
+
"""
|
| 12 |
+
Initialize the model for inference.
|
| 13 |
+
RAG Mode: Uses pre-trained model directly (no training needed!).
|
| 14 |
+
|
| 15 |
+
Args:
|
| 16 |
+
model_name: Name or path of the model (uses pre-trained by default)
|
| 17 |
+
use_4bit: Whether to use 4-bit quantization for efficiency
|
| 18 |
+
"""
|
| 19 |
+
# Use pre-trained model if specified, otherwise check for fine-tuned model
|
| 20 |
+
if config.USE_PRETRAINED or not Path(config.MODEL_PATH).exists():
|
| 21 |
+
self.model_name = model_name or config.MODEL_NAME
|
| 22 |
+
else:
|
| 23 |
+
self.model_name = model_name or config.MODEL_PATH
|
| 24 |
+
|
| 25 |
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 26 |
+
|
| 27 |
+
print(f"Loading model: {self.model_name}")
|
| 28 |
+
print(f"Device: {self.device}")
|
| 29 |
+
|
| 30 |
+
# Configure quantization for efficiency
|
| 31 |
+
if use_4bit and self.device == "cuda":
|
| 32 |
+
bnb_config = BitsAndBytesConfig(
|
| 33 |
+
load_in_4bit=True,
|
| 34 |
+
bnb_4bit_quant_type="nf4",
|
| 35 |
+
bnb_4bit_compute_dtype=torch.float16,
|
| 36 |
+
bnb_4bit_use_double_quant=True,
|
| 37 |
+
)
|
| 38 |
+
self.model = AutoModelForCausalLM.from_pretrained(
|
| 39 |
+
self.model_name,
|
| 40 |
+
quantization_config=bnb_config,
|
| 41 |
+
device_map="auto",
|
| 42 |
+
trust_remote_code=True
|
| 43 |
+
)
|
| 44 |
+
else:
|
| 45 |
+
self.model = AutoModelForCausalLM.from_pretrained(
|
| 46 |
+
self.model_name,
|
| 47 |
+
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
|
| 48 |
+
device_map="auto" if self.device == "cuda" else None,
|
| 49 |
+
trust_remote_code=True
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
| 53 |
+
self.model_name,
|
| 54 |
+
trust_remote_code=True
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
if self.tokenizer.pad_token is None:
|
| 58 |
+
self.tokenizer.pad_token = self.tokenizer.eos_token
|
| 59 |
+
|
| 60 |
+
self.model.eval()
|
| 61 |
+
|
| 62 |
+
def generate_response(
|
| 63 |
+
self,
|
| 64 |
+
prompt: str,
|
| 65 |
+
context: str = "",
|
| 66 |
+
use_case: str = "explanation",
|
| 67 |
+
temperature: float = None,
|
| 68 |
+
max_tokens: int = None
|
| 69 |
+
) -> str:
|
| 70 |
+
"""
|
| 71 |
+
Generate a response based on the prompt and context.
|
| 72 |
+
|
| 73 |
+
Args:
|
| 74 |
+
prompt: User query
|
| 75 |
+
context: Retrieved context from documents
|
| 76 |
+
use_case: Type of response (explanation, summary, qa, notes)
|
| 77 |
+
temperature: Sampling temperature
|
| 78 |
+
max_tokens: Maximum number of tokens to generate
|
| 79 |
+
|
| 80 |
+
Returns:
|
| 81 |
+
Generated text response
|
| 82 |
+
"""
|
| 83 |
+
temperature = temperature or config.TEMPERATURE
|
| 84 |
+
max_tokens = max_tokens or config.MAX_TOKENS
|
| 85 |
+
|
| 86 |
+
# Create system prompt based on use case
|
| 87 |
+
system_prompts = {
|
| 88 |
+
"explanation": "You are an expert tutor. Provide detailed, clear explanations of concepts based on the given context.",
|
| 89 |
+
"summary": "You are a summarization expert. Create concise, well-structured summaries of the provided content.",
|
| 90 |
+
"qa": "You are a knowledgeable assistant. Answer questions accurately based on the given context.",
|
| 91 |
+
"notes": "You are a study notes specialist. Create well-organized, structured study notes from the content."
|
| 92 |
+
}
|
| 93 |
+
|
| 94 |
+
system_prompt = system_prompts.get(use_case, system_prompts["explanation"])
|
| 95 |
+
|
| 96 |
+
# Format the full prompt
|
| 97 |
+
full_prompt = self._format_prompt(system_prompt, context, prompt)
|
| 98 |
+
|
| 99 |
+
# Tokenize
|
| 100 |
+
inputs = self.tokenizer(
|
| 101 |
+
full_prompt,
|
| 102 |
+
return_tensors="pt",
|
| 103 |
+
truncation=True,
|
| 104 |
+
max_length=2048
|
| 105 |
+
).to(self.device)
|
| 106 |
+
|
| 107 |
+
# Generate
|
| 108 |
+
with torch.no_grad():
|
| 109 |
+
outputs = self.model.generate(
|
| 110 |
+
**inputs,
|
| 111 |
+
max_new_tokens=max_tokens,
|
| 112 |
+
temperature=temperature,
|
| 113 |
+
do_sample=True,
|
| 114 |
+
top_p=0.95,
|
| 115 |
+
top_k=50,
|
| 116 |
+
repetition_penalty=1.1,
|
| 117 |
+
pad_token_id=self.tokenizer.pad_token_id,
|
| 118 |
+
eos_token_id=self.tokenizer.eos_token_id
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
# Decode
|
| 122 |
+
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 123 |
+
|
| 124 |
+
# Extract only the new generated text
|
| 125 |
+
response = response[len(full_prompt):].strip()
|
| 126 |
+
|
| 127 |
+
return response
|
| 128 |
+
|
| 129 |
+
def _format_prompt(self, system_prompt: str, context: str, query: str) -> str:
|
| 130 |
+
"""Format the prompt with system instructions, context, and query."""
|
| 131 |
+
prompt = f"{system_prompt}\n\n"
|
| 132 |
+
|
| 133 |
+
if context:
|
| 134 |
+
prompt += f"Context from your study materials:\n{context}\n\n"
|
| 135 |
+
|
| 136 |
+
prompt += f"Query: {query}\n\nResponse:"
|
| 137 |
+
|
| 138 |
+
return prompt
|
| 139 |
+
|
| 140 |
+
def batch_generate(self, prompts: List[str], **kwargs) -> List[str]:
|
| 141 |
+
"""
|
| 142 |
+
Generate responses for multiple prompts.
|
| 143 |
+
|
| 144 |
+
Args:
|
| 145 |
+
prompts: List of prompts
|
| 146 |
+
**kwargs: Additional arguments for generate_response
|
| 147 |
+
|
| 148 |
+
Returns:
|
| 149 |
+
List of generated responses
|
| 150 |
+
"""
|
| 151 |
+
responses = []
|
| 152 |
+
for prompt in prompts:
|
| 153 |
+
response = self.generate_response(prompt, **kwargs)
|
| 154 |
+
responses.append(response)
|
| 155 |
+
|
| 156 |
+
return responses
|
utils/simple_generator.py
ADDED
|
@@ -0,0 +1,444 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
NotebookLM-style response generator with professional formatting.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from typing import List, Dict
|
| 6 |
+
import config
|
| 7 |
+
import re
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class SimpleGenerator:
|
| 11 |
+
"""Lightweight generator with NotebookLM-quality formatting."""
|
| 12 |
+
|
| 13 |
+
def __init__(self):
|
| 14 |
+
self.ready = True
|
| 15 |
+
|
| 16 |
+
def _clean_and_format_text(self, text: str) -> str:
|
| 17 |
+
"""Clean and format text with proper spacing like NotebookLM."""
|
| 18 |
+
# Fix spacing after punctuation
|
| 19 |
+
text = re.sub(r'([.!?])([A-Z])', r'\1 \2', text)
|
| 20 |
+
# Remove multiple spaces
|
| 21 |
+
text = re.sub(r'\s+', ' ', text)
|
| 22 |
+
# Add proper line breaks after sentences
|
| 23 |
+
text = re.sub(r'([.!?])\s+', r'\1\n\n', text)
|
| 24 |
+
return text.strip()
|
| 25 |
+
|
| 26 |
+
def _extract_key_terms(self, text: str) -> List[str]:
|
| 27 |
+
"""Extract key terms that should be bolded."""
|
| 28 |
+
# Look for capitalized terms, technical terms
|
| 29 |
+
terms = []
|
| 30 |
+
|
| 31 |
+
# Find terms in quotes
|
| 32 |
+
quoted = re.findall(r'"([^"]+)"', text)
|
| 33 |
+
terms.extend(quoted)
|
| 34 |
+
|
| 35 |
+
# Find repeated important words (appear 2+ times)
|
| 36 |
+
words = re.findall(r'\b[A-Z][a-z]+(?:\s+[A-Z][a-z]+)*\b', text)
|
| 37 |
+
word_count = {}
|
| 38 |
+
for word in words:
|
| 39 |
+
word_count[word] = word_count.get(word, 0) + 1
|
| 40 |
+
|
| 41 |
+
# Add words that appear multiple times
|
| 42 |
+
terms.extend([w for w, count in word_count.items() if count >= 2])
|
| 43 |
+
|
| 44 |
+
return list(set(terms))
|
| 45 |
+
|
| 46 |
+
def _apply_bold_formatting(self, text: str) -> str:
|
| 47 |
+
"""Apply bold formatting to key terms like NotebookLM."""
|
| 48 |
+
key_terms = self._extract_key_terms(text)
|
| 49 |
+
|
| 50 |
+
# Bold key terms
|
| 51 |
+
for term in key_terms:
|
| 52 |
+
if len(term) > 3: # Skip very short terms
|
| 53 |
+
text = re.sub(rf'\b({re.escape(term)})\b', r'**\1**', text, count=1)
|
| 54 |
+
|
| 55 |
+
# Bold specific patterns
|
| 56 |
+
# Numbers with context
|
| 57 |
+
text = re.sub(r'\b(\d+)\s+(observations?|years?|months?|quarters?)', r'**\1 \2**', text)
|
| 58 |
+
|
| 59 |
+
return text
|
| 60 |
+
|
| 61 |
+
def _create_structured_response(self, context: str, query: str) -> str:
|
| 62 |
+
"""Create a NotebookLM-style structured response."""
|
| 63 |
+
# Split into paragraphs
|
| 64 |
+
paragraphs = [p.strip() for p in context.split('\n\n') if len(p.strip()) > 50]
|
| 65 |
+
|
| 66 |
+
# Remove duplicates
|
| 67 |
+
unique_paras = []
|
| 68 |
+
seen = set()
|
| 69 |
+
for para in paragraphs:
|
| 70 |
+
para_key = para.lower()[:150]
|
| 71 |
+
if para_key not in seen:
|
| 72 |
+
unique_paras.append(para)
|
| 73 |
+
seen.add(para_key)
|
| 74 |
+
if len(unique_paras) >= 5:
|
| 75 |
+
break
|
| 76 |
+
|
| 77 |
+
if not unique_paras:
|
| 78 |
+
return context[:1000]
|
| 79 |
+
|
| 80 |
+
# Build NotebookLM-style response
|
| 81 |
+
response = ""
|
| 82 |
+
|
| 83 |
+
# Main explanation (first paragraph - cleaned and formatted)
|
| 84 |
+
main_para = self._clean_and_format_text(unique_paras[0])
|
| 85 |
+
main_para = self._apply_bold_formatting(main_para)
|
| 86 |
+
response += main_para + "\n\n"
|
| 87 |
+
|
| 88 |
+
# Add structured details if more content available
|
| 89 |
+
if len(unique_paras) > 1:
|
| 90 |
+
response += "### Key Points:\n\n"
|
| 91 |
+
|
| 92 |
+
for i, para in enumerate(unique_paras[1:4], 1):
|
| 93 |
+
# Extract first 2-3 sentences
|
| 94 |
+
sentences = [s.strip() for s in para.split('.') if len(s.strip()) > 20]
|
| 95 |
+
if sentences:
|
| 96 |
+
detail = self._clean_and_format_text('. '.join(sentences[:2]) + '.')
|
| 97 |
+
detail = self._apply_bold_formatting(detail)
|
| 98 |
+
response += f"{i}. {detail}\n\n"
|
| 99 |
+
|
| 100 |
+
return response.strip()
|
| 101 |
+
|
| 102 |
+
def generate_response(
|
| 103 |
+
self,
|
| 104 |
+
prompt: str,
|
| 105 |
+
context: str = "",
|
| 106 |
+
use_case: str = "explanation",
|
| 107 |
+
metadatas: List[Dict] = None,
|
| 108 |
+
**kwargs
|
| 109 |
+
) -> str:
|
| 110 |
+
"""
|
| 111 |
+
Generate a NotebookLM-quality response with strict citations.
|
| 112 |
+
|
| 113 |
+
Args:
|
| 114 |
+
prompt: User query
|
| 115 |
+
context: Retrieved context from documents
|
| 116 |
+
use_case: Type of response (explanation, summary, qa,notes)
|
| 117 |
+
metadatas: Metadata for each context chunk (for citations)
|
| 118 |
+
|
| 119 |
+
Returns:
|
| 120 |
+
Professional formatted response with inline citations
|
| 121 |
+
"""
|
| 122 |
+
if not context:
|
| 123 |
+
return (
|
| 124 |
+
"I don't have enough information from your uploaded documents to answer this question. "
|
| 125 |
+
"Please upload relevant study materials first, or try rephrasing your question."
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
# Use specialized prompts based on use case
|
| 129 |
+
if use_case == "summary":
|
| 130 |
+
response = self._create_summary_with_citations(context, prompt, metadatas)
|
| 131 |
+
elif use_case == "notes":
|
| 132 |
+
response = self._create_notes_with_citations(context, prompt, metadatas)
|
| 133 |
+
elif use_case == "qa":
|
| 134 |
+
response = self._create_qa_with_citations(context, prompt, metadatas)
|
| 135 |
+
else: # Default to explanation
|
| 136 |
+
response = self._create_structured_response_with_citations(context, prompt, metadatas)
|
| 137 |
+
|
| 138 |
+
return response
|
| 139 |
+
|
| 140 |
+
def _create_structured_response_with_citations(
|
| 141 |
+
self,
|
| 142 |
+
context: str,
|
| 143 |
+
query: str,
|
| 144 |
+
metadatas: List[Dict] = None
|
| 145 |
+
) -> str:
|
| 146 |
+
"""Create NotebookLM-style response with inline citations."""
|
| 147 |
+
# Split into paragraphs
|
| 148 |
+
paragraphs = [p.strip() for p in context.split('\n\n') if len(p.strip()) > 50]
|
| 149 |
+
|
| 150 |
+
# Remove duplicates
|
| 151 |
+
unique_paras = []
|
| 152 |
+
seen = set()
|
| 153 |
+
for para in paragraphs:
|
| 154 |
+
para_key = para.lower()[:150]
|
| 155 |
+
if para_key not in seen:
|
| 156 |
+
unique_paras.append(para)
|
| 157 |
+
seen.add(para_key)
|
| 158 |
+
if len(unique_paras) >= 5:
|
| 159 |
+
break
|
| 160 |
+
|
| 161 |
+
if not unique_paras:
|
| 162 |
+
return context[:1000]
|
| 163 |
+
|
| 164 |
+
# Build response with citations
|
| 165 |
+
response = ""
|
| 166 |
+
|
| 167 |
+
# Main explanation (first paragraph - cleaned and formatted)
|
| 168 |
+
main_para = self._clean_and_format_text(unique_paras[0])
|
| 169 |
+
main_para = self._apply_bold_formatting(main_para)
|
| 170 |
+
|
| 171 |
+
# Add citation to end of main paragraph
|
| 172 |
+
cite_text = self._get_citation(0, metadatas) if metadatas else ""
|
| 173 |
+
response += main_para + cite_text + "\n\n"
|
| 174 |
+
|
| 175 |
+
# Add structured details if more content available
|
| 176 |
+
if len(unique_paras) > 1:
|
| 177 |
+
response += "### Key Points:\n\n"
|
| 178 |
+
|
| 179 |
+
for i, para in enumerate(unique_paras[1:4], 1):
|
| 180 |
+
# Extract first 2-3 sentences
|
| 181 |
+
sentences = [s.strip() for s in para.split('.') if len(s.strip()) > 20]
|
| 182 |
+
if sentences:
|
| 183 |
+
detail = self._clean_and_format_text('. '.join(sentences[:2]) + '.')
|
| 184 |
+
detail = self._apply_bold_formatting(detail)
|
| 185 |
+
|
| 186 |
+
# Add citation
|
| 187 |
+
cite_text = self._get_citation(i, metadatas) if metadatas and i < len(metadatas) else ""
|
| 188 |
+
response += f"{i}. {detail}{cite_text}\n\n"
|
| 189 |
+
|
| 190 |
+
return response.strip()
|
| 191 |
+
|
| 192 |
+
def _get_citation(self, index: int, metadatas: List[Dict] = None) -> str:
|
| 193 |
+
"""Generate inline citation from metadata."""
|
| 194 |
+
if not metadatas or index >= len(metadatas):
|
| 195 |
+
return ""
|
| 196 |
+
|
| 197 |
+
meta = metadatas[index]
|
| 198 |
+
filename = meta.get('filename', 'Unknown')
|
| 199 |
+
|
| 200 |
+
# Remove file extension for cleaner citation
|
| 201 |
+
clean_name = filename.replace('.pdf', '').replace('.docx', '').replace('.txt', '')
|
| 202 |
+
|
| 203 |
+
return f" **[{clean_name}]**"
|
| 204 |
+
|
| 205 |
+
def _create_summary_with_citations(
|
| 206 |
+
self,
|
| 207 |
+
context: str,
|
| 208 |
+
query: str,
|
| 209 |
+
metadatas: List[Dict] = None
|
| 210 |
+
) -> str:
|
| 211 |
+
"""Create a summary with citations."""
|
| 212 |
+
sentences = []
|
| 213 |
+
seen = set()
|
| 214 |
+
for s in context.split('.'):
|
| 215 |
+
s_clean = s.strip()
|
| 216 |
+
if len(s_clean) > 40 and s_clean.lower() not in seen:
|
| 217 |
+
sentences.append(s_clean)
|
| 218 |
+
seen.add(s_clean.lower())
|
| 219 |
+
if len(sentences) >= 6:
|
| 220 |
+
break
|
| 221 |
+
|
| 222 |
+
if not sentences:
|
| 223 |
+
return context[:800]
|
| 224 |
+
|
| 225 |
+
response = "## Summary\n\n"
|
| 226 |
+
for i, point in enumerate(sentences, 1):
|
| 227 |
+
cite = self._get_citation(i-1, metadatas) if metadatas else ""
|
| 228 |
+
response += f"{i}. {point}.{cite}\n\n"
|
| 229 |
+
|
| 230 |
+
return response.strip()
|
| 231 |
+
|
| 232 |
+
def _create_qa_with_citations(
|
| 233 |
+
self,
|
| 234 |
+
context: str,
|
| 235 |
+
query: str,
|
| 236 |
+
metadatas: List[Dict] = None
|
| 237 |
+
) -> str:
|
| 238 |
+
"""Answer with strict source grounding."""
|
| 239 |
+
paragraphs = [p.strip() for p in context.split('\n\n') if len(p.strip()) > 50]
|
| 240 |
+
|
| 241 |
+
if not paragraphs:
|
| 242 |
+
sentences = [s.strip() + '.' for s in context.split('.') if len(s.strip()) > 30]
|
| 243 |
+
response = ' '.join(sentences[:6])
|
| 244 |
+
cite = self._get_citation(0, metadatas) if metadatas else ""
|
| 245 |
+
return response + cite
|
| 246 |
+
|
| 247 |
+
# Remove duplicates
|
| 248 |
+
unique_paras = []
|
| 249 |
+
seen = set()
|
| 250 |
+
for para in paragraphs:
|
| 251 |
+
para_key = para.lower()[:150]
|
| 252 |
+
if para_key not in seen:
|
| 253 |
+
unique_paras.append(para)
|
| 254 |
+
seen.add(para_key)
|
| 255 |
+
if len(unique_paras) >= 3:
|
| 256 |
+
break
|
| 257 |
+
|
| 258 |
+
# Fix spacing and add citations
|
| 259 |
+
response = unique_paras[0] if unique_paras else context[:800]
|
| 260 |
+
response = re.sub(r'([.!?])([A-Z])', r'\1 \2', response)
|
| 261 |
+
cite = self._get_citation(0, metadatas) if metadatas else ""
|
| 262 |
+
response += cite
|
| 263 |
+
|
| 264 |
+
# Add supporting details if available
|
| 265 |
+
if len(unique_paras) > 1:
|
| 266 |
+
second_para = re.sub(r'([.!?])([A-Z])', r'\1 \2', unique_paras[1])
|
| 267 |
+
cite2 = self._get_citation(1, metadatas) if metadatas and len(metadatas) > 1 else ""
|
| 268 |
+
response += "\n\n" + second_para + cite2
|
| 269 |
+
|
| 270 |
+
return response.strip()
|
| 271 |
+
|
| 272 |
+
def _create_notes_with_citations(
|
| 273 |
+
self,
|
| 274 |
+
context: str,
|
| 275 |
+
query: str,
|
| 276 |
+
metadatas: List[Dict] = None
|
| 277 |
+
) -> str:
|
| 278 |
+
"""Create study notes with source attribution."""
|
| 279 |
+
sections = [s.strip() for s in context.split('\n\n') if len(s.strip()) > 40]
|
| 280 |
+
|
| 281 |
+
# Remove duplicates
|
| 282 |
+
unique_sections = []
|
| 283 |
+
seen = set()
|
| 284 |
+
for section in sections:
|
| 285 |
+
section_key = section.lower()[:100]
|
| 286 |
+
if section_key not in seen:
|
| 287 |
+
unique_sections.append(section)
|
| 288 |
+
seen.add(section_key)
|
| 289 |
+
if len(unique_sections) >= 6:
|
| 290 |
+
break
|
| 291 |
+
|
| 292 |
+
if not unique_sections:
|
| 293 |
+
return context[:1000]
|
| 294 |
+
|
| 295 |
+
response = "## Study Notes\n\n"
|
| 296 |
+
|
| 297 |
+
for i, section in enumerate(unique_sections, 1):
|
| 298 |
+
sentences = [s.strip() for s in section.split('.') if len(s.strip()) > 20]
|
| 299 |
+
|
| 300 |
+
if sentences:
|
| 301 |
+
heading = sentences[0]
|
| 302 |
+
cite = self._get_citation(i-1, metadatas) if metadatas else ""
|
| 303 |
+
response += f"### {i}. {heading}{cite}\n\n"
|
| 304 |
+
|
| 305 |
+
for sent in sentences[1:3]:
|
| 306 |
+
response += f"- {sent}\n"
|
| 307 |
+
response += "\n"
|
| 308 |
+
|
| 309 |
+
return response.strip()
|
| 310 |
+
|
| 311 |
+
def _create_summary(self, context: str, query: str) -> str:
|
| 312 |
+
"""Create a clean summary from retrieved context."""
|
| 313 |
+
# Extract key sentences - remove duplicates
|
| 314 |
+
sentences = []
|
| 315 |
+
seen = set()
|
| 316 |
+
for s in context.split('.'):
|
| 317 |
+
s_clean = s.strip()
|
| 318 |
+
# Remove duplicates and filter short/low-quality sentences
|
| 319 |
+
if len(s_clean) > 40 and s_clean.lower() not in seen:
|
| 320 |
+
sentences.append(s_clean)
|
| 321 |
+
seen.add(s_clean.lower())
|
| 322 |
+
if len(sentences) >= 6:
|
| 323 |
+
break
|
| 324 |
+
|
| 325 |
+
if not sentences:
|
| 326 |
+
return context[:800]
|
| 327 |
+
|
| 328 |
+
response = "## Summary\n\n"
|
| 329 |
+
for i, point in enumerate(sentences, 1):
|
| 330 |
+
response += f"{i}. {point}.\n\n"
|
| 331 |
+
|
| 332 |
+
return response.strip()
|
| 333 |
+
|
| 334 |
+
def _create_explanation(self, context: str, query: str) -> str:
|
| 335 |
+
"""Create a well-formatted explanation from retrieved context."""
|
| 336 |
+
# Remove duplicate paragraphs
|
| 337 |
+
paragraphs = []
|
| 338 |
+
seen = set()
|
| 339 |
+
for para in context.split('\n\n'):
|
| 340 |
+
para_clean = para.strip()
|
| 341 |
+
# Keep unique, substantial paragraphs
|
| 342 |
+
if len(para_clean) > 50:
|
| 343 |
+
para_lower = para_clean.lower()[:200] # Check first 200 chars for duplicates
|
| 344 |
+
if para_lower not in seen:
|
| 345 |
+
paragraphs.append(para_clean)
|
| 346 |
+
seen.add(para_lower)
|
| 347 |
+
|
| 348 |
+
if not paragraphs:
|
| 349 |
+
# Fallback: split by sentence
|
| 350 |
+
sentences = [s.strip() + '.' for s in context.split('.') if len(s.strip()) > 30]
|
| 351 |
+
return ' '.join(sentences[:8])
|
| 352 |
+
|
| 353 |
+
# Build clean, formatted response with proper spacing
|
| 354 |
+
response = ""
|
| 355 |
+
|
| 356 |
+
# Add first paragraph as main explanation (ensure spacing between sentences)
|
| 357 |
+
first_para = paragraphs[0]
|
| 358 |
+
# Add space after punctuation if missing
|
| 359 |
+
import re
|
| 360 |
+
first_para = re.sub(r'([.!?])([A-Z])', r'\1 \2', first_para)
|
| 361 |
+
response += first_para
|
| 362 |
+
|
| 363 |
+
# Add additional details if available
|
| 364 |
+
if len(paragraphs) > 1:
|
| 365 |
+
response += "\n\n### Key Points:\n\n"
|
| 366 |
+
for i, para in enumerate(paragraphs[1:4], 1): # Max 3 additional points
|
| 367 |
+
# Extract first sentence as bullet
|
| 368 |
+
sentences = [s.strip() for s in para.split('.') if len(s.strip()) > 20]
|
| 369 |
+
if sentences:
|
| 370 |
+
response += f"• {sentences[0]}.\n"
|
| 371 |
+
if len(sentences) > 1 and len(sentences[1]) > 20:
|
| 372 |
+
response += f" {sentences[1]}.\n"
|
| 373 |
+
response += "\n"
|
| 374 |
+
|
| 375 |
+
return response.strip()
|
| 376 |
+
|
| 377 |
+
def _create_qa(self, context: str, query: str) -> str:
|
| 378 |
+
"""Answer a question with clean formatting."""
|
| 379 |
+
# Find most relevant paragraphs
|
| 380 |
+
paragraphs = [p.strip() for p in context.split('\n\n') if len(p.strip()) > 50]
|
| 381 |
+
|
| 382 |
+
if not paragraphs:
|
| 383 |
+
sentences = [s.strip() + '.' for s in context.split('.') if len(s.strip()) > 30]
|
| 384 |
+
return ' '.join(sentences[:6])
|
| 385 |
+
|
| 386 |
+
# Remove duplicates
|
| 387 |
+
unique_paras = []
|
| 388 |
+
seen = set()
|
| 389 |
+
for para in paragraphs:
|
| 390 |
+
para_key = para.lower()[:150]
|
| 391 |
+
if para_key not in seen:
|
| 392 |
+
unique_paras.append(para)
|
| 393 |
+
seen.add(para_key)
|
| 394 |
+
if len(unique_paras) >= 3:
|
| 395 |
+
break
|
| 396 |
+
|
| 397 |
+
# Fix spacing in response
|
| 398 |
+
import re
|
| 399 |
+
response = unique_paras[0] if unique_paras else context[:800]
|
| 400 |
+
response = re.sub(r'([.!?])([A-Z])', r'\1 \2', response)
|
| 401 |
+
|
| 402 |
+
# Add supporting details if available
|
| 403 |
+
if len(unique_paras) > 1:
|
| 404 |
+
second_para = re.sub(r'([.!?])([A-Z])', r'\1 \2', unique_paras[1])
|
| 405 |
+
response += "\n\n" + second_para
|
| 406 |
+
|
| 407 |
+
return response.strip()
|
| 408 |
+
|
| 409 |
+
def _create_notes(self, context: str, query: str) -> str:
|
| 410 |
+
"""Create well-structured study notes."""
|
| 411 |
+
# Split and clean sections
|
| 412 |
+
sections = [s.strip() for s in context.split('\n\n') if len(s.strip()) > 40]
|
| 413 |
+
|
| 414 |
+
# Remove duplicates
|
| 415 |
+
unique_sections = []
|
| 416 |
+
seen = set()
|
| 417 |
+
for section in sections:
|
| 418 |
+
section_key = section.lower()[:100]
|
| 419 |
+
if section_key not in seen:
|
| 420 |
+
unique_sections.append(section)
|
| 421 |
+
seen.add(section_key)
|
| 422 |
+
if len(unique_sections) >= 6:
|
| 423 |
+
break
|
| 424 |
+
|
| 425 |
+
if not unique_sections:
|
| 426 |
+
return context[:1000]
|
| 427 |
+
|
| 428 |
+
response = "## Study Notes\n\n"
|
| 429 |
+
|
| 430 |
+
for i, section in enumerate(unique_sections, 1):
|
| 431 |
+
# Extract key information
|
| 432 |
+
sentences = [s.strip() for s in section.split('.') if len(s.strip()) > 20]
|
| 433 |
+
|
| 434 |
+
if sentences:
|
| 435 |
+
# Use first sentence as heading
|
| 436 |
+
heading = sentences[0]
|
| 437 |
+
response += f"### {i}. {heading}\n\n"
|
| 438 |
+
|
| 439 |
+
# Add bullet points for remaining content
|
| 440 |
+
for sent in sentences[1:3]: # Max 2 additional sentences
|
| 441 |
+
response += f"- {sent}\n"
|
| 442 |
+
response += "\n"
|
| 443 |
+
|
| 444 |
+
return response.strip()
|
utils/spaces_manager.py
ADDED
|
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Spaces (Workspaces) manager for organizing chats and files by subject.
|
| 3 |
+
Each space has its own vector DB and chat history.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import json
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from typing import List, Dict, Optional
|
| 9 |
+
from datetime import datetime
|
| 10 |
+
import config
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class SpacesManager:
|
| 14 |
+
"""Manages workspaces (Spaces) for organizing study materials by subject."""
|
| 15 |
+
|
| 16 |
+
def __init__(self):
|
| 17 |
+
self.spaces_file = config.DATA_DIR / "spaces.json"
|
| 18 |
+
self.spaces_data = self._load_spaces()
|
| 19 |
+
|
| 20 |
+
def _load_spaces(self) -> Dict:
|
| 21 |
+
"""Load spaces from file."""
|
| 22 |
+
if self.spaces_file.exists():
|
| 23 |
+
try:
|
| 24 |
+
with open(self.spaces_file, 'r', encoding='utf-8') as f:
|
| 25 |
+
return json.load(f)
|
| 26 |
+
except Exception:
|
| 27 |
+
return self._create_default_spaces()
|
| 28 |
+
return self._create_default_spaces()
|
| 29 |
+
|
| 30 |
+
def _create_default_spaces(self) -> Dict:
|
| 31 |
+
"""Create default spaces structure."""
|
| 32 |
+
return {
|
| 33 |
+
"spaces": [
|
| 34 |
+
{
|
| 35 |
+
"id": "general",
|
| 36 |
+
"name": "General",
|
| 37 |
+
"description": "General study materials",
|
| 38 |
+
"created_at": datetime.now().isoformat(),
|
| 39 |
+
"file_count": 0,
|
| 40 |
+
"chat_count": 0
|
| 41 |
+
}
|
| 42 |
+
]
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
def save_spaces(self):
|
| 46 |
+
"""Save spaces to file."""
|
| 47 |
+
try:
|
| 48 |
+
with open(self.spaces_file, 'w', encoding='utf-8') as f:
|
| 49 |
+
json.dump(self.spaces_data, f, indent=2)
|
| 50 |
+
except Exception as e:
|
| 51 |
+
print(f"Error saving spaces: {e}")
|
| 52 |
+
|
| 53 |
+
def get_all_spaces(self) -> List[Dict]:
|
| 54 |
+
"""Get all spaces."""
|
| 55 |
+
return self.spaces_data.get("spaces", [])
|
| 56 |
+
|
| 57 |
+
def get_space(self, space_id: str) -> Optional[Dict]:
|
| 58 |
+
"""Get specific space by ID."""
|
| 59 |
+
for space in self.spaces_data.get("spaces", []):
|
| 60 |
+
if space["id"] == space_id:
|
| 61 |
+
return space
|
| 62 |
+
return None
|
| 63 |
+
|
| 64 |
+
def create_space(self, name: str, description: str = "") -> Dict:
|
| 65 |
+
"""Create a new space."""
|
| 66 |
+
space_id = name.lower().replace(" ", "_")
|
| 67 |
+
|
| 68 |
+
# Check if space already exists
|
| 69 |
+
if self.get_space(space_id):
|
| 70 |
+
raise ValueError(f"Space '{name}' already exists")
|
| 71 |
+
|
| 72 |
+
new_space = {
|
| 73 |
+
"id": space_id,
|
| 74 |
+
"name": name,
|
| 75 |
+
"description": description,
|
| 76 |
+
"created_at": datetime.now().isoformat(),
|
| 77 |
+
"file_count": 0,
|
| 78 |
+
"chat_count": 0
|
| 79 |
+
}
|
| 80 |
+
|
| 81 |
+
self.spaces_data["spaces"].append(new_space)
|
| 82 |
+
self.save_spaces()
|
| 83 |
+
|
| 84 |
+
# Create dedicated directories for this space
|
| 85 |
+
space_dir = config.DATA_DIR / "spaces" / space_id
|
| 86 |
+
space_dir.mkdir(parents=True, exist_ok=True)
|
| 87 |
+
(space_dir / "chats").mkdir(exist_ok=True)
|
| 88 |
+
(space_dir / "vector_db").mkdir(exist_ok=True)
|
| 89 |
+
(space_dir / "uploads").mkdir(exist_ok=True)
|
| 90 |
+
|
| 91 |
+
return new_space
|
| 92 |
+
|
| 93 |
+
def delete_space(self, space_id: str):
|
| 94 |
+
"""Delete a space (except General)."""
|
| 95 |
+
if space_id == "general":
|
| 96 |
+
raise ValueError("Cannot delete General space")
|
| 97 |
+
|
| 98 |
+
self.spaces_data["spaces"] = [
|
| 99 |
+
s for s in self.spaces_data["spaces"]
|
| 100 |
+
if s["id"] != space_id
|
| 101 |
+
]
|
| 102 |
+
self.save_spaces()
|
| 103 |
+
|
| 104 |
+
def update_space_counts(self, space_id: str, file_count: int = None, chat_count: int = None):
|
| 105 |
+
"""Update file/chat counts for a space."""
|
| 106 |
+
space = self.get_space(space_id)
|
| 107 |
+
if space:
|
| 108 |
+
if file_count is not None:
|
| 109 |
+
space["file_count"] = file_count
|
| 110 |
+
if chat_count is not None:
|
| 111 |
+
space["chat_count"] = chat_count
|
| 112 |
+
self.save_spaces()
|
| 113 |
+
|
| 114 |
+
def get_space_chats_dir(self, space_id: str) -> Path:
|
| 115 |
+
"""Get chats directory for a space."""
|
| 116 |
+
return config.DATA_DIR / "spaces" / space_id / "chats"
|
| 117 |
+
|
| 118 |
+
def get_space_vector_db_dir(self, space_id: str) -> Path:
|
| 119 |
+
"""Get vector DB directory for a space."""
|
| 120 |
+
return config.DATA_DIR / "spaces" / space_id / "vector_db"
|
| 121 |
+
|
| 122 |
+
def get_space_uploads_dir(self, space_id: str) -> Path:
|
| 123 |
+
"""Get uploads directory for a space."""
|
| 124 |
+
return config.DATA_DIR / "spaces" / space_id / "uploads"
|
utils/studio_generator.py
ADDED
|
@@ -0,0 +1,309 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Studio Generator - Uses LLM to generate flashcards and quiz questions
|
| 3 |
+
"""
|
| 4 |
+
from typing import List, Optional
|
| 5 |
+
import json
|
| 6 |
+
import re
|
| 7 |
+
|
| 8 |
+
from models.studio_models import (
|
| 9 |
+
Flashcard, FlashcardCreate, FlashcardGenerateRequest,
|
| 10 |
+
Quiz, QuizQuestion, QuizGenerateRequest,
|
| 11 |
+
DifficultyLevel, QuestionType
|
| 12 |
+
)
|
| 13 |
+
from utils.llm_generator import LLMGenerator
|
| 14 |
+
from utils.studio_manager import StudioManager
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class StudioGenerator:
|
| 18 |
+
"""Generate flashcards and quizzes using LLM"""
|
| 19 |
+
|
| 20 |
+
def __init__(self, llm_generator: LLMGenerator, studio_manager: StudioManager):
|
| 21 |
+
self.llm = llm_generator
|
| 22 |
+
self.studio = studio_manager
|
| 23 |
+
|
| 24 |
+
async def generate_flashcards(self, request: FlashcardGenerateRequest) -> List[Flashcard]:
|
| 25 |
+
"""Generate flashcards from content using LLM"""
|
| 26 |
+
|
| 27 |
+
# Gather source content
|
| 28 |
+
content = await self._gather_content(
|
| 29 |
+
request.space_id,
|
| 30 |
+
request.source_type,
|
| 31 |
+
request.source_ids,
|
| 32 |
+
request.text_content
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
if not content:
|
| 36 |
+
return []
|
| 37 |
+
|
| 38 |
+
# Create prompt for LLM
|
| 39 |
+
prompt = self._create_flashcard_prompt(content, request.num_cards, request.difficulty)
|
| 40 |
+
|
| 41 |
+
# Generate flashcards using LLM
|
| 42 |
+
response = await self.llm.generate(prompt, max_tokens=2000)
|
| 43 |
+
|
| 44 |
+
if not response:
|
| 45 |
+
return []
|
| 46 |
+
|
| 47 |
+
# Parse flashcards from response
|
| 48 |
+
flashcards = self._parse_flashcards(
|
| 49 |
+
response,
|
| 50 |
+
request.space_id,
|
| 51 |
+
request.source_type,
|
| 52 |
+
request.source_ids,
|
| 53 |
+
request.difficulty
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
# Save flashcards to storage
|
| 57 |
+
saved_cards = []
|
| 58 |
+
for card_data in flashcards:
|
| 59 |
+
card = self.studio.create_flashcard(card_data)
|
| 60 |
+
saved_cards.append(card)
|
| 61 |
+
|
| 62 |
+
return saved_cards
|
| 63 |
+
|
| 64 |
+
async def generate_quiz(self, request: QuizGenerateRequest) -> Optional[Quiz]:
|
| 65 |
+
"""Generate a quiz from content using LLM"""
|
| 66 |
+
|
| 67 |
+
# Gather source content
|
| 68 |
+
content = await self._gather_content(
|
| 69 |
+
request.space_id,
|
| 70 |
+
request.source_type,
|
| 71 |
+
request.source_ids,
|
| 72 |
+
request.text_content
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
if not content:
|
| 76 |
+
return None
|
| 77 |
+
|
| 78 |
+
# Create prompt for LLM
|
| 79 |
+
prompt = self._create_quiz_prompt(
|
| 80 |
+
content,
|
| 81 |
+
request.num_questions,
|
| 82 |
+
request.question_types,
|
| 83 |
+
request.difficulty
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
# Generate quiz using LLM
|
| 87 |
+
response = await self.llm.generate(prompt, max_tokens=3000)
|
| 88 |
+
|
| 89 |
+
if not response:
|
| 90 |
+
return None
|
| 91 |
+
|
| 92 |
+
# Parse quiz questions from response
|
| 93 |
+
questions = self._parse_quiz_questions(response, request.question_types, request.difficulty)
|
| 94 |
+
|
| 95 |
+
if not questions:
|
| 96 |
+
return None
|
| 97 |
+
|
| 98 |
+
# Create quiz
|
| 99 |
+
from models.studio_models import QuizCreate
|
| 100 |
+
quiz_data = QuizCreate(
|
| 101 |
+
space_id=request.space_id,
|
| 102 |
+
title=request.title,
|
| 103 |
+
description=f"Generated quiz with {len(questions)} questions",
|
| 104 |
+
questions=questions,
|
| 105 |
+
source_type=request.source_type,
|
| 106 |
+
source_ids=request.source_ids
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
quiz = self.studio.create_quiz(quiz_data)
|
| 110 |
+
return quiz
|
| 111 |
+
|
| 112 |
+
async def _gather_content(
|
| 113 |
+
self,
|
| 114 |
+
space_id: str,
|
| 115 |
+
source_type: str,
|
| 116 |
+
source_ids: Optional[List[str]],
|
| 117 |
+
text_content: Optional[str]
|
| 118 |
+
) -> str:
|
| 119 |
+
"""Gather content from various sources"""
|
| 120 |
+
|
| 121 |
+
if text_content:
|
| 122 |
+
return text_content
|
| 123 |
+
|
| 124 |
+
content_parts = []
|
| 125 |
+
|
| 126 |
+
if source_type == "notebook" and source_ids:
|
| 127 |
+
# Get notebook entries
|
| 128 |
+
for entry_id in source_ids:
|
| 129 |
+
entry = self.studio.get_notebook_entry(entry_id)
|
| 130 |
+
if entry:
|
| 131 |
+
content_parts.append(f"# {entry.title}\n\n{entry.content}")
|
| 132 |
+
|
| 133 |
+
elif source_type == "file" and source_ids:
|
| 134 |
+
# TODO: Integrate with file retriever to get file content
|
| 135 |
+
# For now, just return a placeholder
|
| 136 |
+
content_parts.append("File content retrieval not yet implemented")
|
| 137 |
+
|
| 138 |
+
return "\n\n---\n\n".join(content_parts)
|
| 139 |
+
|
| 140 |
+
def _create_flashcard_prompt(self, content: str, num_cards: int, difficulty: DifficultyLevel) -> str:
|
| 141 |
+
"""Create prompt for flashcard generation"""
|
| 142 |
+
|
| 143 |
+
difficulty_desc = {
|
| 144 |
+
DifficultyLevel.EASY: "basic concepts and definitions",
|
| 145 |
+
DifficultyLevel.MEDIUM: "key concepts and applications",
|
| 146 |
+
DifficultyLevel.HARD: "advanced concepts and critical thinking"
|
| 147 |
+
}
|
| 148 |
+
|
| 149 |
+
prompt = f"""Based on the following content, create {num_cards} flashcards focusing on {difficulty_desc[difficulty]}.
|
| 150 |
+
|
| 151 |
+
Content:
|
| 152 |
+
{content[:3000]} # Limit content length
|
| 153 |
+
|
| 154 |
+
Format your response as a JSON array of flashcards, where each flashcard has:
|
| 155 |
+
- "question": The question or prompt (front of card)
|
| 156 |
+
- "answer": The answer or explanation (back of card)
|
| 157 |
+
|
| 158 |
+
Example format:
|
| 159 |
+
[
|
| 160 |
+
{{"question": "What is...", "answer": "It is..."}},
|
| 161 |
+
{{"question": "How does...", "answer": "It works by..."}}
|
| 162 |
+
]
|
| 163 |
+
|
| 164 |
+
Generate exactly {num_cards} flashcards:"""
|
| 165 |
+
|
| 166 |
+
return prompt
|
| 167 |
+
|
| 168 |
+
def _create_quiz_prompt(
|
| 169 |
+
self,
|
| 170 |
+
content: str,
|
| 171 |
+
num_questions: int,
|
| 172 |
+
question_types: List[QuestionType],
|
| 173 |
+
difficulty: DifficultyLevel
|
| 174 |
+
) -> str:
|
| 175 |
+
"""Create prompt for quiz generation"""
|
| 176 |
+
|
| 177 |
+
types_str = ", ".join(qt.value for qt in question_types)
|
| 178 |
+
|
| 179 |
+
prompt = f"""Based on the following content, create a quiz with {num_questions} questions.
|
| 180 |
+
|
| 181 |
+
Content:
|
| 182 |
+
{content[:3000]} # Limit content length
|
| 183 |
+
|
| 184 |
+
Question types to include: {types_str}
|
| 185 |
+
Difficulty level: {difficulty.value}
|
| 186 |
+
|
| 187 |
+
Format your response as a JSON array of questions, where each question has:
|
| 188 |
+
- "question": The question text
|
| 189 |
+
- "type": One of: {types_str}
|
| 190 |
+
- "options": Array of 4 options (for multiple_choice only)
|
| 191 |
+
- "correct_answer": The correct answer
|
| 192 |
+
- "explanation": Brief explanation of why this is correct
|
| 193 |
+
|
| 194 |
+
Example format:
|
| 195 |
+
[
|
| 196 |
+
{{
|
| 197 |
+
"question": "What is...",
|
| 198 |
+
"type": "multiple_choice",
|
| 199 |
+
"options": ["Option A", "Option B", "Option C", "Option D"],
|
| 200 |
+
"correct_answer": "Option A",
|
| 201 |
+
"explanation": "This is correct because..."
|
| 202 |
+
}},
|
| 203 |
+
{{
|
| 204 |
+
"question": "True or False: ...",
|
| 205 |
+
"type": "true_false",
|
| 206 |
+
"options": ["True", "False"],
|
| 207 |
+
"correct_answer": "True",
|
| 208 |
+
"explanation": "This is true because..."
|
| 209 |
+
}}
|
| 210 |
+
]
|
| 211 |
+
|
| 212 |
+
Generate exactly {num_questions} questions:"""
|
| 213 |
+
|
| 214 |
+
return prompt
|
| 215 |
+
|
| 216 |
+
def _parse_flashcards(
|
| 217 |
+
self,
|
| 218 |
+
response: str,
|
| 219 |
+
space_id: str,
|
| 220 |
+
source_type: str,
|
| 221 |
+
source_ids: Optional[List[str]],
|
| 222 |
+
difficulty: DifficultyLevel
|
| 223 |
+
) -> List[FlashcardCreate]:
|
| 224 |
+
"""Parse flashcards from LLM response"""
|
| 225 |
+
|
| 226 |
+
flashcards = []
|
| 227 |
+
|
| 228 |
+
try:
|
| 229 |
+
# Try to extract JSON from response
|
| 230 |
+
json_match = re.search(r'\[[\s\S]*\]', response)
|
| 231 |
+
if json_match:
|
| 232 |
+
cards_data = json.loads(json_match.group(0))
|
| 233 |
+
|
| 234 |
+
for card_data in cards_data:
|
| 235 |
+
if 'question' in card_data and 'answer' in card_data:
|
| 236 |
+
flashcards.append(FlashcardCreate(
|
| 237 |
+
space_id=space_id,
|
| 238 |
+
question=card_data['question'],
|
| 239 |
+
answer=card_data['answer'],
|
| 240 |
+
difficulty=difficulty,
|
| 241 |
+
source_type=source_type,
|
| 242 |
+
source_id=source_ids[0] if source_ids else None
|
| 243 |
+
))
|
| 244 |
+
except Exception as e:
|
| 245 |
+
print(f"Error parsing flashcards: {e}")
|
| 246 |
+
# Fallback: Try to parse as simple Q&A pairs
|
| 247 |
+
lines = response.split('\n')
|
| 248 |
+
current_question = None
|
| 249 |
+
|
| 250 |
+
for line in lines:
|
| 251 |
+
line = line.strip()
|
| 252 |
+
if line.startswith('Q:') or line.startswith('Question:'):
|
| 253 |
+
current_question = line.split(':', 1)[1].strip()
|
| 254 |
+
elif line.startswith('A:') or line.startswith('Answer:'):
|
| 255 |
+
if current_question:
|
| 256 |
+
answer = line.split(':', 1)[1].strip()
|
| 257 |
+
flashcards.append(FlashcardCreate(
|
| 258 |
+
space_id=space_id,
|
| 259 |
+
question=current_question,
|
| 260 |
+
answer=answer,
|
| 261 |
+
difficulty=difficulty,
|
| 262 |
+
source_type=source_type,
|
| 263 |
+
source_id=source_ids[0] if source_ids else None
|
| 264 |
+
))
|
| 265 |
+
current_question = None
|
| 266 |
+
|
| 267 |
+
return flashcards
|
| 268 |
+
|
| 269 |
+
def _parse_quiz_questions(
|
| 270 |
+
self,
|
| 271 |
+
response: str,
|
| 272 |
+
question_types: List[QuestionType],
|
| 273 |
+
difficulty: DifficultyLevel
|
| 274 |
+
) -> List[QuizQuestion]:
|
| 275 |
+
"""Parse quiz questions from LLM response"""
|
| 276 |
+
|
| 277 |
+
questions = []
|
| 278 |
+
|
| 279 |
+
try:
|
| 280 |
+
# Try to extract JSON from response
|
| 281 |
+
json_match = re.search(r'\[[\s\S]*\]', response)
|
| 282 |
+
if json_match:
|
| 283 |
+
questions_data = json.loads(json_match.group(0))
|
| 284 |
+
|
| 285 |
+
for idx, q_data in enumerate(questions_data):
|
| 286 |
+
import uuid
|
| 287 |
+
|
| 288 |
+
# Parse question type
|
| 289 |
+
q_type = QuestionType.MULTIPLE_CHOICE
|
| 290 |
+
if 'type' in q_data:
|
| 291 |
+
try:
|
| 292 |
+
q_type = QuestionType(q_data['type'])
|
| 293 |
+
except ValueError:
|
| 294 |
+
q_type = QuestionType.MULTIPLE_CHOICE
|
| 295 |
+
|
| 296 |
+
questions.append(QuizQuestion(
|
| 297 |
+
id=str(uuid.uuid4()),
|
| 298 |
+
question=q_data.get('question', ''),
|
| 299 |
+
type=q_type,
|
| 300 |
+
options=q_data.get('options'),
|
| 301 |
+
correct_answer=q_data.get('correct_answer', ''),
|
| 302 |
+
explanation=q_data.get('explanation'),
|
| 303 |
+
points=1,
|
| 304 |
+
difficulty=difficulty
|
| 305 |
+
))
|
| 306 |
+
except Exception as e:
|
| 307 |
+
print(f"Error parsing quiz questions: {e}")
|
| 308 |
+
|
| 309 |
+
return questions
|
utils/studio_manager.py
ADDED
|
@@ -0,0 +1,473 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Studio Manager - Handles Notebook, Flashcards, and Quiz storage and operations
|
| 3 |
+
"""
|
| 4 |
+
import json
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from typing import List, Optional, Dict, Any
|
| 7 |
+
from datetime import datetime, timedelta
|
| 8 |
+
import uuid
|
| 9 |
+
|
| 10 |
+
from models.studio_models import (
|
| 11 |
+
NotebookEntry, NotebookEntryCreate, NotebookEntryUpdate,
|
| 12 |
+
Flashcard, FlashcardCreate, FlashcardUpdate, FlashcardReview,
|
| 13 |
+
Quiz, QuizCreate, QuizResult, QuizHistory, QuizAnswer,
|
| 14 |
+
MasteryLevel, DifficultyLevel
|
| 15 |
+
)
|
| 16 |
+
import config
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class StudioManager:
|
| 20 |
+
"""Manages all Studio features: Notebook, Flashcards, Quiz"""
|
| 21 |
+
|
| 22 |
+
def __init__(self):
|
| 23 |
+
"""Initialize studio manager with data directories"""
|
| 24 |
+
self.studio_dir = config.DATA_DIR / "studio"
|
| 25 |
+
self.notebooks_dir = self.studio_dir / "notebooks"
|
| 26 |
+
self.notebook_dir = self.studio_dir / "notebook"
|
| 27 |
+
self.flashcards_dir = self.studio_dir / "flashcards"
|
| 28 |
+
self.quizzes_dir = self.studio_dir / "quizzes"
|
| 29 |
+
self.quiz_results_dir = self.studio_dir / "quiz_results"
|
| 30 |
+
|
| 31 |
+
# Create directories
|
| 32 |
+
for directory in [self.notebooks_dir, self.notebook_dir, self.flashcards_dir,
|
| 33 |
+
self.quizzes_dir, self.quiz_results_dir]:
|
| 34 |
+
directory.mkdir(parents=True, exist_ok=True)
|
| 35 |
+
|
| 36 |
+
def _get_notebook_file_path(self, space_id: str) -> Path:
|
| 37 |
+
"""Get the metadata file path for a space notebook."""
|
| 38 |
+
return self.notebooks_dir / f"{space_id}.json"
|
| 39 |
+
|
| 40 |
+
def ensure_space_notebook(self, space_id: str, space_name: str = "") -> Dict[str, Any]:
|
| 41 |
+
"""Create notebook metadata for a space if it does not exist."""
|
| 42 |
+
file_path = self._get_notebook_file_path(space_id)
|
| 43 |
+
|
| 44 |
+
if file_path.exists():
|
| 45 |
+
with open(file_path, 'r', encoding='utf-8') as f:
|
| 46 |
+
return json.load(f)
|
| 47 |
+
|
| 48 |
+
now = datetime.now().isoformat()
|
| 49 |
+
notebook_name = space_name.strip() if space_name and space_name.strip() else space_id
|
| 50 |
+
notebook_data = {
|
| 51 |
+
"id": space_id,
|
| 52 |
+
"space_id": space_id,
|
| 53 |
+
"name": notebook_name,
|
| 54 |
+
"created_at": now,
|
| 55 |
+
"updated_at": now
|
| 56 |
+
}
|
| 57 |
+
|
| 58 |
+
with open(file_path, 'w', encoding='utf-8') as f:
|
| 59 |
+
json.dump(notebook_data, f, indent=2)
|
| 60 |
+
|
| 61 |
+
return notebook_data
|
| 62 |
+
|
| 63 |
+
def get_space_notebook(self, space_id: str) -> Optional[Dict[str, Any]]:
|
| 64 |
+
"""Get notebook metadata for a specific space."""
|
| 65 |
+
file_path = self._get_notebook_file_path(space_id)
|
| 66 |
+
if not file_path.exists():
|
| 67 |
+
return None
|
| 68 |
+
|
| 69 |
+
with open(file_path, 'r', encoding='utf-8') as f:
|
| 70 |
+
return json.load(f)
|
| 71 |
+
|
| 72 |
+
def _derive_title_from_question(self, question: str) -> str:
|
| 73 |
+
"""Generate a readable title from a chat question."""
|
| 74 |
+
question = (question or "").strip()
|
| 75 |
+
if not question:
|
| 76 |
+
return "Chat Note"
|
| 77 |
+
|
| 78 |
+
title = question.replace('\n', ' ')
|
| 79 |
+
return title[:80] + "..." if len(title) > 80 else title
|
| 80 |
+
|
| 81 |
+
# ========================================================================
|
| 82 |
+
# NOTEBOOK OPERATIONS
|
| 83 |
+
# ========================================================================
|
| 84 |
+
|
| 85 |
+
def create_notebook_entry(self, entry_data: NotebookEntryCreate) -> NotebookEntry:
|
| 86 |
+
"""Create a new notebook entry"""
|
| 87 |
+
# Ensure a notebook record exists for this space.
|
| 88 |
+
self.ensure_space_notebook(entry_data.space_id)
|
| 89 |
+
|
| 90 |
+
entry = NotebookEntry(
|
| 91 |
+
id=str(uuid.uuid4()),
|
| 92 |
+
**entry_data.dict()
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
# Save to file
|
| 96 |
+
file_path = self.notebook_dir / f"{entry.id}.json"
|
| 97 |
+
with open(file_path, 'w', encoding='utf-8') as f:
|
| 98 |
+
json.dump(entry.dict(), f, indent=2, default=str)
|
| 99 |
+
|
| 100 |
+
return entry
|
| 101 |
+
|
| 102 |
+
def create_notebook_entry_from_chat(
|
| 103 |
+
self,
|
| 104 |
+
space_id: str,
|
| 105 |
+
question: str,
|
| 106 |
+
answer: str,
|
| 107 |
+
chat_id: Optional[str] = None,
|
| 108 |
+
assistant_timestamp: Optional[str] = None,
|
| 109 |
+
tags: Optional[List[str]] = None,
|
| 110 |
+
space_name: str = ""
|
| 111 |
+
) -> NotebookEntry:
|
| 112 |
+
"""Create a notebook entry from a chat Q/A pair."""
|
| 113 |
+
self.ensure_space_notebook(space_id, space_name=space_name)
|
| 114 |
+
|
| 115 |
+
metadata: Dict[str, Any] = {
|
| 116 |
+
"question": question,
|
| 117 |
+
"assistant_timestamp": assistant_timestamp,
|
| 118 |
+
}
|
| 119 |
+
if chat_id:
|
| 120 |
+
metadata["chat_id"] = chat_id
|
| 121 |
+
|
| 122 |
+
entry_data = NotebookEntryCreate(
|
| 123 |
+
space_id=space_id,
|
| 124 |
+
title=self._derive_title_from_question(question),
|
| 125 |
+
content=f"Q: {question.strip()}\n\nA: {answer.strip()}",
|
| 126 |
+
source_type="chat",
|
| 127 |
+
source_id=chat_id,
|
| 128 |
+
tags=tags or ["chat"],
|
| 129 |
+
metadata=metadata
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
entry = self.create_notebook_entry(entry_data)
|
| 133 |
+
|
| 134 |
+
# Update notebook metadata timestamp.
|
| 135 |
+
notebook_data = self.ensure_space_notebook(space_id, space_name=space_name)
|
| 136 |
+
notebook_data["updated_at"] = datetime.now().isoformat()
|
| 137 |
+
with open(self._get_notebook_file_path(space_id), 'w', encoding='utf-8') as f:
|
| 138 |
+
json.dump(notebook_data, f, indent=2)
|
| 139 |
+
|
| 140 |
+
return entry
|
| 141 |
+
|
| 142 |
+
def get_notebook_entry(self, entry_id: str) -> Optional[NotebookEntry]:
|
| 143 |
+
"""Get a single notebook entry by ID"""
|
| 144 |
+
file_path = self.notebook_dir / f"{entry_id}.json"
|
| 145 |
+
if not file_path.exists():
|
| 146 |
+
return None
|
| 147 |
+
|
| 148 |
+
with open(file_path, 'r', encoding='utf-8') as f:
|
| 149 |
+
data = json.load(f)
|
| 150 |
+
|
| 151 |
+
return NotebookEntry(**data)
|
| 152 |
+
|
| 153 |
+
def list_notebook_entries(self, space_id: Optional[str] = None) -> List[NotebookEntry]:
|
| 154 |
+
"""List all notebook entries, optionally filtered by space"""
|
| 155 |
+
entries = []
|
| 156 |
+
|
| 157 |
+
for file_path in self.notebook_dir.glob("*.json"):
|
| 158 |
+
try:
|
| 159 |
+
with open(file_path, 'r', encoding='utf-8') as f:
|
| 160 |
+
data = json.load(f)
|
| 161 |
+
|
| 162 |
+
entry = NotebookEntry(**data)
|
| 163 |
+
|
| 164 |
+
# Filter by space if specified
|
| 165 |
+
if space_id is None or entry.space_id == space_id:
|
| 166 |
+
entries.append(entry)
|
| 167 |
+
except Exception as e:
|
| 168 |
+
print(f"Error loading notebook entry {file_path}: {e}")
|
| 169 |
+
|
| 170 |
+
# Sort by updated_at descending
|
| 171 |
+
entries.sort(key=lambda x: x.updated_at, reverse=True)
|
| 172 |
+
return entries
|
| 173 |
+
|
| 174 |
+
def update_notebook_entry(self, entry_id: str, update_data: NotebookEntryUpdate) -> Optional[NotebookEntry]:
|
| 175 |
+
"""Update an existing notebook entry"""
|
| 176 |
+
entry = self.get_notebook_entry(entry_id)
|
| 177 |
+
if not entry:
|
| 178 |
+
return None
|
| 179 |
+
|
| 180 |
+
# Update fields
|
| 181 |
+
update_dict = update_data.dict(exclude_unset=True)
|
| 182 |
+
for key, value in update_dict.items():
|
| 183 |
+
setattr(entry, key, value)
|
| 184 |
+
|
| 185 |
+
entry.updated_at = datetime.now()
|
| 186 |
+
|
| 187 |
+
# Save
|
| 188 |
+
file_path = self.notebook_dir / f"{entry_id}.json"
|
| 189 |
+
with open(file_path, 'w', encoding='utf-8') as f:
|
| 190 |
+
json.dump(entry.dict(), f, indent=2, default=str)
|
| 191 |
+
|
| 192 |
+
return entry
|
| 193 |
+
|
| 194 |
+
def delete_notebook_entry(self, entry_id: str) -> bool:
|
| 195 |
+
"""Delete a notebook entry"""
|
| 196 |
+
file_path = self.notebook_dir / f"{entry_id}.json"
|
| 197 |
+
if file_path.exists():
|
| 198 |
+
file_path.unlink()
|
| 199 |
+
return True
|
| 200 |
+
return False
|
| 201 |
+
|
| 202 |
+
# ========================================================================
|
| 203 |
+
# FLASHCARD OPERATIONS
|
| 204 |
+
# ========================================================================
|
| 205 |
+
|
| 206 |
+
def create_flashcard(self, card_data: FlashcardCreate) -> Flashcard:
|
| 207 |
+
"""Create a new flashcard"""
|
| 208 |
+
card = Flashcard(
|
| 209 |
+
id=str(uuid.uuid4()),
|
| 210 |
+
**card_data.dict()
|
| 211 |
+
)
|
| 212 |
+
|
| 213 |
+
# Save to file
|
| 214 |
+
file_path = self.flashcards_dir / f"{card.id}.json"
|
| 215 |
+
with open(file_path, 'w', encoding='utf-8') as f:
|
| 216 |
+
json.dump(card.dict(), f, indent=2, default=str)
|
| 217 |
+
|
| 218 |
+
return card
|
| 219 |
+
|
| 220 |
+
def get_flashcard(self, card_id: str) -> Optional[Flashcard]:
|
| 221 |
+
"""Get a single flashcard by ID"""
|
| 222 |
+
file_path = self.flashcards_dir / f"{card_id}.json"
|
| 223 |
+
if not file_path.exists():
|
| 224 |
+
return None
|
| 225 |
+
|
| 226 |
+
with open(file_path, 'r', encoding='utf-8') as f:
|
| 227 |
+
data = json.load(f)
|
| 228 |
+
|
| 229 |
+
return Flashcard(**data)
|
| 230 |
+
|
| 231 |
+
def list_flashcards(self, space_id: Optional[str] = None,
|
| 232 |
+
mastery: Optional[MasteryLevel] = None) -> List[Flashcard]:
|
| 233 |
+
"""List all flashcards, optionally filtered"""
|
| 234 |
+
cards = []
|
| 235 |
+
|
| 236 |
+
for file_path in self.flashcards_dir.glob("*.json"):
|
| 237 |
+
try:
|
| 238 |
+
with open(file_path, 'r', encoding='utf-8') as f:
|
| 239 |
+
data = json.load(f)
|
| 240 |
+
|
| 241 |
+
card = Flashcard(**data)
|
| 242 |
+
|
| 243 |
+
# Apply filters
|
| 244 |
+
if space_id and card.space_id != space_id:
|
| 245 |
+
continue
|
| 246 |
+
if mastery and card.mastery != mastery:
|
| 247 |
+
continue
|
| 248 |
+
|
| 249 |
+
cards.append(card)
|
| 250 |
+
except Exception as e:
|
| 251 |
+
print(f"Error loading flashcard {file_path}: {e}")
|
| 252 |
+
|
| 253 |
+
# Sort by next_review date (cards due for review first)
|
| 254 |
+
cards.sort(key=lambda x: x.next_review or datetime.now())
|
| 255 |
+
return cards
|
| 256 |
+
|
| 257 |
+
def update_flashcard(self, card_id: str, update_data: FlashcardUpdate) -> Optional[Flashcard]:
|
| 258 |
+
"""Update a flashcard"""
|
| 259 |
+
card = self.get_flashcard(card_id)
|
| 260 |
+
if not card:
|
| 261 |
+
return None
|
| 262 |
+
|
| 263 |
+
# Update fields
|
| 264 |
+
update_dict = update_data.dict(exclude_unset=True)
|
| 265 |
+
for key, value in update_dict.items():
|
| 266 |
+
setattr(card, key, value)
|
| 267 |
+
|
| 268 |
+
# Save
|
| 269 |
+
file_path = self.flashcards_dir / f"{card_id}.json"
|
| 270 |
+
with open(file_path, 'w', encoding='utf-8') as f:
|
| 271 |
+
json.dump(card.dict(), f, indent=2, default=str)
|
| 272 |
+
|
| 273 |
+
return card
|
| 274 |
+
|
| 275 |
+
def review_flashcard(self, card_id: str, review: FlashcardReview) -> Optional[Flashcard]:
|
| 276 |
+
"""Record a flashcard review and update mastery level"""
|
| 277 |
+
card = self.get_flashcard(card_id)
|
| 278 |
+
if not card:
|
| 279 |
+
return None
|
| 280 |
+
|
| 281 |
+
# Update review stats
|
| 282 |
+
card.review_count += 1
|
| 283 |
+
if review.correct:
|
| 284 |
+
card.correct_count += 1
|
| 285 |
+
|
| 286 |
+
card.last_reviewed = datetime.now()
|
| 287 |
+
|
| 288 |
+
# Update mastery level based on performance
|
| 289 |
+
accuracy = card.correct_count / card.review_count if card.review_count > 0 else 0
|
| 290 |
+
|
| 291 |
+
if accuracy >= 0.9 and card.review_count >= 5:
|
| 292 |
+
card.mastery = MasteryLevel.MASTERED
|
| 293 |
+
card.next_review = datetime.now() + timedelta(days=30)
|
| 294 |
+
elif accuracy >= 0.7 and card.review_count >= 3:
|
| 295 |
+
card.mastery = MasteryLevel.REVIEWING
|
| 296 |
+
card.next_review = datetime.now() + timedelta(days=7)
|
| 297 |
+
elif card.review_count >= 1:
|
| 298 |
+
card.mastery = MasteryLevel.LEARNING
|
| 299 |
+
card.next_review = datetime.now() + timedelta(days=1)
|
| 300 |
+
else:
|
| 301 |
+
card.mastery = MasteryLevel.NEW
|
| 302 |
+
card.next_review = datetime.now()
|
| 303 |
+
|
| 304 |
+
# Save
|
| 305 |
+
file_path = self.flashcards_dir / f"{card_id}.json"
|
| 306 |
+
with open(file_path, 'w', encoding='utf-8') as f:
|
| 307 |
+
json.dump(card.dict(), f, indent=2, default=str)
|
| 308 |
+
|
| 309 |
+
return card
|
| 310 |
+
|
| 311 |
+
def delete_flashcard(self, card_id: str) -> bool:
|
| 312 |
+
"""Delete a flashcard"""
|
| 313 |
+
file_path = self.flashcards_dir / f"{card_id}.json"
|
| 314 |
+
if file_path.exists():
|
| 315 |
+
file_path.unlink()
|
| 316 |
+
return True
|
| 317 |
+
return False
|
| 318 |
+
|
| 319 |
+
# ========================================================================
|
| 320 |
+
# QUIZ OPERATIONS
|
| 321 |
+
# ========================================================================
|
| 322 |
+
|
| 323 |
+
def create_quiz(self, quiz_data: QuizCreate) -> Quiz:
|
| 324 |
+
"""Create a new quiz"""
|
| 325 |
+
quiz = Quiz(
|
| 326 |
+
id=str(uuid.uuid4()),
|
| 327 |
+
**quiz_data.dict()
|
| 328 |
+
)
|
| 329 |
+
|
| 330 |
+
# Save to file
|
| 331 |
+
file_path = self.quizzes_dir / f"{quiz.id}.json"
|
| 332 |
+
with open(file_path, 'w', encoding='utf-8') as f:
|
| 333 |
+
json.dump(quiz.dict(), f, indent=2, default=str)
|
| 334 |
+
|
| 335 |
+
return quiz
|
| 336 |
+
|
| 337 |
+
def get_quiz(self, quiz_id: str) -> Optional[Quiz]:
|
| 338 |
+
"""Get a quiz by ID"""
|
| 339 |
+
file_path = self.quizzes_dir / f"{quiz_id}.json"
|
| 340 |
+
if not file_path.exists():
|
| 341 |
+
return None
|
| 342 |
+
|
| 343 |
+
with open(file_path, 'r', encoding='utf-8') as f:
|
| 344 |
+
data = json.load(f)
|
| 345 |
+
|
| 346 |
+
return Quiz(**data)
|
| 347 |
+
|
| 348 |
+
def list_quizzes(self, space_id: Optional[str] = None) -> List[Quiz]:
|
| 349 |
+
"""List all quizzes, optionally filtered by space"""
|
| 350 |
+
quizzes = []
|
| 351 |
+
|
| 352 |
+
for file_path in self.quizzes_dir.glob("*.json"):
|
| 353 |
+
try:
|
| 354 |
+
with open(file_path, 'r', encoding='utf-8') as f:
|
| 355 |
+
data = json.load(f)
|
| 356 |
+
|
| 357 |
+
quiz = Quiz(**data)
|
| 358 |
+
|
| 359 |
+
if space_id is None or quiz.space_id == space_id:
|
| 360 |
+
quizzes.append(quiz)
|
| 361 |
+
except Exception as e:
|
| 362 |
+
print(f"Error loading quiz {file_path}: {e}")
|
| 363 |
+
|
| 364 |
+
# Sort by created_at descending
|
| 365 |
+
quizzes.sort(key=lambda x: x.created_at, reverse=True)
|
| 366 |
+
return quizzes
|
| 367 |
+
|
| 368 |
+
def delete_quiz(self, quiz_id: str) -> bool:
|
| 369 |
+
"""Delete a quiz"""
|
| 370 |
+
file_path = self.quizzes_dir / f"{quiz_id}.json"
|
| 371 |
+
if file_path.exists():
|
| 372 |
+
file_path.unlink()
|
| 373 |
+
return True
|
| 374 |
+
return False
|
| 375 |
+
|
| 376 |
+
def submit_quiz(self, quiz_id: str, answers: List[QuizAnswer]) -> Optional[QuizResult]:
|
| 377 |
+
"""Submit quiz answers and calculate results"""
|
| 378 |
+
quiz = self.get_quiz(quiz_id)
|
| 379 |
+
if not quiz:
|
| 380 |
+
return None
|
| 381 |
+
|
| 382 |
+
# Create answer lookup
|
| 383 |
+
answer_dict = {ans.question_id: ans for ans in answers}
|
| 384 |
+
|
| 385 |
+
# Calculate results
|
| 386 |
+
total_points = sum(q.points for q in quiz.questions)
|
| 387 |
+
correct_count = 0
|
| 388 |
+
incorrect_count = 0
|
| 389 |
+
earned_points = 0
|
| 390 |
+
detailed_answers = []
|
| 391 |
+
|
| 392 |
+
for question in quiz.questions:
|
| 393 |
+
user_answer = answer_dict.get(question.id)
|
| 394 |
+
is_correct = False
|
| 395 |
+
|
| 396 |
+
if user_answer:
|
| 397 |
+
# Normalize answers for comparison
|
| 398 |
+
correct_ans = question.correct_answer.strip().lower()
|
| 399 |
+
user_ans = user_answer.answer.strip().lower()
|
| 400 |
+
|
| 401 |
+
is_correct = correct_ans == user_ans
|
| 402 |
+
|
| 403 |
+
if is_correct:
|
| 404 |
+
correct_count += 1
|
| 405 |
+
earned_points += question.points
|
| 406 |
+
else:
|
| 407 |
+
incorrect_count += 1
|
| 408 |
+
else:
|
| 409 |
+
incorrect_count += 1
|
| 410 |
+
|
| 411 |
+
detailed_answers.append({
|
| 412 |
+
"question_id": question.id,
|
| 413 |
+
"question": question.question,
|
| 414 |
+
"user_answer": user_answer.answer if user_answer else None,
|
| 415 |
+
"correct_answer": question.correct_answer,
|
| 416 |
+
"is_correct": is_correct,
|
| 417 |
+
"explanation": question.explanation,
|
| 418 |
+
"points": question.points if is_correct else 0
|
| 419 |
+
})
|
| 420 |
+
|
| 421 |
+
# Create result
|
| 422 |
+
result = QuizResult(
|
| 423 |
+
quiz_id=quiz_id,
|
| 424 |
+
submission_id=str(uuid.uuid4()),
|
| 425 |
+
total_questions=len(quiz.questions),
|
| 426 |
+
correct_answers=correct_count,
|
| 427 |
+
incorrect_answers=incorrect_count,
|
| 428 |
+
score_percentage=(correct_count / len(quiz.questions) * 100) if quiz.questions else 0,
|
| 429 |
+
total_points=total_points,
|
| 430 |
+
earned_points=earned_points,
|
| 431 |
+
answers=detailed_answers
|
| 432 |
+
)
|
| 433 |
+
|
| 434 |
+
# Save result
|
| 435 |
+
result_file = self.quiz_results_dir / f"{result.submission_id}.json"
|
| 436 |
+
with open(result_file, 'w', encoding='utf-8') as f:
|
| 437 |
+
json.dump(result.dict(), f, indent=2, default=str)
|
| 438 |
+
|
| 439 |
+
return result
|
| 440 |
+
|
| 441 |
+
def get_quiz_history(self, quiz_id: str) -> QuizHistory:
|
| 442 |
+
"""Get quiz attempt history"""
|
| 443 |
+
quiz = self.get_quiz(quiz_id)
|
| 444 |
+
if not quiz:
|
| 445 |
+
return None
|
| 446 |
+
|
| 447 |
+
# Load all results for this quiz
|
| 448 |
+
results = []
|
| 449 |
+
for file_path in self.quiz_results_dir.glob("*.json"):
|
| 450 |
+
try:
|
| 451 |
+
with open(file_path, 'r', encoding='utf-8') as f:
|
| 452 |
+
data = json.load(f)
|
| 453 |
+
|
| 454 |
+
result = QuizResult(**data)
|
| 455 |
+
if result.quiz_id == quiz_id:
|
| 456 |
+
results.append(result)
|
| 457 |
+
except Exception as e:
|
| 458 |
+
print(f"Error loading quiz result {file_path}: {e}")
|
| 459 |
+
|
| 460 |
+
# Calculate statistics
|
| 461 |
+
scores = [r.score_percentage for r in results] if results else [0]
|
| 462 |
+
|
| 463 |
+
history = QuizHistory(
|
| 464 |
+
quiz_id=quiz_id,
|
| 465 |
+
space_id=quiz.space_id,
|
| 466 |
+
quiz_title=quiz.title,
|
| 467 |
+
results=results,
|
| 468 |
+
best_score=max(scores),
|
| 469 |
+
average_score=sum(scores) / len(scores) if scores else 0,
|
| 470 |
+
attempts_count=len(results)
|
| 471 |
+
)
|
| 472 |
+
|
| 473 |
+
return history
|
utils/vector_db.py
ADDED
|
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import uuid
|
| 3 |
+
import hashlib
|
| 4 |
+
from qdrant_client import QdrantClient
|
| 5 |
+
from qdrant_client.http import models
|
| 6 |
+
from sentence_transformers import SentenceTransformer
|
| 7 |
+
from typing import List, Dict, Optional
|
| 8 |
+
import threading
|
| 9 |
+
import logging
|
| 10 |
+
import warnings
|
| 11 |
+
|
| 12 |
+
warnings.filterwarnings('ignore', category=FutureWarning)
|
| 13 |
+
logging.getLogger('sentence_transformers').setLevel(logging.WARNING)
|
| 14 |
+
|
| 15 |
+
class VectorDatabase:
|
| 16 |
+
"""Manage vector database for document embeddings using Qdrant Cloud."""
|
| 17 |
+
|
| 18 |
+
_embedding_model = None
|
| 19 |
+
_embedding_model_name = None
|
| 20 |
+
_embedding_model_lock = threading.Lock()
|
| 21 |
+
|
| 22 |
+
def __init__(self, collection_name: str = "documents", persist_directory: str = None):
|
| 23 |
+
"""Initialize Qdrant Client (persist_directory is ignored for Cloud)"""
|
| 24 |
+
|
| 25 |
+
qdrant_url = os.getenv("QDRANT_URL")
|
| 26 |
+
qdrant_api_key = os.getenv("QDRANT_API_KEY")
|
| 27 |
+
|
| 28 |
+
if not qdrant_url or not qdrant_api_key:
|
| 29 |
+
raise ValueError("QDRANT_URL and QDRANT_API_KEY must be set in environment variables.")
|
| 30 |
+
|
| 31 |
+
self.client = QdrantClient(
|
| 32 |
+
url=qdrant_url,
|
| 33 |
+
api_key=qdrant_api_key,
|
| 34 |
+
timeout=60.0
|
| 35 |
+
)
|
| 36 |
+
self.collection_name = collection_name
|
| 37 |
+
self.vector_size = 384 # Size for standard sentence-transformers (e.g. all-MiniLM-L6-v2)
|
| 38 |
+
|
| 39 |
+
# Ensure collection exists
|
| 40 |
+
self._ensure_collection()
|
| 41 |
+
|
| 42 |
+
# Load embedding model
|
| 43 |
+
self.embedding_model = self._get_or_create_embedding_model()
|
| 44 |
+
|
| 45 |
+
def _ensure_collection(self):
|
| 46 |
+
"""Creates the collection in Qdrant if it doesn't exist."""
|
| 47 |
+
try:
|
| 48 |
+
collections = self.client.get_collections().collections
|
| 49 |
+
exists = any(c.name == self.collection_name for c in collections)
|
| 50 |
+
|
| 51 |
+
if not exists:
|
| 52 |
+
self.client.create_collection(
|
| 53 |
+
collection_name=self.collection_name,
|
| 54 |
+
vectors_config=models.VectorParams(
|
| 55 |
+
size=self.vector_size,
|
| 56 |
+
distance=models.Distance.COSINE
|
| 57 |
+
)
|
| 58 |
+
)
|
| 59 |
+
except Exception as e:
|
| 60 |
+
print(f"Error checking/creating collection: {e}")
|
| 61 |
+
|
| 62 |
+
@classmethod
|
| 63 |
+
def _get_or_create_embedding_model(cls):
|
| 64 |
+
with cls._embedding_model_lock:
|
| 65 |
+
# Assuming you set EMBEDDING_MODEL in your config, defaulting to MiniLM
|
| 66 |
+
model_name = os.getenv("EMBEDDING_MODEL", "all-MiniLM-L6-v2")
|
| 67 |
+
if cls._embedding_model is None or cls._embedding_model_name != model_name:
|
| 68 |
+
import torch
|
| 69 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 70 |
+
print(f"Loading embedding model on {device}...")
|
| 71 |
+
cls._embedding_model = SentenceTransformer(model_name, device=device)
|
| 72 |
+
cls._embedding_model_name = model_name
|
| 73 |
+
return cls._embedding_model
|
| 74 |
+
|
| 75 |
+
def _string_to_uuid(self, string_id: str) -> str:
|
| 76 |
+
"""Qdrant requires proper UUIDs. This hashes your custom string IDs into UUIDs."""
|
| 77 |
+
return str(uuid.UUID(hashlib.md5(string_id.encode()).hexdigest()))
|
| 78 |
+
|
| 79 |
+
def add_documents(self, texts: List[str], metadatas: List[Dict], ids: List[str]):
|
| 80 |
+
if not texts:
|
| 81 |
+
return
|
| 82 |
+
|
| 83 |
+
embeddings = self.embedding_model.encode(texts, show_progress_bar=False, batch_size=64).tolist()
|
| 84 |
+
|
| 85 |
+
points = []
|
| 86 |
+
for i in range(len(texts)):
|
| 87 |
+
payload = metadatas[i] if metadatas[i] else {}
|
| 88 |
+
payload['text'] = texts[i] # Store actual text in payload for retrieval
|
| 89 |
+
|
| 90 |
+
points.append(models.PointStruct(
|
| 91 |
+
id=self._string_to_uuid(ids[i]),
|
| 92 |
+
vector=embeddings[i],
|
| 93 |
+
payload=payload
|
| 94 |
+
))
|
| 95 |
+
|
| 96 |
+
self.client.upsert(
|
| 97 |
+
collection_name=self.collection_name,
|
| 98 |
+
points=points
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
def query(self, query_text: str, n_results: int = 5, filter_dict: Optional[Dict] = None) -> Dict:
|
| 102 |
+
# Check if collection is empty
|
| 103 |
+
count = self.get_collection_count()
|
| 104 |
+
if count == 0:
|
| 105 |
+
return {"documents": [[]], "metadatas": [[]], "distances": [[]], "ids": [[]]}
|
| 106 |
+
|
| 107 |
+
query_embedding = self.embedding_model.encode([query_text])[0].tolist()
|
| 108 |
+
|
| 109 |
+
# Build Qdrant filter if provided
|
| 110 |
+
qdrant_filter = None
|
| 111 |
+
if filter_dict:
|
| 112 |
+
conditions = [
|
| 113 |
+
models.FieldCondition(key=k, match=models.MatchValue(value=v))
|
| 114 |
+
for k, v in filter_dict.items()
|
| 115 |
+
]
|
| 116 |
+
qdrant_filter = models.Filter(must=conditions)
|
| 117 |
+
|
| 118 |
+
search_result = self.client.search(
|
| 119 |
+
collection_name=self.collection_name,
|
| 120 |
+
query_vector=query_embedding,
|
| 121 |
+
query_filter=qdrant_filter,
|
| 122 |
+
limit=n_results
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
# Format output to match exactly what your HybridRetriever expects (ChromaDB style)
|
| 126 |
+
docs, metas, scores, ids = [], [], [], []
|
| 127 |
+
for hit in search_result:
|
| 128 |
+
docs.append(hit.payload.get('text', ''))
|
| 129 |
+
|
| 130 |
+
# Remove text from metadata so it mimics Chroma
|
| 131 |
+
meta = {k: v for k, v in hit.payload.items() if k != 'text'}
|
| 132 |
+
metas.append(meta)
|
| 133 |
+
|
| 134 |
+
scores.append(hit.score)
|
| 135 |
+
ids.append(str(hit.id))
|
| 136 |
+
|
| 137 |
+
return {
|
| 138 |
+
"documents": [docs],
|
| 139 |
+
"metadatas": [metas],
|
| 140 |
+
"distances": [scores], # Note: Qdrant uses cosine similarity (higher is better), Chroma uses distance.
|
| 141 |
+
"ids": [ids]
|
| 142 |
+
}
|
| 143 |
+
|
| 144 |
+
def get_collection_count(self) -> int:
|
| 145 |
+
try:
|
| 146 |
+
return self.client.count(collection_name=self.collection_name).count
|
| 147 |
+
except Exception:
|
| 148 |
+
return 0
|