mohhhhhit commited on
Commit
3736c33
·
verified ·
1 Parent(s): b8aaf69

first init

Browse files
Files changed (40) hide show
  1. .gitignore +27 -0
  2. Dockerfile +24 -0
  3. config.py +44 -0
  4. data/spaces.json +12 -0
  5. main.py +896 -0
  6. models/__pycache__/studio_models.cpython-311.pyc +0 -0
  7. models/studio_models.py +219 -0
  8. requirements.txt +39 -0
  9. runtime.txt +1 -0
  10. start_ngrok_tunnel.py +69 -0
  11. utils/__init__.py +1 -0
  12. utils/__pycache__/__init__.cpython-311.pyc +0 -0
  13. utils/__pycache__/__init__.cpython-314.pyc +0 -0
  14. utils/__pycache__/config_manager.cpython-311.pyc +0 -0
  15. utils/__pycache__/config_manager.cpython-314.pyc +0 -0
  16. utils/__pycache__/document_processor.cpython-311.pyc +0 -0
  17. utils/__pycache__/document_processor.cpython-314.pyc +0 -0
  18. utils/__pycache__/hybrid_retriever.cpython-311.pyc +0 -0
  19. utils/__pycache__/hybrid_retriever.cpython-314.pyc +0 -0
  20. utils/__pycache__/llm_generator.cpython-311.pyc +0 -0
  21. utils/__pycache__/llm_generator.cpython-314.pyc +0 -0
  22. utils/__pycache__/model_inference.cpython-311.pyc +0 -0
  23. utils/__pycache__/simple_generator.cpython-311.pyc +0 -0
  24. utils/__pycache__/spaces_manager.cpython-311.pyc +0 -0
  25. utils/__pycache__/spaces_manager.cpython-314.pyc +0 -0
  26. utils/__pycache__/studio_generator.cpython-311.pyc +0 -0
  27. utils/__pycache__/studio_manager.cpython-311.pyc +0 -0
  28. utils/__pycache__/vector_db.cpython-311.pyc +0 -0
  29. utils/__pycache__/vector_db.cpython-314.pyc +0 -0
  30. utils/chat_manager.py +123 -0
  31. utils/config_manager.py +80 -0
  32. utils/document_processor.py +222 -0
  33. utils/hybrid_retriever.py +149 -0
  34. utils/llm_generator.py +297 -0
  35. utils/model_inference.py +156 -0
  36. utils/simple_generator.py +444 -0
  37. utils/spaces_manager.py +124 -0
  38. utils/studio_generator.py +309 -0
  39. utils/studio_manager.py +473 -0
  40. utils/vector_db.py +148 -0
.gitignore ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+ *.so
6
+ .Python
7
+ *.pyc
8
+
9
+ # Virtual environment
10
+ venv/
11
+ env/
12
+ ENV/
13
+
14
+ # Environment variables
15
+ .env
16
+ .env.local
17
+
18
+ # IDE
19
+ .vscode/
20
+ .idea/
21
+ *.swp
22
+ *.swo
23
+ *~
24
+ .DS_Store
25
+
26
+ # Logs
27
+ *.log
Dockerfile ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.11-slim
2
+
3
+ # Set working directory
4
+ WORKDIR /app
5
+
6
+ # Install system dependencies needed for some ML libraries and PyPDF2
7
+ RUN apt-get update && apt-get install -y \
8
+ build-essential \
9
+ && rm -rf /var/lib/apt/lists/*
10
+
11
+ # Copy requirements first to leverage Docker cache
12
+ COPY requirements.txt .
13
+
14
+ # Install Python dependencies
15
+ RUN pip install --no-cache-dir -r requirements.txt
16
+
17
+ # Copy the rest of the application
18
+ COPY . .
19
+
20
+ # Expose the port FastAPI will run on
21
+ EXPOSE 7860
22
+
23
+ # Command to run the application (Hugging Face routes to 7860 by default)
24
+ CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "7860"]
config.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pathlib import Path
3
+ from dotenv import load_dotenv
4
+
5
+ # Load environment variables
6
+ load_dotenv()
7
+
8
+ # Project paths
9
+ PROJECT_ROOT = Path(__file__).parent
10
+ DATA_DIR = PROJECT_ROOT.parent / "data" # Use project root's data folder, not backend/data
11
+ MODELS_DIR = PROJECT_ROOT / "models"
12
+ UPLOADS_DIR = DATA_DIR / "uploads"
13
+ VECTOR_DB_DIR = DATA_DIR / "vector_db"
14
+ CHATS_DIR = DATA_DIR / "chats"
15
+
16
+ # Create directories if they don't exist
17
+ for dir_path in [DATA_DIR, MODELS_DIR, UPLOADS_DIR, VECTOR_DB_DIR, CHATS_DIR]:
18
+ dir_path.mkdir(parents=True, exist_ok=True)
19
+
20
+ # Model configuration
21
+ # RAG uses pre-trained models directly - no training required!
22
+ MODEL_NAME = os.getenv("MODEL_NAME", "microsoft/phi-2") # Pre-trained model
23
+ USE_PRETRAINED = os.getenv("USE_PRETRAINED", "true").lower() == "true" # Use pre-trained by default
24
+ MODEL_PATH = os.getenv("MODEL_PATH", str(MODELS_DIR / "trained_model")) # Only if fine-tuned
25
+ EMBEDDING_MODEL = "sentence-transformers/all-MiniLM-L6-v2" # For document embeddings
26
+
27
+ # API Keys
28
+ OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "")
29
+ HUGGINGFACE_API_KEY = os.getenv("HUGGINGFACE_API_KEY", "")
30
+
31
+ # Application settings
32
+ MAX_UPLOAD_SIZE = int(os.getenv("MAX_UPLOAD_SIZE", "200")) # MB
33
+ TEMPERATURE = float(os.getenv("TEMPERATURE", "0.7"))
34
+ MAX_TOKENS = int(os.getenv("MAX_TOKENS", "2048"))
35
+ CHUNK_SIZE = 512
36
+ CHUNK_OVERLAP = 50
37
+
38
+ # Use cases
39
+ USE_CASES = {
40
+ "explanation": "Provide detailed explanation of concepts",
41
+ "summary": "Generate concise summary of content",
42
+ "qa": "Answer questions based on content",
43
+ "notes": "Create structured study notes"
44
+ }
data/spaces.json ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "spaces": [
3
+ {
4
+ "id": "general",
5
+ "name": "General",
6
+ "description": "General study materials",
7
+ "created_at": "2026-03-12T10:42:37.952166",
8
+ "file_count": 0,
9
+ "chat_count": 0
10
+ }
11
+ ]
12
+ }
main.py ADDED
@@ -0,0 +1,896 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ FastAPI Backend for NotebookPRO
3
+ Handles RAG, LLM, file processing, and chat management
4
+ """
5
+ from fastapi import FastAPI, File, UploadFile, HTTPException
6
+ from fastapi.middleware.cors import CORSMiddleware
7
+ from pydantic import BaseModel
8
+ from typing import List, Optional, Dict, Any
9
+ from pathlib import Path
10
+ import json
11
+ from datetime import datetime
12
+ import uuid
13
+ import sys
14
+ import warnings
15
+ import logging
16
+ import os
17
+ import shutil
18
+
19
+ # Suppress warnings
20
+ warnings.filterwarnings('ignore')
21
+ os.environ['PYTHONWARNINGS'] = 'ignore'
22
+ os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
23
+ os.environ['TOKENIZERS_PARALLELISM'] = 'false'
24
+ os.environ.setdefault('OMP_NUM_THREADS', '2')
25
+ os.environ.setdefault('MKL_NUM_THREADS', '2')
26
+ os.environ.setdefault('OPENBLAS_NUM_THREADS', '2')
27
+ os.environ.setdefault('NUMEXPR_NUM_THREADS', '2')
28
+ #logging.getLogger().setLevel(logging.ERROR)
29
+
30
+ # Add project root to path
31
+ sys.path.append(str(Path(__file__).parent.parent))
32
+
33
+ import config
34
+ from utils.document_processor import DocumentProcessor
35
+ from utils.vector_db import VectorDatabase
36
+ from utils.hybrid_retriever import HybridRetriever
37
+ from utils.llm_generator import LLMGenerator
38
+ from utils.config_manager import ConfigManager
39
+ from utils.spaces_manager import SpacesManager
40
+ from utils.studio_manager import StudioManager
41
+ from utils.studio_generator import StudioGenerator
42
+
43
+ # Initialize FastAPI
44
+ app = FastAPI(title="NotebookPRO API", version="2.0.0")
45
+
46
+ # CORS - Allow Flutter web to connect
47
+ app.add_middleware(
48
+ CORSMiddleware,
49
+ allow_origins=["*"], # In production, specify your Flutter web URL
50
+ allow_credentials=True,
51
+ allow_methods=["*"],
52
+ allow_headers=["*"],
53
+ )
54
+
55
+ # Global instances
56
+ config_manager = ConfigManager()
57
+ spaces_manager = SpacesManager()
58
+ studio_manager = StudioManager()
59
+ studio_generator = None # Will be initialized after LLM
60
+ vector_db = None
61
+ llm_generator = None
62
+ current_space = None
63
+
64
+ # ==================== Pydantic Models ====================
65
+
66
+ class ChatMessage(BaseModel):
67
+ role: str
68
+ content: str
69
+ timestamp: str
70
+ sources: Optional[List[Dict[str, Any]]] = None
71
+
72
+ class ChatRequest(BaseModel):
73
+ query: str
74
+ space_id: str
75
+ chat_id: Optional[str] = None
76
+ workflow: str = "chat"
77
+
78
+ class ChatResponse(BaseModel):
79
+ response: str
80
+ sources: List[Dict[str, Any]]
81
+ chat_id: str
82
+ timestamp: str
83
+
84
+ class SpaceCreate(BaseModel):
85
+ name: str
86
+
87
+ class SpaceResponse(BaseModel):
88
+ id: str
89
+ name: str
90
+ created_at: str
91
+ file_count: int
92
+
93
+ class ChatInfo(BaseModel):
94
+ id: str
95
+ title: str
96
+ preview: str
97
+ created_at: str
98
+ updated_at: str
99
+ message_count: int
100
+
101
+ class ConfigResponse(BaseModel):
102
+ groq_api_key: Optional[str]
103
+ gemini_api_key: Optional[str]
104
+
105
+ class ConfigUpdate(BaseModel):
106
+ groq_api_key: Optional[str] = None
107
+ gemini_api_key: Optional[str] = None
108
+
109
+ class ChatToNotebookRequest(BaseModel):
110
+ space_id: str
111
+ question: str
112
+ answer: str
113
+ chat_id: Optional[str] = None
114
+ assistant_timestamp: Optional[str] = None
115
+ tags: List[str] = []
116
+ space_name: Optional[str] = None
117
+
118
+ # ==================== Helper Functions ====================
119
+
120
+ def get_data_dir():
121
+ """Get data directory path"""
122
+ return Path(__file__).parent.parent / "data"
123
+
124
+ def get_space_dir(space_id: str):
125
+ """Get space-specific directory"""
126
+ return get_data_dir() / "spaces" / space_id
127
+
128
+ def load_chats_for_space(space_id: str) -> List[Dict]:
129
+ """Load all chats for a space"""
130
+ chats_file = get_space_dir(space_id) / "chats.json"
131
+ if chats_file.exists():
132
+ with open(chats_file, 'r', encoding='utf-8') as f:
133
+ return json.load(f)
134
+ return []
135
+
136
+ def save_chats_for_space(space_id: str, chats: List[Dict]):
137
+ """Save chats for a space"""
138
+ chats_file = get_space_dir(space_id) / "chats.json"
139
+ chats_file.parent.mkdir(parents=True, exist_ok=True)
140
+ with open(chats_file, 'w', encoding='utf-8') as f:
141
+ json.dump(chats, f, indent=2, ensure_ascii=False)
142
+
143
+ def get_chat_title(messages: List[Dict]) -> str:
144
+ """Generate chat title from first user message"""
145
+ for msg in messages:
146
+ if msg['role'] == 'user':
147
+ content = msg['content'][:50]
148
+ return content + "..." if len(msg['content']) > 50 else content
149
+ return "New Chat"
150
+
151
+ def ensure_notebooks_for_existing_spaces() -> int:
152
+ """Ensure every existing space has an associated notebook metadata record."""
153
+ created_count = 0
154
+ spaces = spaces_manager.get_all_spaces()
155
+
156
+ for space in spaces:
157
+ space_id = space.get('id')
158
+ if not space_id:
159
+ continue
160
+
161
+ existing_notebook = studio_manager.get_space_notebook(space_id)
162
+ if existing_notebook:
163
+ continue
164
+
165
+ studio_manager.ensure_space_notebook(space_id, space.get('name', space_id))
166
+ created_count += 1
167
+
168
+ return created_count
169
+
170
+ def rebuild_space_index_if_missing(space_id: str) -> int:
171
+ """Rebuild a space index from uploaded files if the current index is empty."""
172
+ if not vector_db:
173
+ return 0
174
+
175
+ try:
176
+ if vector_db.get_collection_count() > 0:
177
+ return 0
178
+ except Exception:
179
+ # If count check fails, continue with a best-effort rebuild.
180
+ pass
181
+
182
+ uploads_dir = get_space_dir(space_id) / "uploads"
183
+ if not uploads_dir.exists():
184
+ return 0
185
+
186
+ files = [
187
+ p for p in uploads_dir.iterdir()
188
+ if p.is_file() and p.suffix.lower() in {".pdf", ".docx", ".txt"}
189
+ ]
190
+ if not files:
191
+ return 0
192
+
193
+ processor = DocumentProcessor()
194
+ texts: List[str] = []
195
+ metadatas: List[Dict[str, Any]] = []
196
+ ids: List[str] = []
197
+
198
+ for file_path in files:
199
+ try:
200
+ file_data = processor.process_file(file_path)
201
+ chunks = processor.chunk_text(
202
+ file_data['content'],
203
+ chunk_size=512,
204
+ overlap=50,
205
+ semantic=True,
206
+ )
207
+ total_chunks = len(chunks)
208
+ for idx, chunk in enumerate(chunks):
209
+ texts.append(chunk)
210
+ metadatas.append({
211
+ 'filename': file_path.name,
212
+ 'chunk_index': idx,
213
+ 'total_chunks': total_chunks,
214
+ 'source_type': file_data['format'],
215
+ })
216
+ ids.append(f"{space_id}_rebuild_{len(ids)}_{uuid.uuid4().hex[:8]}")
217
+ except Exception as e:
218
+ print(f"Index rebuild skipped {file_path.name}: {e}")
219
+
220
+ if not texts:
221
+ return 0
222
+
223
+ batch_size = 5000
224
+ for i in range(0, len(texts), batch_size):
225
+ vector_db.add_documents(
226
+ texts[i:i + batch_size],
227
+ metadatas[i:i + batch_size],
228
+ ids[i:i + batch_size],
229
+ )
230
+
231
+ print(f"Rebuilt index for space '{space_id}' with {len(texts)} chunks")
232
+ return len(texts)
233
+
234
+ def initialize_space(space_id: str):
235
+ """Initialize vector DB and components for a space"""
236
+ global vector_db, llm_generator, studio_generator, current_space
237
+
238
+ # Fast path: reuse already initialized components for the active space.
239
+ if current_space == space_id and vector_db is not None and llm_generator is not None:
240
+ return
241
+
242
+ # Get API keys
243
+ import os
244
+ # Try the config manager first, but fallback to the .env file variables
245
+ groq_key = config_manager.get_api_key('groq') or os.getenv('GROQ_API_KEY')
246
+ gemini_key = config_manager.get_api_key('gemini') or os.getenv('GOOGLE_API_KEY') or os.getenv('GEMINI_API_KEY')
247
+
248
+ if not groq_key and not gemini_key:
249
+ raise HTTPException(status_code=400, detail="No API keys configured. Please add Groq or Gemini API key.")
250
+
251
+ # Initialize vector database for this space (space-local persistence path).
252
+ # Initialize Qdrant cloud database for this space
253
+ vector_db = VectorDatabase(
254
+ collection_name=f"space_{space_id}"
255
+ )
256
+
257
+ # Backward-compatibility: rebuild embeddings from uploaded files if index is empty.
258
+ rebuild_space_index_if_missing(space_id)
259
+
260
+ # Initialize LLM generator - choose provider based on available keys
261
+ if groq_key:
262
+ llm_generator = LLMGenerator(provider="groq", api_key=groq_key)
263
+ else:
264
+ llm_generator = LLMGenerator(provider="gemini", api_key=gemini_key)
265
+
266
+ # Initialize studio generator with LLM
267
+ studio_generator = StudioGenerator(llm_generator, studio_manager)
268
+ current_space = space_id
269
+
270
+ @app.on_event("startup")
271
+ async def startup_sync_notebooks():
272
+ """Auto-create missing notebooks for pre-existing spaces when backend starts."""
273
+ try:
274
+ created = ensure_notebooks_for_existing_spaces()
275
+ if created > 0:
276
+ print(f"Created {created} missing notebook(s) for existing spaces")
277
+ except Exception as e:
278
+ # Keep server startup resilient even if sync fails.
279
+ print(f"Notebook startup sync failed: {e}")
280
+
281
+ # ==================== API Endpoints ====================
282
+
283
+ @app.get("/")
284
+ async def root():
285
+ """Health check"""
286
+ return {"status": "NotebookPRO API is running", "version": "2.0.0"}
287
+
288
+ @app.get("/api/config", response_model=ConfigResponse)
289
+ async def get_config():
290
+ """Get current API keys (masked)"""
291
+ groq_key = config_manager.get_api_key('groq')
292
+ gemini_key = config_manager.get_api_key('gemini')
293
+
294
+ return ConfigResponse(
295
+ groq_api_key="***" + groq_key[-4:] if groq_key else None,
296
+ gemini_api_key="***" + gemini_key[-4:] if gemini_key else None
297
+ )
298
+
299
+ @app.post("/api/config")
300
+ async def update_config(config_update: ConfigUpdate):
301
+ """Update API keys"""
302
+ if config_update.groq_api_key:
303
+ config_manager.set_api_key('groq', config_update.groq_api_key)
304
+ if config_update.gemini_api_key:
305
+ config_manager.set_api_key('gemini', config_update.gemini_api_key)
306
+
307
+ return {"status": "success", "message": "Configuration updated"}
308
+
309
+ @app.get("/api/spaces", response_model=List[SpaceResponse])
310
+ async def get_spaces():
311
+ """Get all spaces"""
312
+ # Self-healing check in case spaces were created externally while server is running.
313
+ ensure_notebooks_for_existing_spaces()
314
+ spaces = spaces_manager.get_all_spaces()
315
+
316
+ result = []
317
+ for space in spaces:
318
+ space_id = space['id']
319
+ space_dir = get_space_dir(space_id)
320
+ processed_file = space_dir / "processed_files.json"
321
+
322
+ file_count = 0
323
+ if processed_file.exists():
324
+ with open(processed_file, 'r') as f:
325
+ file_count = len(json.load(f))
326
+
327
+ result.append(SpaceResponse(
328
+ id=space_id,
329
+ name=space['name'],
330
+ created_at=space['created_at'],
331
+ file_count=file_count
332
+ ))
333
+
334
+ return result
335
+
336
+ @app.post("/api/spaces", response_model=SpaceResponse)
337
+ async def create_space(space_data: SpaceCreate):
338
+ """Create a new space"""
339
+ try:
340
+ space = spaces_manager.create_space(space_data.name)
341
+
342
+ # Create associated notebook metadata with the same name as the space.
343
+ studio_manager.ensure_space_notebook(space['id'], space['name'])
344
+
345
+ return SpaceResponse(
346
+ id=space['id'],
347
+ name=space['name'],
348
+ created_at=space['created_at'],
349
+ file_count=0
350
+ )
351
+ except ValueError as e:
352
+ raise HTTPException(status_code=400, detail=str(e))
353
+
354
+ @app.delete("/api/spaces/{space_id}")
355
+ async def delete_space(space_id: str):
356
+ """Delete a space"""
357
+ try:
358
+ spaces_manager.delete_space(space_id)
359
+
360
+ # Delete space directory
361
+ space_dir = get_space_dir(space_id)
362
+ if space_dir.exists():
363
+ shutil.rmtree(space_dir)
364
+
365
+ return {"status": "success", "message": f"Space {space_id} deleted"}
366
+ except ValueError as e:
367
+ raise HTTPException(status_code=400, detail=str(e))
368
+ except Exception as e:
369
+ raise HTTPException(status_code=500, detail=f"Error deleting space: {str(e)}")
370
+
371
+ @app.get("/api/spaces/{space_id}/chats", response_model=List[ChatInfo])
372
+ async def get_chats(space_id: str):
373
+ """Get all chats for a space"""
374
+ chats = load_chats_for_space(space_id)
375
+
376
+ result = []
377
+ for chat in chats:
378
+ messages = chat.get('messages', [])
379
+ result.append(ChatInfo(
380
+ id=chat['id'],
381
+ title=get_chat_title(messages),
382
+ preview=messages[0]['content'][:100] if messages else "",
383
+ created_at=chat.get('created_at', ''),
384
+ updated_at=chat.get('updated_at', ''),
385
+ message_count=len(messages)
386
+ ))
387
+
388
+ return result
389
+
390
+ @app.get("/api/spaces/{space_id}/chats/{chat_id}")
391
+ async def get_chat(space_id: str, chat_id: str):
392
+ """Get specific chat by ID"""
393
+ chats = load_chats_for_space(space_id)
394
+
395
+ for chat in chats:
396
+ if chat['id'] == chat_id:
397
+ return chat
398
+
399
+ raise HTTPException(status_code=404, detail="Chat not found")
400
+
401
+ @app.delete("/api/spaces/{space_id}/chats/{chat_id}")
402
+ async def delete_chat(space_id: str, chat_id: str):
403
+ """Delete a chat"""
404
+ chats = load_chats_for_space(space_id)
405
+ chats = [c for c in chats if c['id'] != chat_id]
406
+ save_chats_for_space(space_id, chats)
407
+
408
+ return {"status": "success", "message": f"Chat {chat_id} deleted"}
409
+
410
+ @app.post("/api/chat", response_model=ChatResponse)
411
+ async def chat(request: ChatRequest):
412
+ """Process a chat message with RAG"""
413
+ try:
414
+ # Initialize space if needed
415
+ initialize_space(request.space_id)
416
+
417
+ # Create hybrid retriever with 60% vector, 40% BM25
418
+ hybrid_retriever = HybridRetriever(vector_db, alpha=0.6)
419
+
420
+ # Retrieve relevant documents
421
+ documents, metadatas, scores = hybrid_retriever.retrieve(
422
+ query=request.query,
423
+ n_results=5
424
+ )
425
+
426
+ # Build context from retrieved documents
427
+ context_parts = []
428
+ sources = []
429
+
430
+ for idx, (doc, meta, score) in enumerate(zip(documents, metadatas, scores), 1):
431
+ # Extract clean filename for source citation
432
+ filename = meta.get('filename', 'Unknown')
433
+ clean_name = filename.replace('.pdf', '').replace('.docx', '').replace('.txt', '')
434
+ context_parts.append(f"Source [{idx}] ({clean_name}):\n{doc}\n")
435
+ sources.append({
436
+ "content": doc[:200] + "..." if len(doc) > 200 else doc,
437
+ "metadata": meta,
438
+ "score": float(score)
439
+ })
440
+
441
+ context = "\n".join(context_parts)
442
+
443
+ # Use the advanced generate_response method which has the new NotebookLM-style prompt
444
+ response = llm_generator.generate_response(
445
+ prompt=request.query,
446
+ context=context,
447
+ use_case=request.workflow if request.workflow in ["summary", "explanation", "qa", "notes"] else "qa",
448
+ metadatas=metadatas,
449
+ temperature=0.3
450
+ )
451
+
452
+ # Create or update chat
453
+ chat_id = request.chat_id or str(uuid.uuid4())
454
+ chats = load_chats_for_space(request.space_id)
455
+
456
+ # Find existing chat or create new
457
+ chat = None
458
+ for c in chats:
459
+ if c['id'] == chat_id:
460
+ chat = c
461
+ break
462
+
463
+ if not chat:
464
+ chat = {
465
+ 'id': chat_id,
466
+ 'messages': [],
467
+ 'created_at': datetime.now().isoformat(),
468
+ 'updated_at': datetime.now().isoformat()
469
+ }
470
+ chats.append(chat)
471
+
472
+ # Add messages
473
+ timestamp = datetime.now().isoformat()
474
+ chat['messages'].extend([
475
+ {'role': 'user', 'content': request.query, 'timestamp': timestamp},
476
+ {
477
+ 'role': 'assistant',
478
+ 'content': response,
479
+ 'timestamp': timestamp,
480
+ 'sources': sources
481
+ }
482
+ ])
483
+ chat['updated_at'] = timestamp
484
+
485
+ # Save chats
486
+ save_chats_for_space(request.space_id, chats)
487
+
488
+ return ChatResponse(
489
+ response=response,
490
+ sources=sources,
491
+ chat_id=chat_id,
492
+ timestamp=timestamp
493
+ )
494
+
495
+ except Exception as e:
496
+ raise HTTPException(status_code=500, detail=str(e))
497
+
498
+ @app.post("/api/spaces/{space_id}/upload")
499
+ async def upload_files(space_id: str, files: List[UploadFile] = File(...)):
500
+ """Upload and process files for a space"""
501
+ try:
502
+ # Initialize space
503
+ initialize_space(space_id)
504
+
505
+ # Save uploaded files temporarily
506
+ space_dir = get_space_dir(space_id)
507
+ uploads_dir = space_dir / "uploads"
508
+ uploads_dir.mkdir(parents=True, exist_ok=True)
509
+
510
+ processor = DocumentProcessor()
511
+ all_chunks = []
512
+ processed_files = []
513
+
514
+ for file in files:
515
+ # Save file
516
+ file_path = uploads_dir / file.filename
517
+ with open(file_path, "wb") as f:
518
+ content = await file.read()
519
+ f.write(content)
520
+
521
+ # Process file and extract content
522
+ try:
523
+ file_data = processor.process_file(file_path)
524
+ content = file_data['content']
525
+
526
+ # Chunk the content
527
+ chunks = processor.chunk_text(content, chunk_size=512, overlap=50, semantic=True)
528
+
529
+ # Format chunks for vector database
530
+ formatted_chunks = []
531
+ for idx, chunk in enumerate(chunks):
532
+ formatted_chunks.append({
533
+ 'content': chunk,
534
+ 'metadata': {
535
+ 'filename': file.filename,
536
+ 'chunk_index': idx,
537
+ 'total_chunks': len(chunks),
538
+ 'source_type': file_data['format']
539
+ }
540
+ })
541
+
542
+ all_chunks.extend(formatted_chunks)
543
+ processed_files.append({
544
+ 'filename': file.filename,
545
+ 'chunks': len(chunks),
546
+ 'processed_at': datetime.now().isoformat()
547
+ })
548
+ except Exception as e:
549
+ # Log error but continue with other files
550
+ print(f"Error processing {file.filename}: {str(e)}")
551
+ continue
552
+
553
+ # Add to vector database in batches to avoid size limits
554
+ if all_chunks:
555
+ # Extract texts, metadatas, and generate IDs
556
+ texts = [chunk['content'] for chunk in all_chunks]
557
+ metadatas = [chunk['metadata'] for chunk in all_chunks]
558
+ ids = [f"{space_id}_{idx}_{uuid.uuid4().hex[:8]}" for idx in range(len(all_chunks))]
559
+
560
+ # Process in batches of 5000 to avoid ChromaDB batch size limit
561
+ batch_size = 5000
562
+ for i in range(0, len(texts), batch_size):
563
+ batch_texts = texts[i:i + batch_size]
564
+ batch_metadatas = metadatas[i:i + batch_size]
565
+ batch_ids = ids[i:i + batch_size]
566
+
567
+ vector_db.add_documents(batch_texts, batch_metadatas, batch_ids)
568
+ print(f"Processed batch {i//batch_size + 1}/{(len(texts)-1)//batch_size + 1}")
569
+
570
+ # Save processed files info
571
+ processed_file = space_dir / "processed_files.json"
572
+ existing = []
573
+ if processed_file.exists():
574
+ with open(processed_file, 'r') as f:
575
+ existing = json.load(f)
576
+
577
+ existing.extend(processed_files)
578
+ with open(processed_file, 'w') as f:
579
+ json.dump(existing, f, indent=2)
580
+
581
+ return {
582
+ "status": "success",
583
+ "files_processed": len(processed_files),
584
+ "total_chunks": len(all_chunks)
585
+ }
586
+
587
+ except Exception as e:
588
+ raise e # This strips the wrapper and forces FastAPI to log the raw stack trace
589
+
590
+ @app.get("/api/spaces/{space_id}/files")
591
+ async def get_files(space_id: str):
592
+ """Get processed files for a space"""
593
+ processed_file = get_space_dir(space_id) / "processed_files.json"
594
+
595
+ if processed_file.exists():
596
+ with open(processed_file, 'r') as f:
597
+ return json.load(f)
598
+
599
+ return []
600
+
601
+ @app.delete("/api/spaces/{space_id}/files/{filename}")
602
+ async def delete_file(space_id: str, filename: str):
603
+ """Delete a specific file from a space"""
604
+ try:
605
+ # Remove from processed_files.json
606
+ processed_file = get_space_dir(space_id) / "processed_files.json"
607
+ files_data = []
608
+
609
+ if processed_file.exists():
610
+ with open(processed_file, 'r') as f:
611
+ files_data = json.load(f)
612
+
613
+ # Filter out the file to delete
614
+ files_data = [f for f in files_data if f.get('filename') != filename]
615
+
616
+ with open(processed_file, 'w') as f:
617
+ json.dump(files_data, f, indent=2)
618
+
619
+ # Delete the actual file
620
+ file_path = get_space_dir(space_id) / "uploads" / filename
621
+ if file_path.exists():
622
+ file_path.unlink()
623
+
624
+ # Remove from vector database (if initialized)
625
+ # Note: This removes all chunks with this filename from metadata
626
+ if vector_db:
627
+ try:
628
+ # Get all documents in the collection
629
+ collection = vector_db.collection
630
+ results = collection.get()
631
+
632
+ # Find IDs of documents with matching filename
633
+ ids_to_delete = []
634
+ for idx, metadata in enumerate(results['metadatas']):
635
+ if metadata and metadata.get('filename') == filename:
636
+ ids_to_delete.append(results['ids'][idx])
637
+
638
+ # Delete those documents
639
+ if ids_to_delete:
640
+ collection.delete(ids=ids_to_delete)
641
+ print(f"Deleted {len(ids_to_delete)} chunks for {filename}")
642
+ except Exception as e:
643
+ print(f"Error removing from vector DB: {e}")
644
+
645
+ return {
646
+ "status": "success",
647
+ "message": f"File {filename} deleted"
648
+ }
649
+
650
+ except Exception as e:
651
+ raise HTTPException(status_code=500, detail=f"Error deleting file: {str(e)}")
652
+
653
+
654
+ # ==================== STUDIO API ROUTES ====================
655
+ # Routes for Notebook, Flashcards, and Quiz features
656
+
657
+ # Import studio models
658
+ from models.studio_models import (
659
+ NotebookEntry, NotebookEntryCreate, NotebookEntryUpdate,
660
+ Flashcard, FlashcardCreate, FlashcardUpdate, FlashcardReview,
661
+ FlashcardGenerateRequest,
662
+ Quiz, QuizCreate, QuizGenerateRequest, QuizSubmission, QuizResult, QuizHistory,
663
+ MasteryLevel
664
+ )
665
+
666
+ # ===== NOTEBOOK ROUTES =====
667
+
668
+ @app.post("/api/studio/notebook", response_model=NotebookEntry)
669
+ async def create_notebook_entry(entry_data: NotebookEntryCreate):
670
+ """Create a new notebook entry"""
671
+ try:
672
+ entry = studio_manager.create_notebook_entry(entry_data)
673
+ return entry
674
+ except Exception as e:
675
+ raise HTTPException(status_code=500, detail=str(e))
676
+
677
+ @app.get("/api/studio/notebook/space/{space_id}")
678
+ async def get_space_notebook(space_id: str):
679
+ """Get or create notebook metadata for a space."""
680
+ try:
681
+ space = spaces_manager.get_space(space_id)
682
+ space_name = space['name'] if space else space_id
683
+ notebook = studio_manager.ensure_space_notebook(space_id, space_name)
684
+ return notebook
685
+ except Exception as e:
686
+ raise HTTPException(status_code=500, detail=str(e))
687
+
688
+ @app.post("/api/studio/notebook/from-chat", response_model=NotebookEntry)
689
+ async def add_chat_to_notebook(request: ChatToNotebookRequest):
690
+ """Add a chat question/answer pair into a space notebook."""
691
+ try:
692
+ space = spaces_manager.get_space(request.space_id)
693
+ resolved_space_name = request.space_name or (space['name'] if space else request.space_id)
694
+
695
+ entry = studio_manager.create_notebook_entry_from_chat(
696
+ space_id=request.space_id,
697
+ question=request.question,
698
+ answer=request.answer,
699
+ chat_id=request.chat_id,
700
+ assistant_timestamp=request.assistant_timestamp,
701
+ tags=request.tags,
702
+ space_name=resolved_space_name
703
+ )
704
+ return entry
705
+ except Exception as e:
706
+ raise HTTPException(status_code=500, detail=str(e))
707
+
708
+ @app.get("/api/studio/notebook", response_model=List[NotebookEntry])
709
+ async def list_notebook_entries(space_id: Optional[str] = None):
710
+ """List all notebook entries, optionally filtered by space"""
711
+ try:
712
+ entries = studio_manager.list_notebook_entries(space_id)
713
+ return entries
714
+ except Exception as e:
715
+ raise HTTPException(status_code=500, detail=str(e))
716
+
717
+ @app.get("/api/studio/notebook/{entry_id}", response_model=NotebookEntry)
718
+ async def get_notebook_entry(entry_id: str):
719
+ """Get a single notebook entry"""
720
+ entry = studio_manager.get_notebook_entry(entry_id)
721
+ if not entry:
722
+ raise HTTPException(status_code=404, detail="Notebook entry not found")
723
+ return entry
724
+
725
+ @app.put("/api/studio/notebook/{entry_id}", response_model=NotebookEntry)
726
+ async def update_notebook_entry(entry_id: str, update_data: NotebookEntryUpdate):
727
+ """Update a notebook entry"""
728
+ entry = studio_manager.update_notebook_entry(entry_id, update_data)
729
+ if not entry:
730
+ raise HTTPException(status_code=404, detail="Notebook entry not found")
731
+ return entry
732
+
733
+ @app.delete("/api/studio/notebook/{entry_id}")
734
+ async def delete_notebook_entry(entry_id: str):
735
+ """Delete a notebook entry"""
736
+ success = studio_manager.delete_notebook_entry(entry_id)
737
+ if not success:
738
+ raise HTTPException(status_code=404, detail="Notebook entry not found")
739
+ return {"status": "success", "message": "Notebook entry deleted"}
740
+
741
+
742
+ # ===== FLASHCARD ROUTES =====
743
+
744
+ @app.post("/api/studio/flashcards", response_model=Flashcard)
745
+ async def create_flashcard(card_data: FlashcardCreate):
746
+ """Create a new flashcard"""
747
+ try:
748
+ card = studio_manager.create_flashcard(card_data)
749
+ return card
750
+ except Exception as e:
751
+ raise HTTPException(status_code=500, detail=str(e))
752
+
753
+ @app.get("/api/studio/flashcards", response_model=List[Flashcard])
754
+ async def list_flashcards(
755
+ space_id: Optional[str] = None,
756
+ mastery: Optional[MasteryLevel] = None
757
+ ):
758
+ """List all flashcards, optionally filtered"""
759
+ try:
760
+ cards = studio_manager.list_flashcards(space_id, mastery)
761
+ return cards
762
+ except Exception as e:
763
+ raise HTTPException(status_code=500, detail=str(e))
764
+
765
+ @app.get("/api/studio/flashcards/{card_id}", response_model=Flashcard)
766
+ async def get_flashcard(card_id: str):
767
+ """Get a single flashcard"""
768
+ card = studio_manager.get_flashcard(card_id)
769
+ if not card:
770
+ raise HTTPException(status_code=404, detail="Flashcard not found")
771
+ return card
772
+
773
+ @app.put("/api/studio/flashcards/{card_id}", response_model=Flashcard)
774
+ async def update_flashcard(card_id: str, update_data: FlashcardUpdate):
775
+ """Update a flashcard"""
776
+ card = studio_manager.update_flashcard(card_id, update_data)
777
+ if not card:
778
+ raise HTTPException(status_code=404, detail="Flashcard not found")
779
+ return card
780
+
781
+ @app.post("/api/studio/flashcards/{card_id}/review", response_model=Flashcard)
782
+ async def review_flashcard(card_id: str, review: FlashcardReview):
783
+ """Record a flashcard review"""
784
+ card = studio_manager.review_flashcard(card_id, review)
785
+ if not card:
786
+ raise HTTPException(status_code=404, detail="Flashcard not found")
787
+ return card
788
+
789
+ @app.delete("/api/studio/flashcards/{card_id}")
790
+ async def delete_flashcard(card_id: str):
791
+ """Delete a flashcard"""
792
+ success = studio_manager.delete_flashcard(card_id)
793
+ if not success:
794
+ raise HTTPException(status_code=404, detail="Flashcard not found")
795
+ return {"status": "success", "message": "Flashcard deleted"}
796
+
797
+ @app.post("/api/studio/flashcards/generate", response_model=List[Flashcard])
798
+ async def generate_flashcards(request: FlashcardGenerateRequest):
799
+ """Generate flashcards from content using LLM"""
800
+ global studio_generator
801
+
802
+ if not studio_generator:
803
+ raise HTTPException(status_code=503, detail="LLM not initialized")
804
+
805
+ try:
806
+ cards = await studio_generator.generate_flashcards(request)
807
+ return cards
808
+ except Exception as e:
809
+ raise HTTPException(status_code=500, detail=str(e))
810
+
811
+
812
+ # ===== QUIZ ROUTES =====
813
+
814
+ @app.post("/api/studio/quizzes", response_model=Quiz)
815
+ async def create_quiz(quiz_data: QuizCreate):
816
+ """Create a new quiz"""
817
+ try:
818
+ quiz = studio_manager.create_quiz(quiz_data)
819
+ return quiz
820
+ except Exception as e:
821
+ raise HTTPException(status_code=500, detail=str(e))
822
+
823
+ @app.get("/api/studio/quizzes", response_model=List[Quiz])
824
+ async def list_quizzes(space_id: Optional[str] = None):
825
+ """List all quizzes, optionally filtered by space"""
826
+ try:
827
+ quizzes = studio_manager.list_quizzes(space_id)
828
+ return quizzes
829
+ except Exception as e:
830
+ raise HTTPException(status_code=500, detail=str(e))
831
+
832
+ @app.get("/api/studio/quizzes/{quiz_id}", response_model=Quiz)
833
+ async def get_quiz(quiz_id: str):
834
+ """Get a quiz by ID"""
835
+ quiz = studio_manager.get_quiz(quiz_id)
836
+ if not quiz:
837
+ raise HTTPException(status_code=404, detail="Quiz not found")
838
+ return quiz
839
+
840
+ @app.delete("/api/studio/quizzes/{quiz_id}")
841
+ async def delete_quiz(quiz_id: str):
842
+ """Delete a quiz"""
843
+ success = studio_manager.delete_quiz(quiz_id)
844
+ if not success:
845
+ raise HTTPException(status_code=404, detail="Quiz not found")
846
+ return {"status": "success", "message": "Quiz deleted"}
847
+
848
+ @app.post("/api/studio/quizzes/generate", response_model=Quiz)
849
+ async def generate_quiz(request: QuizGenerateRequest):
850
+ """Generate a quiz from content using LLM"""
851
+ global studio_generator
852
+
853
+ if not studio_generator:
854
+ raise HTTPException(status_code=503, detail="LLM not initialized")
855
+
856
+ try:
857
+ quiz = await studio_generator.generate_quiz(request)
858
+ if not quiz:
859
+ raise HTTPException(status_code=500, detail="Failed to generate quiz")
860
+ return quiz
861
+ except Exception as e:
862
+ raise HTTPException(status_code=500, detail=str(e))
863
+
864
+ @app.post("/api/studio/quizzes/{quiz_id}/submit", response_model=QuizResult)
865
+ async def submit_quiz(quiz_id: str, submission: QuizSubmission):
866
+ """Submit quiz answers and get results"""
867
+ try:
868
+ result = studio_manager.submit_quiz(quiz_id, submission.answers)
869
+ if not result:
870
+ raise HTTPException(status_code=404, detail="Quiz not found")
871
+ return result
872
+ except Exception as e:
873
+ raise HTTPException(status_code=500, detail=str(e))
874
+
875
+ @app.get("/api/studio/quizzes/{quiz_id}/history", response_model=QuizHistory)
876
+ async def get_quiz_history(quiz_id: str):
877
+ """Get quiz attempt history"""
878
+ try:
879
+ history = studio_manager.get_quiz_history(quiz_id)
880
+ if not history:
881
+ raise HTTPException(status_code=404, detail="Quiz not found")
882
+ return history
883
+ except HTTPException as he:
884
+ # If the error is already an HTTPException (like the missing API key error), pass it through directly
885
+ raise he
886
+ except Exception as e:
887
+ # For all other crashes, print the actual traceback to the terminal so you can see what broke
888
+ import traceback
889
+ traceback.print_exc()
890
+ raise HTTPException(status_code=500, detail=str(e))
891
+
892
+ # ==================== Run Server ====================
893
+
894
+ if __name__ == "__main__":
895
+ import uvicorn
896
+ uvicorn.run(app, host="0.0.0.0", port=8000, log_level="error")
models/__pycache__/studio_models.cpython-311.pyc ADDED
Binary file (15.4 kB). View file
 
models/studio_models.py ADDED
@@ -0,0 +1,219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Studio Models - Notebook, Flashcards, Quiz
3
+ These models represent the core Studio features for NotebookPRO
4
+ """
5
+ from pydantic import BaseModel, Field
6
+ from typing import List, Optional, Dict, Any
7
+ from datetime import datetime
8
+ from enum import Enum
9
+
10
+
11
+ # ============================================================================
12
+ # NOTEBOOK MODELS
13
+ # ============================================================================
14
+
15
+ class NotebookEntry(BaseModel):
16
+ """A single note entry in the notebook"""
17
+ id: str = Field(..., description="Unique identifier for the note")
18
+ space_id: str = Field(..., description="Space this note belongs to")
19
+ title: str = Field(..., description="Title of the note")
20
+ content: str = Field(..., description="Main content/body of the note")
21
+ source_type: str = Field(default="manual", description="Source: manual, chat, generated")
22
+ source_id: Optional[str] = Field(None, description="ID of source (e.g., chat message ID)")
23
+ tags: List[str] = Field(default_factory=list, description="Tags for categorization")
24
+ created_at: datetime = Field(default_factory=datetime.now)
25
+ updated_at: datetime = Field(default_factory=datetime.now)
26
+ metadata: Dict[str, Any] = Field(default_factory=dict)
27
+
28
+
29
+ class NotebookEntryCreate(BaseModel):
30
+ """Request model for creating a notebook entry"""
31
+ space_id: str
32
+ title: str
33
+ content: str
34
+ source_type: str = "manual"
35
+ source_id: Optional[str] = None
36
+ tags: List[str] = []
37
+ metadata: Dict[str, Any] = {}
38
+
39
+
40
+ class NotebookEntryUpdate(BaseModel):
41
+ """Request model for updating a notebook entry"""
42
+ title: Optional[str] = None
43
+ content: Optional[str] = None
44
+ tags: Optional[List[str]] = None
45
+ metadata: Optional[Dict[str, Any]] = None
46
+
47
+
48
+ # ============================================================================
49
+ # FLASHCARD MODELS
50
+ # ============================================================================
51
+
52
+ class DifficultyLevel(str, Enum):
53
+ """Difficulty level for flashcards"""
54
+ EASY = "easy"
55
+ MEDIUM = "medium"
56
+ HARD = "hard"
57
+
58
+
59
+ class MasteryLevel(str, Enum):
60
+ """User's mastery level for a flashcard"""
61
+ NEW = "new"
62
+ LEARNING = "learning"
63
+ REVIEWING = "reviewing"
64
+ MASTERED = "mastered"
65
+
66
+
67
+ class Flashcard(BaseModel):
68
+ """A single flashcard for memorization"""
69
+ id: str = Field(..., description="Unique identifier")
70
+ space_id: str = Field(..., description="Space this flashcard belongs to")
71
+ question: str = Field(..., description="Front of the card (question/prompt)")
72
+ answer: str = Field(..., description="Back of the card (answer/explanation)")
73
+ difficulty: DifficultyLevel = Field(default=DifficultyLevel.MEDIUM)
74
+ mastery: MasteryLevel = Field(default=MasteryLevel.NEW)
75
+ source_type: str = Field(default="manual", description="Source: manual, generated, notebook")
76
+ source_id: Optional[str] = Field(None, description="Source ID (e.g., notebook entry ID)")
77
+ tags: List[str] = Field(default_factory=list)
78
+ review_count: int = Field(default=0, description="Number of times reviewed")
79
+ correct_count: int = Field(default=0, description="Number of times answered correctly")
80
+ last_reviewed: Optional[datetime] = None
81
+ next_review: Optional[datetime] = None
82
+ created_at: datetime = Field(default_factory=datetime.now)
83
+ metadata: Dict[str, Any] = Field(default_factory=dict)
84
+
85
+
86
+ class FlashcardCreate(BaseModel):
87
+ """Request model for creating a flashcard"""
88
+ space_id: str
89
+ question: str
90
+ answer: str
91
+ difficulty: DifficultyLevel = DifficultyLevel.MEDIUM
92
+ source_type: str = "manual"
93
+ source_id: Optional[str] = None
94
+ tags: List[str] = []
95
+ metadata: Dict[str, Any] = {}
96
+
97
+
98
+ class FlashcardUpdate(BaseModel):
99
+ """Request model for updating a flashcard"""
100
+ question: Optional[str] = None
101
+ answer: Optional[str] = None
102
+ difficulty: Optional[DifficultyLevel] = None
103
+ mastery: Optional[MasteryLevel] = None
104
+ tags: Optional[List[str]] = None
105
+
106
+
107
+ class FlashcardReview(BaseModel):
108
+ """Request model for reviewing a flashcard"""
109
+ correct: bool = Field(..., description="Whether the user answered correctly")
110
+
111
+
112
+ class FlashcardGenerateRequest(BaseModel):
113
+ """Request to generate flashcards from content"""
114
+ space_id: str
115
+ source_type: str = Field(..., description="Source type: notebook, file, text")
116
+ source_ids: Optional[List[str]] = Field(None, description="IDs of notebook entries or files")
117
+ text_content: Optional[str] = Field(None, description="Direct text content to generate from")
118
+ num_cards: int = Field(default=5, description="Number of flashcards to generate")
119
+ difficulty: DifficultyLevel = DifficultyLevel.MEDIUM
120
+
121
+
122
+ # ============================================================================
123
+ # QUIZ MODELS
124
+ # ============================================================================
125
+
126
+ class QuestionType(str, Enum):
127
+ """Type of quiz question"""
128
+ MULTIPLE_CHOICE = "multiple_choice"
129
+ TRUE_FALSE = "true_false"
130
+ SHORT_ANSWER = "short_answer"
131
+
132
+
133
+ class QuizQuestion(BaseModel):
134
+ """A single quiz question"""
135
+ id: str = Field(..., description="Unique identifier")
136
+ question: str = Field(..., description="Question text")
137
+ type: QuestionType = Field(..., description="Question type")
138
+ options: Optional[List[str]] = Field(None, description="Options for multiple choice")
139
+ correct_answer: str = Field(..., description="Correct answer")
140
+ explanation: Optional[str] = Field(None, description="Explanation of the answer")
141
+ points: int = Field(default=1, description="Points for this question")
142
+ difficulty: DifficultyLevel = Field(default=DifficultyLevel.MEDIUM)
143
+
144
+
145
+ class Quiz(BaseModel):
146
+ """A quiz session"""
147
+ id: str = Field(..., description="Unique identifier")
148
+ space_id: str = Field(..., description="Space this quiz belongs to")
149
+ title: str = Field(..., description="Quiz title")
150
+ description: Optional[str] = None
151
+ questions: List[QuizQuestion] = Field(..., description="List of questions")
152
+ source_type: str = Field(default="manual", description="Source: manual, generated, notebook, file")
153
+ source_ids: Optional[List[str]] = None
154
+ created_at: datetime = Field(default_factory=datetime.now)
155
+ metadata: Dict[str, Any] = Field(default_factory=dict)
156
+
157
+
158
+ class QuizCreate(BaseModel):
159
+ """Request model for creating a quiz"""
160
+ space_id: str
161
+ title: str
162
+ description: Optional[str] = None
163
+ questions: List[QuizQuestion] = []
164
+ source_type: str = "manual"
165
+ source_ids: Optional[List[str]] = None
166
+
167
+
168
+ class QuizGenerateRequest(BaseModel):
169
+ """Request to generate a quiz from content"""
170
+ space_id: str
171
+ title: str
172
+ source_type: str = Field(..., description="Source type: notebook, file, text")
173
+ source_ids: Optional[List[str]] = Field(None, description="IDs of notebook entries or files")
174
+ text_content: Optional[str] = Field(None, description="Direct text content")
175
+ num_questions: int = Field(default=5, description="Number of questions")
176
+ question_types: List[QuestionType] = Field(
177
+ default=[QuestionType.MULTIPLE_CHOICE],
178
+ description="Types of questions to include"
179
+ )
180
+ difficulty: DifficultyLevel = DifficultyLevel.MEDIUM
181
+
182
+
183
+ class QuizAnswer(BaseModel):
184
+ """User's answer to a quiz question"""
185
+ question_id: str
186
+ answer: str
187
+ time_spent: Optional[int] = Field(None, description="Time spent in seconds")
188
+
189
+
190
+ class QuizSubmission(BaseModel):
191
+ """User's quiz submission"""
192
+ quiz_id: str
193
+ answers: List[QuizAnswer]
194
+
195
+
196
+ class QuizResult(BaseModel):
197
+ """Result of a quiz submission"""
198
+ quiz_id: str
199
+ submission_id: str = Field(..., description="Unique submission ID")
200
+ total_questions: int
201
+ correct_answers: int
202
+ incorrect_answers: int
203
+ score_percentage: float
204
+ total_points: int
205
+ earned_points: int
206
+ answers: List[Dict[str, Any]] = Field(..., description="Detailed answer results")
207
+ completed_at: datetime = Field(default_factory=datetime.now)
208
+ time_taken: Optional[int] = Field(None, description="Total time in seconds")
209
+
210
+
211
+ class QuizHistory(BaseModel):
212
+ """Quiz attempt history"""
213
+ quiz_id: str
214
+ space_id: str
215
+ quiz_title: str
216
+ results: List[QuizResult] = Field(default_factory=list)
217
+ best_score: float = Field(default=0.0)
218
+ average_score: float = Field(default=0.0)
219
+ attempts_count: int = Field(default=0)
requirements.txt ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # FastAPI Framework - version compatible with Pydantic v1
2
+ fastapi==0.104.1
3
+ uvicorn[standard]==0.24.0
4
+ python-multipart==0.0.6
5
+ pydantic==1.10.13
6
+
7
+ # HTTP client - version compatible with groq
8
+ httpx==0.25.2
9
+
10
+ # Ngrok tunnel for public access
11
+ pyngrok==7.0.0
12
+
13
+ # Vector database - Pydantic v1 compatible version
14
+
15
+ # ML/AI dependencies - compatible versions
16
+ sentence-transformers==2.7.0
17
+ huggingface-hub==0.23.0
18
+ rank-bm25==0.2.2
19
+
20
+ # LLM providers
21
+ groq==1.1.1
22
+ google-generativeai==0.3.2
23
+
24
+ # Streamlit for UI components
25
+ streamlit==1.31.0
26
+
27
+ # File processing
28
+ PyPDF2==3.0.1
29
+ pdfplumber==0.10.3
30
+ python-docx==1.1.0
31
+
32
+ # Environment variables
33
+ python-dotenv==1.0.0
34
+
35
+ # REMOVE this:
36
+ # chromadb==0.3.21
37
+
38
+ # ADD these:
39
+ qdrant-client==1.8.0
runtime.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ 3.11.0
start_ngrok_tunnel.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Alternative tunnel using ngrok (requires one-time free signup).
3
+
4
+ Setup (one time only):
5
+ 1. Sign up at https://ngrok.com (free)
6
+ 2. Get your authtoken from dashboard
7
+ 3. Run this once: ngrok config add-authtoken YOUR_TOKEN_HERE
8
+
9
+ Then run this script to start the tunnel.
10
+ """
11
+
12
+ from pyngrok import ngrok, conf
13
+ import time
14
+ import json
15
+ from pathlib import Path
16
+
17
+ def start_ngrok_tunnel(port=8000):
18
+ """Start ngrok tunnel for the backend."""
19
+
20
+ try:
21
+ # Close any existing tunnels
22
+ ngrok.kill()
23
+
24
+ # Start tunnel (use default free tier settings - no subdomain)
25
+ print(f"Starting ngrok tunnel for port {port}...")
26
+ # Just use basic http connection without any domain options
27
+ tunnel = ngrok.connect(port, bind_tls=True)
28
+ public_url = tunnel.public_url
29
+
30
+ print("\n" + "="*60)
31
+ print("🚀 Backend Tunnel Started!")
32
+ print("="*60)
33
+ print(f"Public URL: {public_url}")
34
+ print(f"Local URL: http://localhost:{port}")
35
+ print("="*60)
36
+ print("\n✅ Update your Flutter app:")
37
+ print(f' static const String baseUrl = "{public_url}";')
38
+ print("=" * 60)
39
+ print("\nPress Ctrl+C to stop the tunnel")
40
+ print("="*60 + "\n")
41
+
42
+ # Save config
43
+ config_file = Path(__file__).parent.parent / "tunnel_config.json"
44
+ config = {
45
+ "backend_url": public_url,
46
+ "created_at": time.strftime("%Y-%m-%d %H:%M:%S")
47
+ }
48
+
49
+ with open(config_file, 'w') as f:
50
+ json.dump(config, f, indent=2)
51
+
52
+ try:
53
+ while True:
54
+ time.sleep(1)
55
+ except KeyboardInterrupt:
56
+ print("\n\n✨ Shutting down tunnel...")
57
+ ngrok.kill()
58
+ print("✅ Tunnel closed\n")
59
+
60
+ except Exception as e:
61
+ print(f"\n❌ Error: {e}\n")
62
+ print("Setup ngrok authentication:")
63
+ print("1. Sign up at https://ngrok.com (free)")
64
+ print("2. Get your authtoken from the dashboard")
65
+ print("3. Run: ngrok config add-authtoken YOUR_TOKEN")
66
+ print("4. Run this script again\n")
67
+
68
+ if __name__ == "__main__":
69
+ start_ngrok_tunnel()
utils/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """Utility modules for document processing and vector database operations."""
utils/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (231 Bytes). View file
 
utils/__pycache__/__init__.cpython-314.pyc ADDED
Binary file (226 Bytes). View file
 
utils/__pycache__/config_manager.cpython-311.pyc ADDED
Binary file (5.2 kB). View file
 
utils/__pycache__/config_manager.cpython-314.pyc ADDED
Binary file (6.09 kB). View file
 
utils/__pycache__/document_processor.cpython-311.pyc ADDED
Binary file (10.9 kB). View file
 
utils/__pycache__/document_processor.cpython-314.pyc ADDED
Binary file (11.1 kB). View file
 
utils/__pycache__/hybrid_retriever.cpython-311.pyc ADDED
Binary file (7.29 kB). View file
 
utils/__pycache__/hybrid_retriever.cpython-314.pyc ADDED
Binary file (6.87 kB). View file
 
utils/__pycache__/llm_generator.cpython-311.pyc ADDED
Binary file (13.5 kB). View file
 
utils/__pycache__/llm_generator.cpython-314.pyc ADDED
Binary file (13.8 kB). View file
 
utils/__pycache__/model_inference.cpython-311.pyc ADDED
Binary file (6.86 kB). View file
 
utils/__pycache__/simple_generator.cpython-311.pyc ADDED
Binary file (23.9 kB). View file
 
utils/__pycache__/spaces_manager.cpython-311.pyc ADDED
Binary file (7.54 kB). View file
 
utils/__pycache__/spaces_manager.cpython-314.pyc ADDED
Binary file (8.57 kB). View file
 
utils/__pycache__/studio_generator.cpython-311.pyc ADDED
Binary file (12 kB). View file
 
utils/__pycache__/studio_manager.cpython-311.pyc ADDED
Binary file (26.7 kB). View file
 
utils/__pycache__/vector_db.cpython-311.pyc ADDED
Binary file (8.14 kB). View file
 
utils/__pycache__/vector_db.cpython-314.pyc ADDED
Binary file (6.6 kB). View file
 
utils/chat_manager.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Chat management utilities for NotebookPRO.
3
+ """
4
+
5
+ import json
6
+ from datetime import datetime
7
+ from pathlib import Path
8
+ from typing import List, Dict, Optional
9
+ import config
10
+
11
+
12
+ class ChatManager:
13
+ """Manage chat sessions and history."""
14
+
15
+ def __init__(self):
16
+ self.chats_dir = config.CHATS_DIR
17
+ self.chats_dir.mkdir(parents=True, exist_ok=True)
18
+
19
+ def save_chat(self, chat_id: str, messages: List[Dict], space: Optional[str] = None) -> None:
20
+ """
21
+ Save a chat session.
22
+
23
+ Args:
24
+ chat_id: Unique chat identifier
25
+ messages: List of message dictionaries
26
+ space: Optional space/subject name
27
+ """
28
+ chat_data = {
29
+ 'id': chat_id,
30
+ 'messages': messages,
31
+ 'space': space,
32
+ 'created_at': datetime.now().isoformat(),
33
+ 'updated_at': datetime.now().isoformat()
34
+ }
35
+
36
+ chat_file = self.chats_dir / f"{chat_id}.json"
37
+ with open(chat_file, 'w', encoding='utf-8') as f:
38
+ json.dump(chat_data, f, indent=2, ensure_ascii=False)
39
+
40
+ def load_chat(self, chat_id: str) -> Optional[Dict]:
41
+ """
42
+ Load a chat session.
43
+
44
+ Args:
45
+ chat_id: Unique chat identifier
46
+
47
+ Returns:
48
+ Chat data dictionary or None if not found
49
+ """
50
+ chat_file = self.chats_dir / f"{chat_id}.json"
51
+
52
+ if not chat_file.exists():
53
+ return None
54
+
55
+ with open(chat_file, 'r', encoding='utf-8') as f:
56
+ return json.load(f)
57
+
58
+ def list_chats(self, space: Optional[str] = None) -> List[Dict]:
59
+ """
60
+ List all chats, optionally filtered by space.
61
+
62
+ Args:
63
+ space: Optional space filter
64
+
65
+ Returns:
66
+ List of chat metadata dictionaries
67
+ """
68
+ chats = []
69
+
70
+ for chat_file in self.chats_dir.glob("*.json"):
71
+ with open(chat_file, 'r', encoding='utf-8') as f:
72
+ chat_data = json.load(f)
73
+
74
+ if space is None or chat_data.get('space') == space:
75
+ chats.append({
76
+ 'id': chat_data['id'],
77
+ 'space': chat_data.get('space'),
78
+ 'message_count': len(chat_data['messages']),
79
+ 'created_at': chat_data.get('created_at'),
80
+ 'updated_at': chat_data.get('updated_at')
81
+ })
82
+
83
+ # Sort by updated time (most recent first)
84
+ chats.sort(key=lambda x: x.get('updated_at', ''), reverse=True)
85
+
86
+ return chats
87
+
88
+ def delete_chat(self, chat_id: str) -> bool:
89
+ """
90
+ Delete a chat session.
91
+
92
+ Args:
93
+ chat_id: Unique chat identifier
94
+
95
+ Returns:
96
+ True if deleted, False if not found
97
+ """
98
+ chat_file = self.chats_dir / f"{chat_id}.json"
99
+
100
+ if chat_file.exists():
101
+ chat_file.unlink()
102
+ return True
103
+
104
+ return False
105
+
106
+ def get_chat_preview(self, chat_id: str, max_messages: int = 5) -> Optional[List[Dict]]:
107
+ """
108
+ Get a preview of recent messages from a chat.
109
+
110
+ Args:
111
+ chat_id: Unique chat identifier
112
+ max_messages: Maximum number of messages to return
113
+
114
+ Returns:
115
+ List of recent messages or None if chat not found
116
+ """
117
+ chat_data = self.load_chat(chat_id)
118
+
119
+ if chat_data is None:
120
+ return None
121
+
122
+ messages = chat_data['messages']
123
+ return messages[-max_messages:] if len(messages) > max_messages else messages
utils/config_manager.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Configuration manager for persistent settings (API keys, preferences).
3
+ """
4
+
5
+ import json
6
+ from pathlib import Path
7
+ from typing import Dict, Optional
8
+ import config
9
+
10
+
11
+ class ConfigManager:
12
+ """Manages persistent user configuration."""
13
+
14
+ def __init__(self):
15
+ self.config_file = config.DATA_DIR / "user_config.json"
16
+ self.config_data = self._load_config()
17
+
18
+ def _load_config(self) -> Dict:
19
+ """Load configuration from file."""
20
+ if self.config_file.exists():
21
+ try:
22
+ with open(self.config_file, 'r', encoding='utf-8') as f:
23
+ return json.load(f)
24
+ except Exception:
25
+ return self._default_config()
26
+ return self._default_config()
27
+
28
+ def _default_config(self) -> Dict:
29
+ """Default configuration."""
30
+ return {
31
+ "api_keys": {
32
+ "groq": "",
33
+ "gemini": ""
34
+ },
35
+ "preferences": {
36
+ "llm_provider": "groq",
37
+ "temperature": 0.7,
38
+ "workflow": "Auto-Detect"
39
+ },
40
+ "current_space": "General"
41
+ }
42
+
43
+ def save_config(self):
44
+ """Save configuration to file."""
45
+ try:
46
+ with open(self.config_file, 'w', encoding='utf-8') as f:
47
+ json.dump(self.config_data, f, indent=2)
48
+ except Exception as e:
49
+ print(f"Error saving config: {e}")
50
+
51
+ def get_api_key(self, provider: str) -> str:
52
+ """Get API key for provider."""
53
+ return self.config_data.get("api_keys", {}).get(provider, "")
54
+
55
+ def set_api_key(self, provider: str, api_key: str):
56
+ """Save API key for provider."""
57
+ if "api_keys" not in self.config_data:
58
+ self.config_data["api_keys"] = {}
59
+ self.config_data["api_keys"][provider] = api_key
60
+ self.save_config()
61
+
62
+ def get_preference(self, key: str, default=None):
63
+ """Get user preference."""
64
+ return self.config_data.get("preferences", {}).get(key, default)
65
+
66
+ def set_preference(self, key: str, value):
67
+ """Save user preference."""
68
+ if "preferences" not in self.config_data:
69
+ self.config_data["preferences"] = {}
70
+ self.config_data["preferences"][key] = value
71
+ self.save_config()
72
+
73
+ def get_current_space(self) -> str:
74
+ """Get current workspace."""
75
+ return self.config_data.get("current_space", "General")
76
+
77
+ def set_current_space(self, space_name: str):
78
+ """Set current workspace."""
79
+ self.config_data["current_space"] = space_name
80
+ self.save_config()
utils/document_processor.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import PyPDF2
2
+ import pdfplumber
3
+ from docx import Document
4
+ from pathlib import Path
5
+ from typing import List, Dict
6
+ import re
7
+ import warnings
8
+ import logging
9
+
10
+ # Suppress PyPDF2 warnings about font descriptors
11
+ warnings.filterwarnings('ignore', category=UserWarning, module='PyPDF2')
12
+ logging.getLogger('PyPDF2').setLevel(logging.ERROR)
13
+
14
+
15
+ class DocumentProcessor:
16
+ """Process various document types and extract text content."""
17
+
18
+ def __init__(self):
19
+ self.supported_formats = ['.pdf', '.txt', '.docx']
20
+
21
+ def process_file(self, file_path: Path) -> Dict[str, any]:
22
+ """
23
+ Process a single file and extract its content.
24
+
25
+ Args:
26
+ file_path: Path to the file
27
+
28
+ Returns:
29
+ Dictionary containing file metadata and content
30
+ """
31
+ suffix = file_path.suffix.lower()
32
+
33
+ if suffix == '.pdf':
34
+ content = self._extract_pdf(file_path)
35
+ elif suffix == '.txt':
36
+ content = self._extract_txt(file_path)
37
+ elif suffix == '.docx':
38
+ content = self._extract_docx(file_path)
39
+ else:
40
+ raise ValueError(f"Unsupported file format: {suffix}")
41
+
42
+ return {
43
+ 'filename': file_path.name,
44
+ 'path': str(file_path),
45
+ 'content': content,
46
+ 'format': suffix
47
+ }
48
+
49
+ def _extract_pdf(self, file_path: Path) -> str:
50
+ """Extract text from PDF using pdfplumber with PyPDF2 fallback."""
51
+ text = ""
52
+ try:
53
+ # Primary: Use pdfplumber (better for complex PDFs)
54
+ with pdfplumber.open(file_path) as pdf:
55
+ for page in pdf.pages:
56
+ page_text = page.extract_text()
57
+ if page_text:
58
+ text += page_text + "\n"
59
+ except Exception as e:
60
+ # Fallback: Use PyPDF2 with warnings suppressed
61
+ try:
62
+ with warnings.catch_warnings():
63
+ warnings.simplefilter("ignore")
64
+ with open(file_path, 'rb') as file:
65
+ pdf_reader = PyPDF2.PdfReader(file)
66
+ for page in pdf_reader.pages:
67
+ try:
68
+ page_text = page.extract_text()
69
+ if page_text:
70
+ text += page_text + "\n"
71
+ except Exception:
72
+ continue # Skip problematic pages
73
+ except Exception as e2:
74
+ raise ValueError(f"Could not extract text from PDF: {file_path.name}")
75
+
76
+ return self._clean_text(text)
77
+
78
+ def _extract_txt(self, file_path: Path) -> str:
79
+ """Extract text from TXT file."""
80
+ try:
81
+ with open(file_path, 'r', encoding='utf-8') as file:
82
+ text = file.read()
83
+ except UnicodeDecodeError:
84
+ with open(file_path, 'r', encoding='latin-1') as file:
85
+ text = file.read()
86
+
87
+ return self._clean_text(text)
88
+
89
+ def _extract_docx(self, file_path: Path) -> str:
90
+ """Extract text from DOCX file."""
91
+ doc = Document(file_path)
92
+ text = "\n".join([paragraph.text for paragraph in doc.paragraphs])
93
+ return self._clean_text(text)
94
+
95
+ def _clean_text(self, text: str) -> str:
96
+ """Clean and normalize text."""
97
+ # Remove excessive whitespace
98
+ text = re.sub(r'\s+', ' ', text)
99
+ # Remove special characters but keep punctuation
100
+ text = re.sub(r'[^\w\s.,!?;:()\-\'\"]+', '', text)
101
+ return text.strip()
102
+
103
+ def chunk_text(self, text: str, chunk_size: int = 512, overlap: int = 50, semantic: bool = True) -> List[str]:
104
+ """
105
+ Split text into chunks using semantic or simple chunking.
106
+
107
+ Args:
108
+ text: The text to chunk
109
+ chunk_size: Target size of each chunk in characters
110
+ overlap: Number of overlapping characters between chunks
111
+ semantic: Use semantic chunking (by headers/concepts) if True
112
+
113
+ Returns:
114
+ List of text chunks
115
+ """
116
+ if semantic:
117
+ return self._semantic_chunk(text, chunk_size, overlap)
118
+ else:
119
+ return self._simple_chunk(text, chunk_size, overlap)
120
+
121
+ def _semantic_chunk(self, text: str, target_size: int = 512, overlap: int = 50) -> List[str]:
122
+ """
123
+ Chunk text by detecting headers and logical sections.
124
+ Perfect for lecture slides and structured documents.
125
+ """
126
+ chunks = []
127
+
128
+ # Split by common header patterns
129
+ # Pattern 1: Lines that are ALL CAPS or Title Case followed by newline
130
+ # Pattern 2: Lines starting with numbers like "1.", "1.1", etc.
131
+ # Pattern 3: Lines with clear visual separators
132
+
133
+ # First, split by double newlines (paragraphs)
134
+ sections = text.split('\n\n')
135
+
136
+ current_chunk = ""
137
+ current_header = ""
138
+
139
+ for section in sections:
140
+ section = section.strip()
141
+ if not section:
142
+ continue
143
+
144
+ # Check if this looks like a header
145
+ is_header = self._is_likely_header(section)
146
+
147
+ if is_header and len(current_chunk) > 100:
148
+ # Save previous chunk and start new one with this header
149
+ if current_chunk:
150
+ chunks.append(current_chunk.strip())
151
+ current_chunk = section + "\n\n"
152
+ current_header = section
153
+ else:
154
+ # Add to current chunk
155
+ potential_chunk = current_chunk + section + "\n\n"
156
+
157
+ # If chunk is getting too large, split it
158
+ if len(potential_chunk) > target_size * 1.5:
159
+ if current_chunk:
160
+ chunks.append(current_chunk.strip())
161
+ current_chunk = section + "\n\n"
162
+ else:
163
+ current_chunk = potential_chunk
164
+
165
+ # Add final chunk
166
+ if current_chunk:
167
+ chunks.append(current_chunk.strip())
168
+
169
+ # If semantic chunking produced too few chunks, fall back to simple chunking
170
+ if len(chunks) < len(text) / (target_size * 2):
171
+ return self._simple_chunk(text, target_size, overlap)
172
+
173
+ return chunks
174
+
175
+ def _is_likely_header(self, text: str) -> bool:
176
+ """Detect if text is likely a header/title."""
177
+ # Too long to be a header
178
+ if len(text) > 200:
179
+ return False
180
+
181
+ # Single line headers
182
+ if '\n' not in text:
183
+ # ALL CAPS
184
+ if text.isupper() and len(text.split()) <= 10:
185
+ return True
186
+
187
+ # Title Case
188
+ if text.istitle() and len(text.split()) <= 10:
189
+ return True
190
+
191
+ # Numbered sections like "1.", "1.1", "Chapter 1"
192
+ if re.match(r'^(\d+\.)+\s+', text) or re.match(r'^(Chapter|Section|Part)\s+\d+', text, re.IGNORECASE):
193
+ return True
194
+
195
+ return False
196
+
197
+ def _simple_chunk(self, text: str, chunk_size: int = 512, overlap: int = 50) -> List[str]:
198
+ """
199
+ Split text into overlapping chunks (original method).
200
+ """
201
+ chunks = []
202
+ start = 0
203
+ text_length = len(text)
204
+
205
+ while start < text_length:
206
+ end = start + chunk_size
207
+ chunk = text[start:end]
208
+
209
+ # Try to break at sentence boundary
210
+ if end < text_length:
211
+ last_period = chunk.rfind('.')
212
+ last_newline = chunk.rfind('\n')
213
+ break_point = max(last_period, last_newline)
214
+
215
+ if break_point > chunk_size * 0.5: # At least 50% through the chunk
216
+ chunk = chunk[:break_point + 1]
217
+ end = start + break_point + 1
218
+
219
+ chunks.append(chunk.strip())
220
+ start = end - overlap
221
+
222
+ return chunks
utils/hybrid_retriever.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Hybrid Retriever: Combines Vector (ChromaDB) + Keyword (BM25) search.
3
+ This is the "secret sauce" that makes NotebookLM so accurate.
4
+ """
5
+
6
+ from typing import List, Dict, Tuple
7
+ from rank_bm25 import BM25Okapi
8
+ import numpy as np
9
+
10
+
11
+ class HybridRetriever:
12
+ """
13
+ Combines Dense Retrieval (Embeddings) with Sparse Retrieval (BM25).
14
+
15
+ This is crucial for accuracy:
16
+ - Vector search finds conceptually similar content
17
+ - BM25 finds exact keyword matches (formulas, terms, names)
18
+ """
19
+
20
+ def __init__(self, vector_db, alpha: float = 0.5):
21
+ """
22
+ Initialize hybrid retriever.
23
+
24
+ Args:
25
+ vector_db: VectorDatabase instance
26
+ alpha: Weight balance (0=only BM25, 1=only vector, 0.5=balanced)
27
+ """
28
+ self.vector_db = vector_db
29
+ self.alpha = alpha # Weight for vector search
30
+ self.bm25 = None
31
+ self.bm25_corpus = []
32
+ self.bm25_metadata = []
33
+
34
+ def index_documents(self, documents: List[str], metadatas: List[Dict]):
35
+ """
36
+ Index documents for BM25 keyword search.
37
+
38
+ Args:
39
+ documents: List of document chunks
40
+ metadatas: List of metadata dicts for each chunk
41
+ """
42
+ # Tokenize documents for BM25
43
+ tokenized_corpus = [doc.lower().split() for doc in documents]
44
+
45
+ # Create BM25 index
46
+ self.bm25 = BM25Okapi(tokenized_corpus)
47
+ self.bm25_corpus = documents
48
+ self.bm25_metadata = metadatas
49
+
50
+ def retrieve(
51
+ self,
52
+ query: str,
53
+ n_results: int = 5,
54
+ score_threshold: float = 0.0
55
+ ) -> Tuple[List[str], List[Dict], List[float]]:
56
+ """
57
+ Hybrid retrieval: combines vector + keyword search.
58
+
59
+ Args:
60
+ query: User's question
61
+ n_results: Number of chunks to retrieve
62
+ score_threshold: Minimum score threshold
63
+
64
+ Returns:
65
+ Tuple of (documents, metadatas, scores)
66
+ """
67
+ if not self.bm25:
68
+ # Fallback to pure vector search if BM25 not initialized
69
+ results = self.vector_db.query(query, n_results=n_results * 2)
70
+ return (
71
+ results['documents'][0] if results['documents'] else [],
72
+ results['metadatas'][0] if results['metadatas'] else [],
73
+ results.get('distances', [[]])[0]
74
+ )
75
+
76
+ # Get more results than needed for reranking
77
+ fetch_size = n_results * 3
78
+
79
+ # 1. Vector search (semantic similarity)
80
+ vector_results = self.vector_db.query(query, n_results=fetch_size)
81
+ vector_docs = vector_results['documents'][0] if vector_results['documents'] else []
82
+ vector_meta = vector_results['metadatas'][0] if vector_results['metadatas'] else []
83
+ vector_distances = vector_results.get('distances', [[]])[0]
84
+
85
+ # Convert distances to similarity scores (chromadb uses cosine distance)
86
+ vector_scores = [1 / (1 + d) for d in vector_distances]
87
+
88
+ # 2. BM25 search (keyword matching)
89
+ tokenized_query = query.lower().split()
90
+ bm25_scores = self.bm25.get_scores(tokenized_query)
91
+
92
+ # Get top BM25 results
93
+ top_bm25_indices = np.argsort(bm25_scores)[::-1][:fetch_size]
94
+
95
+ # 3. Combine results with weighted scoring
96
+ combined_docs = {} # Use dict to deduplicate by content
97
+
98
+ # Add vector results
99
+ for doc, meta, score in zip(vector_docs, vector_meta, vector_scores):
100
+ combined_docs[doc] = {
101
+ 'doc': doc,
102
+ 'meta': meta,
103
+ 'score': self.alpha * score
104
+ }
105
+
106
+ # Add BM25 results (normalize scores to 0-1 range)
107
+ max_bm25_score = max(bm25_scores) if max(bm25_scores) > 0 else 1
108
+ for idx in top_bm25_indices:
109
+ doc = self.bm25_corpus[idx]
110
+ meta = self.bm25_metadata[idx]
111
+ bm25_score = bm25_scores[idx] / max_bm25_score
112
+
113
+ if doc in combined_docs:
114
+ # Average if document found by both methods
115
+ combined_docs[doc]['score'] += (1 - self.alpha) * bm25_score
116
+ else:
117
+ combined_docs[doc] = {
118
+ 'doc': doc,
119
+ 'meta': meta,
120
+ 'score': (1 - self.alpha) * bm25_score
121
+ }
122
+
123
+ # 4. Rank by combined score
124
+ ranked_results = sorted(
125
+ combined_docs.values(),
126
+ key=lambda x: x['score'],
127
+ reverse=True
128
+ )
129
+
130
+ # 5. Filter by threshold and limit results
131
+ filtered_results = [
132
+ r for r in ranked_results
133
+ if r['score'] >= score_threshold
134
+ ][:n_results]
135
+
136
+ # 6. Return in expected format
137
+ documents = [r['doc'] for r in filtered_results]
138
+ metadatas = [r['meta'] for r in filtered_results]
139
+ scores = [r['score'] for r in filtered_results]
140
+
141
+ return documents, metadatas, scores
142
+
143
+ def get_stats(self) -> Dict:
144
+ """Get retriever statistics."""
145
+ return {
146
+ 'bm25_indexed': len(self.bm25_corpus) if self.bm25 else 0,
147
+ 'vector_count': self.vector_db.get_collection_count(),
148
+ 'alpha': self.alpha
149
+ }
utils/llm_generator.py ADDED
@@ -0,0 +1,297 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Real LLM-based generator using Groq or Google Gemini API.
3
+ This ACTUALLY generates responses (unlike SimpleGenerator which just extracts text).
4
+ """
5
+
6
+ import os
7
+ from typing import List, Dict, Optional
8
+ import streamlit as st
9
+
10
+ try:
11
+ from groq import Groq
12
+ GROQ_AVAILABLE = True
13
+ except ImportError:
14
+ GROQ_AVAILABLE = False
15
+
16
+ try:
17
+ import google.generativeai as genai
18
+ GEMINI_AVAILABLE = True
19
+ except ImportError:
20
+ GEMINI_AVAILABLE = False
21
+
22
+
23
+ class LLMGenerator:
24
+ """
25
+ Actual LLM-based response generation using Groq (Llama-3-70B) or Gemini.
26
+ This is what NotebookLM uses - real AI generation, not text extraction.
27
+ """
28
+
29
+ def __init__(self, provider: str = "groq", api_key: Optional[str] = None):
30
+ """
31
+ Initialize LLM generator.
32
+
33
+ Args:
34
+ provider: "groq" or "gemini"
35
+ api_key: API key (if None, reads from environment or asks user)
36
+ """
37
+ self.provider = provider
38
+ self.client = None
39
+ self.ready = False
40
+
41
+ # Get API key
42
+ if api_key:
43
+ self.api_key = api_key
44
+ elif provider == "groq":
45
+ self.api_key = os.getenv("GROQ_API_KEY", "")
46
+ elif provider == "gemini":
47
+ self.api_key = os.getenv("GEMINI_API_KEY", "")
48
+ else:
49
+ self.api_key = ""
50
+
51
+ # Initialize client
52
+ self._initialize_client()
53
+
54
+ def _initialize_client(self):
55
+ """Initialize the LLM client."""
56
+ if not self.api_key:
57
+ return
58
+
59
+ try:
60
+ if self.provider == "groq" and GROQ_AVAILABLE:
61
+ # Initialize Groq client with explicit parameters
62
+ # Avoid potential proxies kwarg issue by not passing extra config
63
+ import os
64
+ os.environ["GROQ_API_KEY"] = self.api_key
65
+ self.client = Groq() # Will read from environment
66
+ self.ready = True
67
+ elif self.provider == "gemini" and GEMINI_AVAILABLE:
68
+ genai.configure(api_key=self.api_key)
69
+ self.client = genai.GenerativeModel('gemini-1.5-flash')
70
+ self.ready = True
71
+ except Exception as e:
72
+ print(f"Failed to initialize {self.provider}: {e}")
73
+ self.ready = False
74
+
75
+ def set_api_key(self, api_key: str):
76
+ """Update API key and reinitialize."""
77
+ self.api_key = api_key
78
+ self._initialize_client()
79
+
80
+ def generate_response(
81
+ self,
82
+ prompt: str,
83
+ context: str = "",
84
+ use_case: str = "explanation",
85
+ metadatas: List[Dict] = None,
86
+ temperature: float = 0.7,
87
+ max_tokens: int = 1500,
88
+ **kwargs
89
+ ) -> str:
90
+ """
91
+ Generate response using actual LLM (NotebookLM-style).
92
+
93
+ Args:
94
+ prompt: User's question
95
+ context: Retrieved context from documents
96
+ use_case: Response type (explanation, summary, qa, notes)
97
+ metadatas: Metadata for citations
98
+ temperature: LLM temperature (0.0-1.0)
99
+ max_tokens: Maximum response length
100
+
101
+ Returns:
102
+ Generated response with inline citations
103
+ """
104
+ if not self.ready:
105
+ return (
106
+ "⚠️ **LLM not configured.** Please add your API key in the sidebar.\n\n"
107
+ "Get a free key:\n"
108
+ "- **Groq** (recommended, very fast): https://console.groq.com/keys\n"
109
+ "- **Gemini** (Google): https://makersuite.google.com/app/apikey"
110
+ )
111
+
112
+ if not context:
113
+ return (
114
+ "I don't have enough information from your uploaded documents to answer this question. "
115
+ "Please upload relevant study materials first."
116
+ )
117
+
118
+ # Build NotebookLM-style system prompt with strict source grounding
119
+ system_prompt = self._build_system_prompt(use_case)
120
+
121
+ # Build user message with context
122
+ user_message = self._build_user_message(prompt, context, metadatas)
123
+
124
+ try:
125
+ # Generate with LLM
126
+ if self.provider == "groq":
127
+ response = self._generate_groq(system_prompt, user_message, temperature, max_tokens)
128
+ elif self.provider == "gemini":
129
+ response = self._generate_gemini(system_prompt, user_message, temperature, max_tokens)
130
+ else:
131
+ return "Error: Unknown provider"
132
+
133
+ return response
134
+
135
+ except Exception as e:
136
+ return f"Error generating response: {str(e)}\n\nPlease check your API key and try again."
137
+
138
+ def _build_system_prompt(self, use_case: str) -> str:
139
+ """Build specialized system prompt based on use case."""
140
+ base_prompt = (
141
+ "You are an expert academic assistant for students, acting like a highly intelligent study buddy. "
142
+ "⚠️ CRITICAL RULE: You MUST ONLY use information from the provided context below. "
143
+ "DO NOT use your training knowledge. DO NOT infer beyond what's explicitly stated. "
144
+ "If the context doesn't contain adequate information to answer the question, you MUST respond: "
145
+ "'I cannot find sufficient information about this in the uploaded documents. Please upload materials covering this topic or rephrase your question.'\n\n"
146
+ "⚠️ GROUNDING REQUIREMENT: Every statement must be traceable to the provided context. "
147
+ "If you cannot find it in the context below, DO NOT answer from general knowledge.\n\n"
148
+ "✨ FORMATTING RULES (NotebookLM Style):\n"
149
+ "- Use clean, hierarchical Markdown (### Headers, **Bold** terms).\n"
150
+ "- Break down long paragraphs into easily readable bullet points.\n"
151
+ "- Be direct and concise. Avoid conversational fluff like 'Certainly!' or 'Here is the answer'.\n"
152
+ "- If applicable to the prompt, always try to extract a **Real-World Example** from the text to aid understanding.\n\n"
153
+ )
154
+
155
+ if use_case == "explanation":
156
+ base_prompt += (
157
+ "**Your task:** Explain the concept in a clear, step-by-step manner suitable for students.\n"
158
+ "1. Start with a concise, one-sentence definition.\n"
159
+ "2. Break down the core mechanics or components using bullet points.\n"
160
+ "3. Provide an example (only if found in the text).\n"
161
+ "4. Add a 'Key Takeaway' at the end.\n"
162
+ )
163
+ elif use_case == "summary":
164
+ base_prompt += (
165
+ "**Your task:** Create a highly structured summary.\n"
166
+ "- Start with a brief high-level overview (2 sentences max).\n"
167
+ "- Use '### Key Themes' and list the main points as bulleted items.\n"
168
+ "- Keep each point concise but factually dense.\n"
169
+ )
170
+ elif use_case == "qa":
171
+ base_prompt += (
172
+ "**Your task:** Answer the question directly and comprehensively.\n"
173
+ "- Provide the direct answer immediately in the first sentence.\n"
174
+ "- Use numbered lists or bullet points to provide supporting details from the context.\n"
175
+ "- Use **bold** for key facts, numbers, and formulas.\n"
176
+ )
177
+ elif use_case == "notes":
178
+ base_prompt += (
179
+ "**Your task:** Create comprehensive, structured study notes.\n"
180
+ "- Use clear section headers (###).\n"
181
+ "- Organize information hierarchically (using nested bullet points).\n"
182
+ "- Explicitly highlight **Definitions**, **Formulas**, and **Important Dates/Names**.\n"
183
+ )
184
+
185
+ base_prompt += (
186
+ "\n**Citation Rules:**\n"
187
+ "- You MUST cite your source at the end of every major claim or paragraph using numbered brackets like **[1]**, **[2]** based on the Source number provided in the context.\n"
188
+ "- If a claim comes from multiple sources, use **[1, 2]**.\n"
189
+ "- Do NOT use the document filename in the citation, ONLY the number.\n"
190
+ "- Do NOT make up information - stick strictly to the provided context.\n"
191
+ )
192
+
193
+ return base_prompt
194
+
195
+ def _build_user_message(self, prompt: str, context: str, metadatas: List[Dict] = None) -> str:
196
+ """Build user message with context and question."""
197
+ # Extract source names from metadata
198
+ sources = []
199
+ if metadatas:
200
+ for meta in metadatas:
201
+ filename = meta.get('filename', 'Unknown')
202
+ clean_name = filename.replace('.pdf', '').replace('.docx', '').replace('.txt', '')
203
+ if clean_name not in sources:
204
+ sources.append(clean_name)
205
+
206
+ message = "**Available Sources (USE ONLY THESE):**\n"
207
+ for source in sources[:5]: # Show up to 5 sources
208
+ message += f"- {source}\n"
209
+
210
+ message += f"\n**===== START OF CONTEXT (ANSWER ONLY FROM THIS) =====**\n\n{context}\n\n"
211
+ message += f"**===== END OF CONTEXT =====**\n\n"
212
+ message += f"**Student's Question:** {prompt}\n\n"
213
+ message += "**Instructions:** Answer ONLY using the context between the markers above. If the context doesn't contain the answer, say you don't have that information. Cite sources in brackets."
214
+
215
+ return message
216
+
217
+ def _generate_groq(self, system_prompt: str, user_message: str, temperature: float, max_tokens: int) -> str:
218
+ """Generate using Groq API (Llama-3.3-70B)."""
219
+ completion = self.client.chat.completions.create(
220
+ model="llama-3.3-70b-versatile", # Latest 70B model (Dec 2024)
221
+ messages=[
222
+ {"role": "system", "content": system_prompt},
223
+ {"role": "user", "content": user_message}
224
+ ],
225
+ temperature=temperature,
226
+ max_tokens=max_tokens,
227
+ top_p=0.95,
228
+ stream=False
229
+ )
230
+
231
+ return completion.choices[0].message.content
232
+
233
+ def _generate_gemini(self, system_prompt: str, user_message: str, temperature: float, max_tokens: int) -> str:
234
+ """Generate using Google Gemini API."""
235
+ full_prompt = f"{system_prompt}\n\n{user_message}"
236
+
237
+ response = self.client.generate_content(
238
+ full_prompt,
239
+ generation_config=genai.GenerationConfig(
240
+ temperature=temperature,
241
+ max_output_tokens=max_tokens,
242
+ top_p=0.95
243
+ )
244
+ )
245
+
246
+ return response.text
247
+
248
+ def is_ready(self) -> bool:
249
+ """Check if LLM is ready to generate."""
250
+ return self.ready
251
+
252
+ def get_provider(self) -> str:
253
+ """Get current provider name."""
254
+ if self.provider == "groq":
255
+ return "Groq (Llama-3.3-70B)"
256
+ elif self.provider == "gemini":
257
+ return "Google Gemini 1.5 Flash"
258
+ return "Unknown"
259
+
260
+ def generate(self, prompt: str, temperature: float = 0.3, max_tokens: int = 1500) -> str:
261
+ """
262
+ Simple wrapper for backend compatibility.
263
+ Generates response from a complete prompt that already includes context.
264
+
265
+ Args:
266
+ prompt: Complete prompt with context already embedded
267
+ temperature: LLM temperature (0.0-1.0)
268
+ max_tokens: Maximum response length
269
+
270
+ Returns:
271
+ Generated response
272
+ """
273
+ if not self.ready:
274
+ return (
275
+ "⚠️ **LLM not configured.** Please add your API key.\n\n"
276
+ "Get a free key:\n"
277
+ "- **Groq** (recommended, very fast): https://console.groq.com/keys\n"
278
+ "- **Gemini** (Google): https://makersuite.google.com/app/apikey"
279
+ )
280
+
281
+ try:
282
+ if self.provider == "groq":
283
+ return self._generate_groq(
284
+ system_prompt="You are a helpful AI assistant.",
285
+ user_message=prompt,
286
+ temperature=temperature,
287
+ max_tokens=max_tokens
288
+ )
289
+ elif self.provider == "gemini":
290
+ return self._generate_gemini(
291
+ system_prompt="You are a helpful AI assistant.",
292
+ user_message=prompt,
293
+ temperature=temperature,
294
+ max_tokens=max_tokens
295
+ )
296
+ except Exception as e:
297
+ return f"Error generating response: {str(e)}"
utils/model_inference.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
3
+ from typing import List, Dict, Optional
4
+ import config
5
+
6
+
7
+ class ModelInference:
8
+ """Handle model loading and inference for text generation."""
9
+
10
+ def __init__(self, model_name: str = None, use_4bit: bool = True):
11
+ """
12
+ Initialize the model for inference.
13
+ RAG Mode: Uses pre-trained model directly (no training needed!).
14
+
15
+ Args:
16
+ model_name: Name or path of the model (uses pre-trained by default)
17
+ use_4bit: Whether to use 4-bit quantization for efficiency
18
+ """
19
+ # Use pre-trained model if specified, otherwise check for fine-tuned model
20
+ if config.USE_PRETRAINED or not Path(config.MODEL_PATH).exists():
21
+ self.model_name = model_name or config.MODEL_NAME
22
+ else:
23
+ self.model_name = model_name or config.MODEL_PATH
24
+
25
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
26
+
27
+ print(f"Loading model: {self.model_name}")
28
+ print(f"Device: {self.device}")
29
+
30
+ # Configure quantization for efficiency
31
+ if use_4bit and self.device == "cuda":
32
+ bnb_config = BitsAndBytesConfig(
33
+ load_in_4bit=True,
34
+ bnb_4bit_quant_type="nf4",
35
+ bnb_4bit_compute_dtype=torch.float16,
36
+ bnb_4bit_use_double_quant=True,
37
+ )
38
+ self.model = AutoModelForCausalLM.from_pretrained(
39
+ self.model_name,
40
+ quantization_config=bnb_config,
41
+ device_map="auto",
42
+ trust_remote_code=True
43
+ )
44
+ else:
45
+ self.model = AutoModelForCausalLM.from_pretrained(
46
+ self.model_name,
47
+ torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
48
+ device_map="auto" if self.device == "cuda" else None,
49
+ trust_remote_code=True
50
+ )
51
+
52
+ self.tokenizer = AutoTokenizer.from_pretrained(
53
+ self.model_name,
54
+ trust_remote_code=True
55
+ )
56
+
57
+ if self.tokenizer.pad_token is None:
58
+ self.tokenizer.pad_token = self.tokenizer.eos_token
59
+
60
+ self.model.eval()
61
+
62
+ def generate_response(
63
+ self,
64
+ prompt: str,
65
+ context: str = "",
66
+ use_case: str = "explanation",
67
+ temperature: float = None,
68
+ max_tokens: int = None
69
+ ) -> str:
70
+ """
71
+ Generate a response based on the prompt and context.
72
+
73
+ Args:
74
+ prompt: User query
75
+ context: Retrieved context from documents
76
+ use_case: Type of response (explanation, summary, qa, notes)
77
+ temperature: Sampling temperature
78
+ max_tokens: Maximum number of tokens to generate
79
+
80
+ Returns:
81
+ Generated text response
82
+ """
83
+ temperature = temperature or config.TEMPERATURE
84
+ max_tokens = max_tokens or config.MAX_TOKENS
85
+
86
+ # Create system prompt based on use case
87
+ system_prompts = {
88
+ "explanation": "You are an expert tutor. Provide detailed, clear explanations of concepts based on the given context.",
89
+ "summary": "You are a summarization expert. Create concise, well-structured summaries of the provided content.",
90
+ "qa": "You are a knowledgeable assistant. Answer questions accurately based on the given context.",
91
+ "notes": "You are a study notes specialist. Create well-organized, structured study notes from the content."
92
+ }
93
+
94
+ system_prompt = system_prompts.get(use_case, system_prompts["explanation"])
95
+
96
+ # Format the full prompt
97
+ full_prompt = self._format_prompt(system_prompt, context, prompt)
98
+
99
+ # Tokenize
100
+ inputs = self.tokenizer(
101
+ full_prompt,
102
+ return_tensors="pt",
103
+ truncation=True,
104
+ max_length=2048
105
+ ).to(self.device)
106
+
107
+ # Generate
108
+ with torch.no_grad():
109
+ outputs = self.model.generate(
110
+ **inputs,
111
+ max_new_tokens=max_tokens,
112
+ temperature=temperature,
113
+ do_sample=True,
114
+ top_p=0.95,
115
+ top_k=50,
116
+ repetition_penalty=1.1,
117
+ pad_token_id=self.tokenizer.pad_token_id,
118
+ eos_token_id=self.tokenizer.eos_token_id
119
+ )
120
+
121
+ # Decode
122
+ response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
123
+
124
+ # Extract only the new generated text
125
+ response = response[len(full_prompt):].strip()
126
+
127
+ return response
128
+
129
+ def _format_prompt(self, system_prompt: str, context: str, query: str) -> str:
130
+ """Format the prompt with system instructions, context, and query."""
131
+ prompt = f"{system_prompt}\n\n"
132
+
133
+ if context:
134
+ prompt += f"Context from your study materials:\n{context}\n\n"
135
+
136
+ prompt += f"Query: {query}\n\nResponse:"
137
+
138
+ return prompt
139
+
140
+ def batch_generate(self, prompts: List[str], **kwargs) -> List[str]:
141
+ """
142
+ Generate responses for multiple prompts.
143
+
144
+ Args:
145
+ prompts: List of prompts
146
+ **kwargs: Additional arguments for generate_response
147
+
148
+ Returns:
149
+ List of generated responses
150
+ """
151
+ responses = []
152
+ for prompt in prompts:
153
+ response = self.generate_response(prompt, **kwargs)
154
+ responses.append(response)
155
+
156
+ return responses
utils/simple_generator.py ADDED
@@ -0,0 +1,444 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ NotebookLM-style response generator with professional formatting.
3
+ """
4
+
5
+ from typing import List, Dict
6
+ import config
7
+ import re
8
+
9
+
10
+ class SimpleGenerator:
11
+ """Lightweight generator with NotebookLM-quality formatting."""
12
+
13
+ def __init__(self):
14
+ self.ready = True
15
+
16
+ def _clean_and_format_text(self, text: str) -> str:
17
+ """Clean and format text with proper spacing like NotebookLM."""
18
+ # Fix spacing after punctuation
19
+ text = re.sub(r'([.!?])([A-Z])', r'\1 \2', text)
20
+ # Remove multiple spaces
21
+ text = re.sub(r'\s+', ' ', text)
22
+ # Add proper line breaks after sentences
23
+ text = re.sub(r'([.!?])\s+', r'\1\n\n', text)
24
+ return text.strip()
25
+
26
+ def _extract_key_terms(self, text: str) -> List[str]:
27
+ """Extract key terms that should be bolded."""
28
+ # Look for capitalized terms, technical terms
29
+ terms = []
30
+
31
+ # Find terms in quotes
32
+ quoted = re.findall(r'"([^"]+)"', text)
33
+ terms.extend(quoted)
34
+
35
+ # Find repeated important words (appear 2+ times)
36
+ words = re.findall(r'\b[A-Z][a-z]+(?:\s+[A-Z][a-z]+)*\b', text)
37
+ word_count = {}
38
+ for word in words:
39
+ word_count[word] = word_count.get(word, 0) + 1
40
+
41
+ # Add words that appear multiple times
42
+ terms.extend([w for w, count in word_count.items() if count >= 2])
43
+
44
+ return list(set(terms))
45
+
46
+ def _apply_bold_formatting(self, text: str) -> str:
47
+ """Apply bold formatting to key terms like NotebookLM."""
48
+ key_terms = self._extract_key_terms(text)
49
+
50
+ # Bold key terms
51
+ for term in key_terms:
52
+ if len(term) > 3: # Skip very short terms
53
+ text = re.sub(rf'\b({re.escape(term)})\b', r'**\1**', text, count=1)
54
+
55
+ # Bold specific patterns
56
+ # Numbers with context
57
+ text = re.sub(r'\b(\d+)\s+(observations?|years?|months?|quarters?)', r'**\1 \2**', text)
58
+
59
+ return text
60
+
61
+ def _create_structured_response(self, context: str, query: str) -> str:
62
+ """Create a NotebookLM-style structured response."""
63
+ # Split into paragraphs
64
+ paragraphs = [p.strip() for p in context.split('\n\n') if len(p.strip()) > 50]
65
+
66
+ # Remove duplicates
67
+ unique_paras = []
68
+ seen = set()
69
+ for para in paragraphs:
70
+ para_key = para.lower()[:150]
71
+ if para_key not in seen:
72
+ unique_paras.append(para)
73
+ seen.add(para_key)
74
+ if len(unique_paras) >= 5:
75
+ break
76
+
77
+ if not unique_paras:
78
+ return context[:1000]
79
+
80
+ # Build NotebookLM-style response
81
+ response = ""
82
+
83
+ # Main explanation (first paragraph - cleaned and formatted)
84
+ main_para = self._clean_and_format_text(unique_paras[0])
85
+ main_para = self._apply_bold_formatting(main_para)
86
+ response += main_para + "\n\n"
87
+
88
+ # Add structured details if more content available
89
+ if len(unique_paras) > 1:
90
+ response += "### Key Points:\n\n"
91
+
92
+ for i, para in enumerate(unique_paras[1:4], 1):
93
+ # Extract first 2-3 sentences
94
+ sentences = [s.strip() for s in para.split('.') if len(s.strip()) > 20]
95
+ if sentences:
96
+ detail = self._clean_and_format_text('. '.join(sentences[:2]) + '.')
97
+ detail = self._apply_bold_formatting(detail)
98
+ response += f"{i}. {detail}\n\n"
99
+
100
+ return response.strip()
101
+
102
+ def generate_response(
103
+ self,
104
+ prompt: str,
105
+ context: str = "",
106
+ use_case: str = "explanation",
107
+ metadatas: List[Dict] = None,
108
+ **kwargs
109
+ ) -> str:
110
+ """
111
+ Generate a NotebookLM-quality response with strict citations.
112
+
113
+ Args:
114
+ prompt: User query
115
+ context: Retrieved context from documents
116
+ use_case: Type of response (explanation, summary, qa,notes)
117
+ metadatas: Metadata for each context chunk (for citations)
118
+
119
+ Returns:
120
+ Professional formatted response with inline citations
121
+ """
122
+ if not context:
123
+ return (
124
+ "I don't have enough information from your uploaded documents to answer this question. "
125
+ "Please upload relevant study materials first, or try rephrasing your question."
126
+ )
127
+
128
+ # Use specialized prompts based on use case
129
+ if use_case == "summary":
130
+ response = self._create_summary_with_citations(context, prompt, metadatas)
131
+ elif use_case == "notes":
132
+ response = self._create_notes_with_citations(context, prompt, metadatas)
133
+ elif use_case == "qa":
134
+ response = self._create_qa_with_citations(context, prompt, metadatas)
135
+ else: # Default to explanation
136
+ response = self._create_structured_response_with_citations(context, prompt, metadatas)
137
+
138
+ return response
139
+
140
+ def _create_structured_response_with_citations(
141
+ self,
142
+ context: str,
143
+ query: str,
144
+ metadatas: List[Dict] = None
145
+ ) -> str:
146
+ """Create NotebookLM-style response with inline citations."""
147
+ # Split into paragraphs
148
+ paragraphs = [p.strip() for p in context.split('\n\n') if len(p.strip()) > 50]
149
+
150
+ # Remove duplicates
151
+ unique_paras = []
152
+ seen = set()
153
+ for para in paragraphs:
154
+ para_key = para.lower()[:150]
155
+ if para_key not in seen:
156
+ unique_paras.append(para)
157
+ seen.add(para_key)
158
+ if len(unique_paras) >= 5:
159
+ break
160
+
161
+ if not unique_paras:
162
+ return context[:1000]
163
+
164
+ # Build response with citations
165
+ response = ""
166
+
167
+ # Main explanation (first paragraph - cleaned and formatted)
168
+ main_para = self._clean_and_format_text(unique_paras[0])
169
+ main_para = self._apply_bold_formatting(main_para)
170
+
171
+ # Add citation to end of main paragraph
172
+ cite_text = self._get_citation(0, metadatas) if metadatas else ""
173
+ response += main_para + cite_text + "\n\n"
174
+
175
+ # Add structured details if more content available
176
+ if len(unique_paras) > 1:
177
+ response += "### Key Points:\n\n"
178
+
179
+ for i, para in enumerate(unique_paras[1:4], 1):
180
+ # Extract first 2-3 sentences
181
+ sentences = [s.strip() for s in para.split('.') if len(s.strip()) > 20]
182
+ if sentences:
183
+ detail = self._clean_and_format_text('. '.join(sentences[:2]) + '.')
184
+ detail = self._apply_bold_formatting(detail)
185
+
186
+ # Add citation
187
+ cite_text = self._get_citation(i, metadatas) if metadatas and i < len(metadatas) else ""
188
+ response += f"{i}. {detail}{cite_text}\n\n"
189
+
190
+ return response.strip()
191
+
192
+ def _get_citation(self, index: int, metadatas: List[Dict] = None) -> str:
193
+ """Generate inline citation from metadata."""
194
+ if not metadatas or index >= len(metadatas):
195
+ return ""
196
+
197
+ meta = metadatas[index]
198
+ filename = meta.get('filename', 'Unknown')
199
+
200
+ # Remove file extension for cleaner citation
201
+ clean_name = filename.replace('.pdf', '').replace('.docx', '').replace('.txt', '')
202
+
203
+ return f" **[{clean_name}]**"
204
+
205
+ def _create_summary_with_citations(
206
+ self,
207
+ context: str,
208
+ query: str,
209
+ metadatas: List[Dict] = None
210
+ ) -> str:
211
+ """Create a summary with citations."""
212
+ sentences = []
213
+ seen = set()
214
+ for s in context.split('.'):
215
+ s_clean = s.strip()
216
+ if len(s_clean) > 40 and s_clean.lower() not in seen:
217
+ sentences.append(s_clean)
218
+ seen.add(s_clean.lower())
219
+ if len(sentences) >= 6:
220
+ break
221
+
222
+ if not sentences:
223
+ return context[:800]
224
+
225
+ response = "## Summary\n\n"
226
+ for i, point in enumerate(sentences, 1):
227
+ cite = self._get_citation(i-1, metadatas) if metadatas else ""
228
+ response += f"{i}. {point}.{cite}\n\n"
229
+
230
+ return response.strip()
231
+
232
+ def _create_qa_with_citations(
233
+ self,
234
+ context: str,
235
+ query: str,
236
+ metadatas: List[Dict] = None
237
+ ) -> str:
238
+ """Answer with strict source grounding."""
239
+ paragraphs = [p.strip() for p in context.split('\n\n') if len(p.strip()) > 50]
240
+
241
+ if not paragraphs:
242
+ sentences = [s.strip() + '.' for s in context.split('.') if len(s.strip()) > 30]
243
+ response = ' '.join(sentences[:6])
244
+ cite = self._get_citation(0, metadatas) if metadatas else ""
245
+ return response + cite
246
+
247
+ # Remove duplicates
248
+ unique_paras = []
249
+ seen = set()
250
+ for para in paragraphs:
251
+ para_key = para.lower()[:150]
252
+ if para_key not in seen:
253
+ unique_paras.append(para)
254
+ seen.add(para_key)
255
+ if len(unique_paras) >= 3:
256
+ break
257
+
258
+ # Fix spacing and add citations
259
+ response = unique_paras[0] if unique_paras else context[:800]
260
+ response = re.sub(r'([.!?])([A-Z])', r'\1 \2', response)
261
+ cite = self._get_citation(0, metadatas) if metadatas else ""
262
+ response += cite
263
+
264
+ # Add supporting details if available
265
+ if len(unique_paras) > 1:
266
+ second_para = re.sub(r'([.!?])([A-Z])', r'\1 \2', unique_paras[1])
267
+ cite2 = self._get_citation(1, metadatas) if metadatas and len(metadatas) > 1 else ""
268
+ response += "\n\n" + second_para + cite2
269
+
270
+ return response.strip()
271
+
272
+ def _create_notes_with_citations(
273
+ self,
274
+ context: str,
275
+ query: str,
276
+ metadatas: List[Dict] = None
277
+ ) -> str:
278
+ """Create study notes with source attribution."""
279
+ sections = [s.strip() for s in context.split('\n\n') if len(s.strip()) > 40]
280
+
281
+ # Remove duplicates
282
+ unique_sections = []
283
+ seen = set()
284
+ for section in sections:
285
+ section_key = section.lower()[:100]
286
+ if section_key not in seen:
287
+ unique_sections.append(section)
288
+ seen.add(section_key)
289
+ if len(unique_sections) >= 6:
290
+ break
291
+
292
+ if not unique_sections:
293
+ return context[:1000]
294
+
295
+ response = "## Study Notes\n\n"
296
+
297
+ for i, section in enumerate(unique_sections, 1):
298
+ sentences = [s.strip() for s in section.split('.') if len(s.strip()) > 20]
299
+
300
+ if sentences:
301
+ heading = sentences[0]
302
+ cite = self._get_citation(i-1, metadatas) if metadatas else ""
303
+ response += f"### {i}. {heading}{cite}\n\n"
304
+
305
+ for sent in sentences[1:3]:
306
+ response += f"- {sent}\n"
307
+ response += "\n"
308
+
309
+ return response.strip()
310
+
311
+ def _create_summary(self, context: str, query: str) -> str:
312
+ """Create a clean summary from retrieved context."""
313
+ # Extract key sentences - remove duplicates
314
+ sentences = []
315
+ seen = set()
316
+ for s in context.split('.'):
317
+ s_clean = s.strip()
318
+ # Remove duplicates and filter short/low-quality sentences
319
+ if len(s_clean) > 40 and s_clean.lower() not in seen:
320
+ sentences.append(s_clean)
321
+ seen.add(s_clean.lower())
322
+ if len(sentences) >= 6:
323
+ break
324
+
325
+ if not sentences:
326
+ return context[:800]
327
+
328
+ response = "## Summary\n\n"
329
+ for i, point in enumerate(sentences, 1):
330
+ response += f"{i}. {point}.\n\n"
331
+
332
+ return response.strip()
333
+
334
+ def _create_explanation(self, context: str, query: str) -> str:
335
+ """Create a well-formatted explanation from retrieved context."""
336
+ # Remove duplicate paragraphs
337
+ paragraphs = []
338
+ seen = set()
339
+ for para in context.split('\n\n'):
340
+ para_clean = para.strip()
341
+ # Keep unique, substantial paragraphs
342
+ if len(para_clean) > 50:
343
+ para_lower = para_clean.lower()[:200] # Check first 200 chars for duplicates
344
+ if para_lower not in seen:
345
+ paragraphs.append(para_clean)
346
+ seen.add(para_lower)
347
+
348
+ if not paragraphs:
349
+ # Fallback: split by sentence
350
+ sentences = [s.strip() + '.' for s in context.split('.') if len(s.strip()) > 30]
351
+ return ' '.join(sentences[:8])
352
+
353
+ # Build clean, formatted response with proper spacing
354
+ response = ""
355
+
356
+ # Add first paragraph as main explanation (ensure spacing between sentences)
357
+ first_para = paragraphs[0]
358
+ # Add space after punctuation if missing
359
+ import re
360
+ first_para = re.sub(r'([.!?])([A-Z])', r'\1 \2', first_para)
361
+ response += first_para
362
+
363
+ # Add additional details if available
364
+ if len(paragraphs) > 1:
365
+ response += "\n\n### Key Points:\n\n"
366
+ for i, para in enumerate(paragraphs[1:4], 1): # Max 3 additional points
367
+ # Extract first sentence as bullet
368
+ sentences = [s.strip() for s in para.split('.') if len(s.strip()) > 20]
369
+ if sentences:
370
+ response += f"• {sentences[0]}.\n"
371
+ if len(sentences) > 1 and len(sentences[1]) > 20:
372
+ response += f" {sentences[1]}.\n"
373
+ response += "\n"
374
+
375
+ return response.strip()
376
+
377
+ def _create_qa(self, context: str, query: str) -> str:
378
+ """Answer a question with clean formatting."""
379
+ # Find most relevant paragraphs
380
+ paragraphs = [p.strip() for p in context.split('\n\n') if len(p.strip()) > 50]
381
+
382
+ if not paragraphs:
383
+ sentences = [s.strip() + '.' for s in context.split('.') if len(s.strip()) > 30]
384
+ return ' '.join(sentences[:6])
385
+
386
+ # Remove duplicates
387
+ unique_paras = []
388
+ seen = set()
389
+ for para in paragraphs:
390
+ para_key = para.lower()[:150]
391
+ if para_key not in seen:
392
+ unique_paras.append(para)
393
+ seen.add(para_key)
394
+ if len(unique_paras) >= 3:
395
+ break
396
+
397
+ # Fix spacing in response
398
+ import re
399
+ response = unique_paras[0] if unique_paras else context[:800]
400
+ response = re.sub(r'([.!?])([A-Z])', r'\1 \2', response)
401
+
402
+ # Add supporting details if available
403
+ if len(unique_paras) > 1:
404
+ second_para = re.sub(r'([.!?])([A-Z])', r'\1 \2', unique_paras[1])
405
+ response += "\n\n" + second_para
406
+
407
+ return response.strip()
408
+
409
+ def _create_notes(self, context: str, query: str) -> str:
410
+ """Create well-structured study notes."""
411
+ # Split and clean sections
412
+ sections = [s.strip() for s in context.split('\n\n') if len(s.strip()) > 40]
413
+
414
+ # Remove duplicates
415
+ unique_sections = []
416
+ seen = set()
417
+ for section in sections:
418
+ section_key = section.lower()[:100]
419
+ if section_key not in seen:
420
+ unique_sections.append(section)
421
+ seen.add(section_key)
422
+ if len(unique_sections) >= 6:
423
+ break
424
+
425
+ if not unique_sections:
426
+ return context[:1000]
427
+
428
+ response = "## Study Notes\n\n"
429
+
430
+ for i, section in enumerate(unique_sections, 1):
431
+ # Extract key information
432
+ sentences = [s.strip() for s in section.split('.') if len(s.strip()) > 20]
433
+
434
+ if sentences:
435
+ # Use first sentence as heading
436
+ heading = sentences[0]
437
+ response += f"### {i}. {heading}\n\n"
438
+
439
+ # Add bullet points for remaining content
440
+ for sent in sentences[1:3]: # Max 2 additional sentences
441
+ response += f"- {sent}\n"
442
+ response += "\n"
443
+
444
+ return response.strip()
utils/spaces_manager.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Spaces (Workspaces) manager for organizing chats and files by subject.
3
+ Each space has its own vector DB and chat history.
4
+ """
5
+
6
+ import json
7
+ from pathlib import Path
8
+ from typing import List, Dict, Optional
9
+ from datetime import datetime
10
+ import config
11
+
12
+
13
+ class SpacesManager:
14
+ """Manages workspaces (Spaces) for organizing study materials by subject."""
15
+
16
+ def __init__(self):
17
+ self.spaces_file = config.DATA_DIR / "spaces.json"
18
+ self.spaces_data = self._load_spaces()
19
+
20
+ def _load_spaces(self) -> Dict:
21
+ """Load spaces from file."""
22
+ if self.spaces_file.exists():
23
+ try:
24
+ with open(self.spaces_file, 'r', encoding='utf-8') as f:
25
+ return json.load(f)
26
+ except Exception:
27
+ return self._create_default_spaces()
28
+ return self._create_default_spaces()
29
+
30
+ def _create_default_spaces(self) -> Dict:
31
+ """Create default spaces structure."""
32
+ return {
33
+ "spaces": [
34
+ {
35
+ "id": "general",
36
+ "name": "General",
37
+ "description": "General study materials",
38
+ "created_at": datetime.now().isoformat(),
39
+ "file_count": 0,
40
+ "chat_count": 0
41
+ }
42
+ ]
43
+ }
44
+
45
+ def save_spaces(self):
46
+ """Save spaces to file."""
47
+ try:
48
+ with open(self.spaces_file, 'w', encoding='utf-8') as f:
49
+ json.dump(self.spaces_data, f, indent=2)
50
+ except Exception as e:
51
+ print(f"Error saving spaces: {e}")
52
+
53
+ def get_all_spaces(self) -> List[Dict]:
54
+ """Get all spaces."""
55
+ return self.spaces_data.get("spaces", [])
56
+
57
+ def get_space(self, space_id: str) -> Optional[Dict]:
58
+ """Get specific space by ID."""
59
+ for space in self.spaces_data.get("spaces", []):
60
+ if space["id"] == space_id:
61
+ return space
62
+ return None
63
+
64
+ def create_space(self, name: str, description: str = "") -> Dict:
65
+ """Create a new space."""
66
+ space_id = name.lower().replace(" ", "_")
67
+
68
+ # Check if space already exists
69
+ if self.get_space(space_id):
70
+ raise ValueError(f"Space '{name}' already exists")
71
+
72
+ new_space = {
73
+ "id": space_id,
74
+ "name": name,
75
+ "description": description,
76
+ "created_at": datetime.now().isoformat(),
77
+ "file_count": 0,
78
+ "chat_count": 0
79
+ }
80
+
81
+ self.spaces_data["spaces"].append(new_space)
82
+ self.save_spaces()
83
+
84
+ # Create dedicated directories for this space
85
+ space_dir = config.DATA_DIR / "spaces" / space_id
86
+ space_dir.mkdir(parents=True, exist_ok=True)
87
+ (space_dir / "chats").mkdir(exist_ok=True)
88
+ (space_dir / "vector_db").mkdir(exist_ok=True)
89
+ (space_dir / "uploads").mkdir(exist_ok=True)
90
+
91
+ return new_space
92
+
93
+ def delete_space(self, space_id: str):
94
+ """Delete a space (except General)."""
95
+ if space_id == "general":
96
+ raise ValueError("Cannot delete General space")
97
+
98
+ self.spaces_data["spaces"] = [
99
+ s for s in self.spaces_data["spaces"]
100
+ if s["id"] != space_id
101
+ ]
102
+ self.save_spaces()
103
+
104
+ def update_space_counts(self, space_id: str, file_count: int = None, chat_count: int = None):
105
+ """Update file/chat counts for a space."""
106
+ space = self.get_space(space_id)
107
+ if space:
108
+ if file_count is not None:
109
+ space["file_count"] = file_count
110
+ if chat_count is not None:
111
+ space["chat_count"] = chat_count
112
+ self.save_spaces()
113
+
114
+ def get_space_chats_dir(self, space_id: str) -> Path:
115
+ """Get chats directory for a space."""
116
+ return config.DATA_DIR / "spaces" / space_id / "chats"
117
+
118
+ def get_space_vector_db_dir(self, space_id: str) -> Path:
119
+ """Get vector DB directory for a space."""
120
+ return config.DATA_DIR / "spaces" / space_id / "vector_db"
121
+
122
+ def get_space_uploads_dir(self, space_id: str) -> Path:
123
+ """Get uploads directory for a space."""
124
+ return config.DATA_DIR / "spaces" / space_id / "uploads"
utils/studio_generator.py ADDED
@@ -0,0 +1,309 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Studio Generator - Uses LLM to generate flashcards and quiz questions
3
+ """
4
+ from typing import List, Optional
5
+ import json
6
+ import re
7
+
8
+ from models.studio_models import (
9
+ Flashcard, FlashcardCreate, FlashcardGenerateRequest,
10
+ Quiz, QuizQuestion, QuizGenerateRequest,
11
+ DifficultyLevel, QuestionType
12
+ )
13
+ from utils.llm_generator import LLMGenerator
14
+ from utils.studio_manager import StudioManager
15
+
16
+
17
+ class StudioGenerator:
18
+ """Generate flashcards and quizzes using LLM"""
19
+
20
+ def __init__(self, llm_generator: LLMGenerator, studio_manager: StudioManager):
21
+ self.llm = llm_generator
22
+ self.studio = studio_manager
23
+
24
+ async def generate_flashcards(self, request: FlashcardGenerateRequest) -> List[Flashcard]:
25
+ """Generate flashcards from content using LLM"""
26
+
27
+ # Gather source content
28
+ content = await self._gather_content(
29
+ request.space_id,
30
+ request.source_type,
31
+ request.source_ids,
32
+ request.text_content
33
+ )
34
+
35
+ if not content:
36
+ return []
37
+
38
+ # Create prompt for LLM
39
+ prompt = self._create_flashcard_prompt(content, request.num_cards, request.difficulty)
40
+
41
+ # Generate flashcards using LLM
42
+ response = await self.llm.generate(prompt, max_tokens=2000)
43
+
44
+ if not response:
45
+ return []
46
+
47
+ # Parse flashcards from response
48
+ flashcards = self._parse_flashcards(
49
+ response,
50
+ request.space_id,
51
+ request.source_type,
52
+ request.source_ids,
53
+ request.difficulty
54
+ )
55
+
56
+ # Save flashcards to storage
57
+ saved_cards = []
58
+ for card_data in flashcards:
59
+ card = self.studio.create_flashcard(card_data)
60
+ saved_cards.append(card)
61
+
62
+ return saved_cards
63
+
64
+ async def generate_quiz(self, request: QuizGenerateRequest) -> Optional[Quiz]:
65
+ """Generate a quiz from content using LLM"""
66
+
67
+ # Gather source content
68
+ content = await self._gather_content(
69
+ request.space_id,
70
+ request.source_type,
71
+ request.source_ids,
72
+ request.text_content
73
+ )
74
+
75
+ if not content:
76
+ return None
77
+
78
+ # Create prompt for LLM
79
+ prompt = self._create_quiz_prompt(
80
+ content,
81
+ request.num_questions,
82
+ request.question_types,
83
+ request.difficulty
84
+ )
85
+
86
+ # Generate quiz using LLM
87
+ response = await self.llm.generate(prompt, max_tokens=3000)
88
+
89
+ if not response:
90
+ return None
91
+
92
+ # Parse quiz questions from response
93
+ questions = self._parse_quiz_questions(response, request.question_types, request.difficulty)
94
+
95
+ if not questions:
96
+ return None
97
+
98
+ # Create quiz
99
+ from models.studio_models import QuizCreate
100
+ quiz_data = QuizCreate(
101
+ space_id=request.space_id,
102
+ title=request.title,
103
+ description=f"Generated quiz with {len(questions)} questions",
104
+ questions=questions,
105
+ source_type=request.source_type,
106
+ source_ids=request.source_ids
107
+ )
108
+
109
+ quiz = self.studio.create_quiz(quiz_data)
110
+ return quiz
111
+
112
+ async def _gather_content(
113
+ self,
114
+ space_id: str,
115
+ source_type: str,
116
+ source_ids: Optional[List[str]],
117
+ text_content: Optional[str]
118
+ ) -> str:
119
+ """Gather content from various sources"""
120
+
121
+ if text_content:
122
+ return text_content
123
+
124
+ content_parts = []
125
+
126
+ if source_type == "notebook" and source_ids:
127
+ # Get notebook entries
128
+ for entry_id in source_ids:
129
+ entry = self.studio.get_notebook_entry(entry_id)
130
+ if entry:
131
+ content_parts.append(f"# {entry.title}\n\n{entry.content}")
132
+
133
+ elif source_type == "file" and source_ids:
134
+ # TODO: Integrate with file retriever to get file content
135
+ # For now, just return a placeholder
136
+ content_parts.append("File content retrieval not yet implemented")
137
+
138
+ return "\n\n---\n\n".join(content_parts)
139
+
140
+ def _create_flashcard_prompt(self, content: str, num_cards: int, difficulty: DifficultyLevel) -> str:
141
+ """Create prompt for flashcard generation"""
142
+
143
+ difficulty_desc = {
144
+ DifficultyLevel.EASY: "basic concepts and definitions",
145
+ DifficultyLevel.MEDIUM: "key concepts and applications",
146
+ DifficultyLevel.HARD: "advanced concepts and critical thinking"
147
+ }
148
+
149
+ prompt = f"""Based on the following content, create {num_cards} flashcards focusing on {difficulty_desc[difficulty]}.
150
+
151
+ Content:
152
+ {content[:3000]} # Limit content length
153
+
154
+ Format your response as a JSON array of flashcards, where each flashcard has:
155
+ - "question": The question or prompt (front of card)
156
+ - "answer": The answer or explanation (back of card)
157
+
158
+ Example format:
159
+ [
160
+ {{"question": "What is...", "answer": "It is..."}},
161
+ {{"question": "How does...", "answer": "It works by..."}}
162
+ ]
163
+
164
+ Generate exactly {num_cards} flashcards:"""
165
+
166
+ return prompt
167
+
168
+ def _create_quiz_prompt(
169
+ self,
170
+ content: str,
171
+ num_questions: int,
172
+ question_types: List[QuestionType],
173
+ difficulty: DifficultyLevel
174
+ ) -> str:
175
+ """Create prompt for quiz generation"""
176
+
177
+ types_str = ", ".join(qt.value for qt in question_types)
178
+
179
+ prompt = f"""Based on the following content, create a quiz with {num_questions} questions.
180
+
181
+ Content:
182
+ {content[:3000]} # Limit content length
183
+
184
+ Question types to include: {types_str}
185
+ Difficulty level: {difficulty.value}
186
+
187
+ Format your response as a JSON array of questions, where each question has:
188
+ - "question": The question text
189
+ - "type": One of: {types_str}
190
+ - "options": Array of 4 options (for multiple_choice only)
191
+ - "correct_answer": The correct answer
192
+ - "explanation": Brief explanation of why this is correct
193
+
194
+ Example format:
195
+ [
196
+ {{
197
+ "question": "What is...",
198
+ "type": "multiple_choice",
199
+ "options": ["Option A", "Option B", "Option C", "Option D"],
200
+ "correct_answer": "Option A",
201
+ "explanation": "This is correct because..."
202
+ }},
203
+ {{
204
+ "question": "True or False: ...",
205
+ "type": "true_false",
206
+ "options": ["True", "False"],
207
+ "correct_answer": "True",
208
+ "explanation": "This is true because..."
209
+ }}
210
+ ]
211
+
212
+ Generate exactly {num_questions} questions:"""
213
+
214
+ return prompt
215
+
216
+ def _parse_flashcards(
217
+ self,
218
+ response: str,
219
+ space_id: str,
220
+ source_type: str,
221
+ source_ids: Optional[List[str]],
222
+ difficulty: DifficultyLevel
223
+ ) -> List[FlashcardCreate]:
224
+ """Parse flashcards from LLM response"""
225
+
226
+ flashcards = []
227
+
228
+ try:
229
+ # Try to extract JSON from response
230
+ json_match = re.search(r'\[[\s\S]*\]', response)
231
+ if json_match:
232
+ cards_data = json.loads(json_match.group(0))
233
+
234
+ for card_data in cards_data:
235
+ if 'question' in card_data and 'answer' in card_data:
236
+ flashcards.append(FlashcardCreate(
237
+ space_id=space_id,
238
+ question=card_data['question'],
239
+ answer=card_data['answer'],
240
+ difficulty=difficulty,
241
+ source_type=source_type,
242
+ source_id=source_ids[0] if source_ids else None
243
+ ))
244
+ except Exception as e:
245
+ print(f"Error parsing flashcards: {e}")
246
+ # Fallback: Try to parse as simple Q&A pairs
247
+ lines = response.split('\n')
248
+ current_question = None
249
+
250
+ for line in lines:
251
+ line = line.strip()
252
+ if line.startswith('Q:') or line.startswith('Question:'):
253
+ current_question = line.split(':', 1)[1].strip()
254
+ elif line.startswith('A:') or line.startswith('Answer:'):
255
+ if current_question:
256
+ answer = line.split(':', 1)[1].strip()
257
+ flashcards.append(FlashcardCreate(
258
+ space_id=space_id,
259
+ question=current_question,
260
+ answer=answer,
261
+ difficulty=difficulty,
262
+ source_type=source_type,
263
+ source_id=source_ids[0] if source_ids else None
264
+ ))
265
+ current_question = None
266
+
267
+ return flashcards
268
+
269
+ def _parse_quiz_questions(
270
+ self,
271
+ response: str,
272
+ question_types: List[QuestionType],
273
+ difficulty: DifficultyLevel
274
+ ) -> List[QuizQuestion]:
275
+ """Parse quiz questions from LLM response"""
276
+
277
+ questions = []
278
+
279
+ try:
280
+ # Try to extract JSON from response
281
+ json_match = re.search(r'\[[\s\S]*\]', response)
282
+ if json_match:
283
+ questions_data = json.loads(json_match.group(0))
284
+
285
+ for idx, q_data in enumerate(questions_data):
286
+ import uuid
287
+
288
+ # Parse question type
289
+ q_type = QuestionType.MULTIPLE_CHOICE
290
+ if 'type' in q_data:
291
+ try:
292
+ q_type = QuestionType(q_data['type'])
293
+ except ValueError:
294
+ q_type = QuestionType.MULTIPLE_CHOICE
295
+
296
+ questions.append(QuizQuestion(
297
+ id=str(uuid.uuid4()),
298
+ question=q_data.get('question', ''),
299
+ type=q_type,
300
+ options=q_data.get('options'),
301
+ correct_answer=q_data.get('correct_answer', ''),
302
+ explanation=q_data.get('explanation'),
303
+ points=1,
304
+ difficulty=difficulty
305
+ ))
306
+ except Exception as e:
307
+ print(f"Error parsing quiz questions: {e}")
308
+
309
+ return questions
utils/studio_manager.py ADDED
@@ -0,0 +1,473 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Studio Manager - Handles Notebook, Flashcards, and Quiz storage and operations
3
+ """
4
+ import json
5
+ from pathlib import Path
6
+ from typing import List, Optional, Dict, Any
7
+ from datetime import datetime, timedelta
8
+ import uuid
9
+
10
+ from models.studio_models import (
11
+ NotebookEntry, NotebookEntryCreate, NotebookEntryUpdate,
12
+ Flashcard, FlashcardCreate, FlashcardUpdate, FlashcardReview,
13
+ Quiz, QuizCreate, QuizResult, QuizHistory, QuizAnswer,
14
+ MasteryLevel, DifficultyLevel
15
+ )
16
+ import config
17
+
18
+
19
+ class StudioManager:
20
+ """Manages all Studio features: Notebook, Flashcards, Quiz"""
21
+
22
+ def __init__(self):
23
+ """Initialize studio manager with data directories"""
24
+ self.studio_dir = config.DATA_DIR / "studio"
25
+ self.notebooks_dir = self.studio_dir / "notebooks"
26
+ self.notebook_dir = self.studio_dir / "notebook"
27
+ self.flashcards_dir = self.studio_dir / "flashcards"
28
+ self.quizzes_dir = self.studio_dir / "quizzes"
29
+ self.quiz_results_dir = self.studio_dir / "quiz_results"
30
+
31
+ # Create directories
32
+ for directory in [self.notebooks_dir, self.notebook_dir, self.flashcards_dir,
33
+ self.quizzes_dir, self.quiz_results_dir]:
34
+ directory.mkdir(parents=True, exist_ok=True)
35
+
36
+ def _get_notebook_file_path(self, space_id: str) -> Path:
37
+ """Get the metadata file path for a space notebook."""
38
+ return self.notebooks_dir / f"{space_id}.json"
39
+
40
+ def ensure_space_notebook(self, space_id: str, space_name: str = "") -> Dict[str, Any]:
41
+ """Create notebook metadata for a space if it does not exist."""
42
+ file_path = self._get_notebook_file_path(space_id)
43
+
44
+ if file_path.exists():
45
+ with open(file_path, 'r', encoding='utf-8') as f:
46
+ return json.load(f)
47
+
48
+ now = datetime.now().isoformat()
49
+ notebook_name = space_name.strip() if space_name and space_name.strip() else space_id
50
+ notebook_data = {
51
+ "id": space_id,
52
+ "space_id": space_id,
53
+ "name": notebook_name,
54
+ "created_at": now,
55
+ "updated_at": now
56
+ }
57
+
58
+ with open(file_path, 'w', encoding='utf-8') as f:
59
+ json.dump(notebook_data, f, indent=2)
60
+
61
+ return notebook_data
62
+
63
+ def get_space_notebook(self, space_id: str) -> Optional[Dict[str, Any]]:
64
+ """Get notebook metadata for a specific space."""
65
+ file_path = self._get_notebook_file_path(space_id)
66
+ if not file_path.exists():
67
+ return None
68
+
69
+ with open(file_path, 'r', encoding='utf-8') as f:
70
+ return json.load(f)
71
+
72
+ def _derive_title_from_question(self, question: str) -> str:
73
+ """Generate a readable title from a chat question."""
74
+ question = (question or "").strip()
75
+ if not question:
76
+ return "Chat Note"
77
+
78
+ title = question.replace('\n', ' ')
79
+ return title[:80] + "..." if len(title) > 80 else title
80
+
81
+ # ========================================================================
82
+ # NOTEBOOK OPERATIONS
83
+ # ========================================================================
84
+
85
+ def create_notebook_entry(self, entry_data: NotebookEntryCreate) -> NotebookEntry:
86
+ """Create a new notebook entry"""
87
+ # Ensure a notebook record exists for this space.
88
+ self.ensure_space_notebook(entry_data.space_id)
89
+
90
+ entry = NotebookEntry(
91
+ id=str(uuid.uuid4()),
92
+ **entry_data.dict()
93
+ )
94
+
95
+ # Save to file
96
+ file_path = self.notebook_dir / f"{entry.id}.json"
97
+ with open(file_path, 'w', encoding='utf-8') as f:
98
+ json.dump(entry.dict(), f, indent=2, default=str)
99
+
100
+ return entry
101
+
102
+ def create_notebook_entry_from_chat(
103
+ self,
104
+ space_id: str,
105
+ question: str,
106
+ answer: str,
107
+ chat_id: Optional[str] = None,
108
+ assistant_timestamp: Optional[str] = None,
109
+ tags: Optional[List[str]] = None,
110
+ space_name: str = ""
111
+ ) -> NotebookEntry:
112
+ """Create a notebook entry from a chat Q/A pair."""
113
+ self.ensure_space_notebook(space_id, space_name=space_name)
114
+
115
+ metadata: Dict[str, Any] = {
116
+ "question": question,
117
+ "assistant_timestamp": assistant_timestamp,
118
+ }
119
+ if chat_id:
120
+ metadata["chat_id"] = chat_id
121
+
122
+ entry_data = NotebookEntryCreate(
123
+ space_id=space_id,
124
+ title=self._derive_title_from_question(question),
125
+ content=f"Q: {question.strip()}\n\nA: {answer.strip()}",
126
+ source_type="chat",
127
+ source_id=chat_id,
128
+ tags=tags or ["chat"],
129
+ metadata=metadata
130
+ )
131
+
132
+ entry = self.create_notebook_entry(entry_data)
133
+
134
+ # Update notebook metadata timestamp.
135
+ notebook_data = self.ensure_space_notebook(space_id, space_name=space_name)
136
+ notebook_data["updated_at"] = datetime.now().isoformat()
137
+ with open(self._get_notebook_file_path(space_id), 'w', encoding='utf-8') as f:
138
+ json.dump(notebook_data, f, indent=2)
139
+
140
+ return entry
141
+
142
+ def get_notebook_entry(self, entry_id: str) -> Optional[NotebookEntry]:
143
+ """Get a single notebook entry by ID"""
144
+ file_path = self.notebook_dir / f"{entry_id}.json"
145
+ if not file_path.exists():
146
+ return None
147
+
148
+ with open(file_path, 'r', encoding='utf-8') as f:
149
+ data = json.load(f)
150
+
151
+ return NotebookEntry(**data)
152
+
153
+ def list_notebook_entries(self, space_id: Optional[str] = None) -> List[NotebookEntry]:
154
+ """List all notebook entries, optionally filtered by space"""
155
+ entries = []
156
+
157
+ for file_path in self.notebook_dir.glob("*.json"):
158
+ try:
159
+ with open(file_path, 'r', encoding='utf-8') as f:
160
+ data = json.load(f)
161
+
162
+ entry = NotebookEntry(**data)
163
+
164
+ # Filter by space if specified
165
+ if space_id is None or entry.space_id == space_id:
166
+ entries.append(entry)
167
+ except Exception as e:
168
+ print(f"Error loading notebook entry {file_path}: {e}")
169
+
170
+ # Sort by updated_at descending
171
+ entries.sort(key=lambda x: x.updated_at, reverse=True)
172
+ return entries
173
+
174
+ def update_notebook_entry(self, entry_id: str, update_data: NotebookEntryUpdate) -> Optional[NotebookEntry]:
175
+ """Update an existing notebook entry"""
176
+ entry = self.get_notebook_entry(entry_id)
177
+ if not entry:
178
+ return None
179
+
180
+ # Update fields
181
+ update_dict = update_data.dict(exclude_unset=True)
182
+ for key, value in update_dict.items():
183
+ setattr(entry, key, value)
184
+
185
+ entry.updated_at = datetime.now()
186
+
187
+ # Save
188
+ file_path = self.notebook_dir / f"{entry_id}.json"
189
+ with open(file_path, 'w', encoding='utf-8') as f:
190
+ json.dump(entry.dict(), f, indent=2, default=str)
191
+
192
+ return entry
193
+
194
+ def delete_notebook_entry(self, entry_id: str) -> bool:
195
+ """Delete a notebook entry"""
196
+ file_path = self.notebook_dir / f"{entry_id}.json"
197
+ if file_path.exists():
198
+ file_path.unlink()
199
+ return True
200
+ return False
201
+
202
+ # ========================================================================
203
+ # FLASHCARD OPERATIONS
204
+ # ========================================================================
205
+
206
+ def create_flashcard(self, card_data: FlashcardCreate) -> Flashcard:
207
+ """Create a new flashcard"""
208
+ card = Flashcard(
209
+ id=str(uuid.uuid4()),
210
+ **card_data.dict()
211
+ )
212
+
213
+ # Save to file
214
+ file_path = self.flashcards_dir / f"{card.id}.json"
215
+ with open(file_path, 'w', encoding='utf-8') as f:
216
+ json.dump(card.dict(), f, indent=2, default=str)
217
+
218
+ return card
219
+
220
+ def get_flashcard(self, card_id: str) -> Optional[Flashcard]:
221
+ """Get a single flashcard by ID"""
222
+ file_path = self.flashcards_dir / f"{card_id}.json"
223
+ if not file_path.exists():
224
+ return None
225
+
226
+ with open(file_path, 'r', encoding='utf-8') as f:
227
+ data = json.load(f)
228
+
229
+ return Flashcard(**data)
230
+
231
+ def list_flashcards(self, space_id: Optional[str] = None,
232
+ mastery: Optional[MasteryLevel] = None) -> List[Flashcard]:
233
+ """List all flashcards, optionally filtered"""
234
+ cards = []
235
+
236
+ for file_path in self.flashcards_dir.glob("*.json"):
237
+ try:
238
+ with open(file_path, 'r', encoding='utf-8') as f:
239
+ data = json.load(f)
240
+
241
+ card = Flashcard(**data)
242
+
243
+ # Apply filters
244
+ if space_id and card.space_id != space_id:
245
+ continue
246
+ if mastery and card.mastery != mastery:
247
+ continue
248
+
249
+ cards.append(card)
250
+ except Exception as e:
251
+ print(f"Error loading flashcard {file_path}: {e}")
252
+
253
+ # Sort by next_review date (cards due for review first)
254
+ cards.sort(key=lambda x: x.next_review or datetime.now())
255
+ return cards
256
+
257
+ def update_flashcard(self, card_id: str, update_data: FlashcardUpdate) -> Optional[Flashcard]:
258
+ """Update a flashcard"""
259
+ card = self.get_flashcard(card_id)
260
+ if not card:
261
+ return None
262
+
263
+ # Update fields
264
+ update_dict = update_data.dict(exclude_unset=True)
265
+ for key, value in update_dict.items():
266
+ setattr(card, key, value)
267
+
268
+ # Save
269
+ file_path = self.flashcards_dir / f"{card_id}.json"
270
+ with open(file_path, 'w', encoding='utf-8') as f:
271
+ json.dump(card.dict(), f, indent=2, default=str)
272
+
273
+ return card
274
+
275
+ def review_flashcard(self, card_id: str, review: FlashcardReview) -> Optional[Flashcard]:
276
+ """Record a flashcard review and update mastery level"""
277
+ card = self.get_flashcard(card_id)
278
+ if not card:
279
+ return None
280
+
281
+ # Update review stats
282
+ card.review_count += 1
283
+ if review.correct:
284
+ card.correct_count += 1
285
+
286
+ card.last_reviewed = datetime.now()
287
+
288
+ # Update mastery level based on performance
289
+ accuracy = card.correct_count / card.review_count if card.review_count > 0 else 0
290
+
291
+ if accuracy >= 0.9 and card.review_count >= 5:
292
+ card.mastery = MasteryLevel.MASTERED
293
+ card.next_review = datetime.now() + timedelta(days=30)
294
+ elif accuracy >= 0.7 and card.review_count >= 3:
295
+ card.mastery = MasteryLevel.REVIEWING
296
+ card.next_review = datetime.now() + timedelta(days=7)
297
+ elif card.review_count >= 1:
298
+ card.mastery = MasteryLevel.LEARNING
299
+ card.next_review = datetime.now() + timedelta(days=1)
300
+ else:
301
+ card.mastery = MasteryLevel.NEW
302
+ card.next_review = datetime.now()
303
+
304
+ # Save
305
+ file_path = self.flashcards_dir / f"{card_id}.json"
306
+ with open(file_path, 'w', encoding='utf-8') as f:
307
+ json.dump(card.dict(), f, indent=2, default=str)
308
+
309
+ return card
310
+
311
+ def delete_flashcard(self, card_id: str) -> bool:
312
+ """Delete a flashcard"""
313
+ file_path = self.flashcards_dir / f"{card_id}.json"
314
+ if file_path.exists():
315
+ file_path.unlink()
316
+ return True
317
+ return False
318
+
319
+ # ========================================================================
320
+ # QUIZ OPERATIONS
321
+ # ========================================================================
322
+
323
+ def create_quiz(self, quiz_data: QuizCreate) -> Quiz:
324
+ """Create a new quiz"""
325
+ quiz = Quiz(
326
+ id=str(uuid.uuid4()),
327
+ **quiz_data.dict()
328
+ )
329
+
330
+ # Save to file
331
+ file_path = self.quizzes_dir / f"{quiz.id}.json"
332
+ with open(file_path, 'w', encoding='utf-8') as f:
333
+ json.dump(quiz.dict(), f, indent=2, default=str)
334
+
335
+ return quiz
336
+
337
+ def get_quiz(self, quiz_id: str) -> Optional[Quiz]:
338
+ """Get a quiz by ID"""
339
+ file_path = self.quizzes_dir / f"{quiz_id}.json"
340
+ if not file_path.exists():
341
+ return None
342
+
343
+ with open(file_path, 'r', encoding='utf-8') as f:
344
+ data = json.load(f)
345
+
346
+ return Quiz(**data)
347
+
348
+ def list_quizzes(self, space_id: Optional[str] = None) -> List[Quiz]:
349
+ """List all quizzes, optionally filtered by space"""
350
+ quizzes = []
351
+
352
+ for file_path in self.quizzes_dir.glob("*.json"):
353
+ try:
354
+ with open(file_path, 'r', encoding='utf-8') as f:
355
+ data = json.load(f)
356
+
357
+ quiz = Quiz(**data)
358
+
359
+ if space_id is None or quiz.space_id == space_id:
360
+ quizzes.append(quiz)
361
+ except Exception as e:
362
+ print(f"Error loading quiz {file_path}: {e}")
363
+
364
+ # Sort by created_at descending
365
+ quizzes.sort(key=lambda x: x.created_at, reverse=True)
366
+ return quizzes
367
+
368
+ def delete_quiz(self, quiz_id: str) -> bool:
369
+ """Delete a quiz"""
370
+ file_path = self.quizzes_dir / f"{quiz_id}.json"
371
+ if file_path.exists():
372
+ file_path.unlink()
373
+ return True
374
+ return False
375
+
376
+ def submit_quiz(self, quiz_id: str, answers: List[QuizAnswer]) -> Optional[QuizResult]:
377
+ """Submit quiz answers and calculate results"""
378
+ quiz = self.get_quiz(quiz_id)
379
+ if not quiz:
380
+ return None
381
+
382
+ # Create answer lookup
383
+ answer_dict = {ans.question_id: ans for ans in answers}
384
+
385
+ # Calculate results
386
+ total_points = sum(q.points for q in quiz.questions)
387
+ correct_count = 0
388
+ incorrect_count = 0
389
+ earned_points = 0
390
+ detailed_answers = []
391
+
392
+ for question in quiz.questions:
393
+ user_answer = answer_dict.get(question.id)
394
+ is_correct = False
395
+
396
+ if user_answer:
397
+ # Normalize answers for comparison
398
+ correct_ans = question.correct_answer.strip().lower()
399
+ user_ans = user_answer.answer.strip().lower()
400
+
401
+ is_correct = correct_ans == user_ans
402
+
403
+ if is_correct:
404
+ correct_count += 1
405
+ earned_points += question.points
406
+ else:
407
+ incorrect_count += 1
408
+ else:
409
+ incorrect_count += 1
410
+
411
+ detailed_answers.append({
412
+ "question_id": question.id,
413
+ "question": question.question,
414
+ "user_answer": user_answer.answer if user_answer else None,
415
+ "correct_answer": question.correct_answer,
416
+ "is_correct": is_correct,
417
+ "explanation": question.explanation,
418
+ "points": question.points if is_correct else 0
419
+ })
420
+
421
+ # Create result
422
+ result = QuizResult(
423
+ quiz_id=quiz_id,
424
+ submission_id=str(uuid.uuid4()),
425
+ total_questions=len(quiz.questions),
426
+ correct_answers=correct_count,
427
+ incorrect_answers=incorrect_count,
428
+ score_percentage=(correct_count / len(quiz.questions) * 100) if quiz.questions else 0,
429
+ total_points=total_points,
430
+ earned_points=earned_points,
431
+ answers=detailed_answers
432
+ )
433
+
434
+ # Save result
435
+ result_file = self.quiz_results_dir / f"{result.submission_id}.json"
436
+ with open(result_file, 'w', encoding='utf-8') as f:
437
+ json.dump(result.dict(), f, indent=2, default=str)
438
+
439
+ return result
440
+
441
+ def get_quiz_history(self, quiz_id: str) -> QuizHistory:
442
+ """Get quiz attempt history"""
443
+ quiz = self.get_quiz(quiz_id)
444
+ if not quiz:
445
+ return None
446
+
447
+ # Load all results for this quiz
448
+ results = []
449
+ for file_path in self.quiz_results_dir.glob("*.json"):
450
+ try:
451
+ with open(file_path, 'r', encoding='utf-8') as f:
452
+ data = json.load(f)
453
+
454
+ result = QuizResult(**data)
455
+ if result.quiz_id == quiz_id:
456
+ results.append(result)
457
+ except Exception as e:
458
+ print(f"Error loading quiz result {file_path}: {e}")
459
+
460
+ # Calculate statistics
461
+ scores = [r.score_percentage for r in results] if results else [0]
462
+
463
+ history = QuizHistory(
464
+ quiz_id=quiz_id,
465
+ space_id=quiz.space_id,
466
+ quiz_title=quiz.title,
467
+ results=results,
468
+ best_score=max(scores),
469
+ average_score=sum(scores) / len(scores) if scores else 0,
470
+ attempts_count=len(results)
471
+ )
472
+
473
+ return history
utils/vector_db.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import uuid
3
+ import hashlib
4
+ from qdrant_client import QdrantClient
5
+ from qdrant_client.http import models
6
+ from sentence_transformers import SentenceTransformer
7
+ from typing import List, Dict, Optional
8
+ import threading
9
+ import logging
10
+ import warnings
11
+
12
+ warnings.filterwarnings('ignore', category=FutureWarning)
13
+ logging.getLogger('sentence_transformers').setLevel(logging.WARNING)
14
+
15
+ class VectorDatabase:
16
+ """Manage vector database for document embeddings using Qdrant Cloud."""
17
+
18
+ _embedding_model = None
19
+ _embedding_model_name = None
20
+ _embedding_model_lock = threading.Lock()
21
+
22
+ def __init__(self, collection_name: str = "documents", persist_directory: str = None):
23
+ """Initialize Qdrant Client (persist_directory is ignored for Cloud)"""
24
+
25
+ qdrant_url = os.getenv("QDRANT_URL")
26
+ qdrant_api_key = os.getenv("QDRANT_API_KEY")
27
+
28
+ if not qdrant_url or not qdrant_api_key:
29
+ raise ValueError("QDRANT_URL and QDRANT_API_KEY must be set in environment variables.")
30
+
31
+ self.client = QdrantClient(
32
+ url=qdrant_url,
33
+ api_key=qdrant_api_key,
34
+ timeout=60.0
35
+ )
36
+ self.collection_name = collection_name
37
+ self.vector_size = 384 # Size for standard sentence-transformers (e.g. all-MiniLM-L6-v2)
38
+
39
+ # Ensure collection exists
40
+ self._ensure_collection()
41
+
42
+ # Load embedding model
43
+ self.embedding_model = self._get_or_create_embedding_model()
44
+
45
+ def _ensure_collection(self):
46
+ """Creates the collection in Qdrant if it doesn't exist."""
47
+ try:
48
+ collections = self.client.get_collections().collections
49
+ exists = any(c.name == self.collection_name for c in collections)
50
+
51
+ if not exists:
52
+ self.client.create_collection(
53
+ collection_name=self.collection_name,
54
+ vectors_config=models.VectorParams(
55
+ size=self.vector_size,
56
+ distance=models.Distance.COSINE
57
+ )
58
+ )
59
+ except Exception as e:
60
+ print(f"Error checking/creating collection: {e}")
61
+
62
+ @classmethod
63
+ def _get_or_create_embedding_model(cls):
64
+ with cls._embedding_model_lock:
65
+ # Assuming you set EMBEDDING_MODEL in your config, defaulting to MiniLM
66
+ model_name = os.getenv("EMBEDDING_MODEL", "all-MiniLM-L6-v2")
67
+ if cls._embedding_model is None or cls._embedding_model_name != model_name:
68
+ import torch
69
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
70
+ print(f"Loading embedding model on {device}...")
71
+ cls._embedding_model = SentenceTransformer(model_name, device=device)
72
+ cls._embedding_model_name = model_name
73
+ return cls._embedding_model
74
+
75
+ def _string_to_uuid(self, string_id: str) -> str:
76
+ """Qdrant requires proper UUIDs. This hashes your custom string IDs into UUIDs."""
77
+ return str(uuid.UUID(hashlib.md5(string_id.encode()).hexdigest()))
78
+
79
+ def add_documents(self, texts: List[str], metadatas: List[Dict], ids: List[str]):
80
+ if not texts:
81
+ return
82
+
83
+ embeddings = self.embedding_model.encode(texts, show_progress_bar=False, batch_size=64).tolist()
84
+
85
+ points = []
86
+ for i in range(len(texts)):
87
+ payload = metadatas[i] if metadatas[i] else {}
88
+ payload['text'] = texts[i] # Store actual text in payload for retrieval
89
+
90
+ points.append(models.PointStruct(
91
+ id=self._string_to_uuid(ids[i]),
92
+ vector=embeddings[i],
93
+ payload=payload
94
+ ))
95
+
96
+ self.client.upsert(
97
+ collection_name=self.collection_name,
98
+ points=points
99
+ )
100
+
101
+ def query(self, query_text: str, n_results: int = 5, filter_dict: Optional[Dict] = None) -> Dict:
102
+ # Check if collection is empty
103
+ count = self.get_collection_count()
104
+ if count == 0:
105
+ return {"documents": [[]], "metadatas": [[]], "distances": [[]], "ids": [[]]}
106
+
107
+ query_embedding = self.embedding_model.encode([query_text])[0].tolist()
108
+
109
+ # Build Qdrant filter if provided
110
+ qdrant_filter = None
111
+ if filter_dict:
112
+ conditions = [
113
+ models.FieldCondition(key=k, match=models.MatchValue(value=v))
114
+ for k, v in filter_dict.items()
115
+ ]
116
+ qdrant_filter = models.Filter(must=conditions)
117
+
118
+ search_result = self.client.search(
119
+ collection_name=self.collection_name,
120
+ query_vector=query_embedding,
121
+ query_filter=qdrant_filter,
122
+ limit=n_results
123
+ )
124
+
125
+ # Format output to match exactly what your HybridRetriever expects (ChromaDB style)
126
+ docs, metas, scores, ids = [], [], [], []
127
+ for hit in search_result:
128
+ docs.append(hit.payload.get('text', ''))
129
+
130
+ # Remove text from metadata so it mimics Chroma
131
+ meta = {k: v for k, v in hit.payload.items() if k != 'text'}
132
+ metas.append(meta)
133
+
134
+ scores.append(hit.score)
135
+ ids.append(str(hit.id))
136
+
137
+ return {
138
+ "documents": [docs],
139
+ "metadatas": [metas],
140
+ "distances": [scores], # Note: Qdrant uses cosine similarity (higher is better), Chroma uses distance.
141
+ "ids": [ids]
142
+ }
143
+
144
+ def get_collection_count(self) -> int:
145
+ try:
146
+ return self.client.count(collection_name=self.collection_name).count
147
+ except Exception:
148
+ return 0