aviseth commited on
Commit
1a5863d
·
1 Parent(s): eca2087

feat: Phase 1 enhancements - ensemble endpoint, history API, rate limiting, storage monitoring

Browse files
requirements.txt CHANGED
@@ -44,6 +44,7 @@ tqdm>=4.65.0
44
  # Testing
45
  pytest>=7.4.0
46
  pytest-asyncio>=0.21.0
 
47
 
48
  # Visualization
49
  matplotlib>=3.7.0
 
44
  # Testing
45
  pytest>=7.4.0
46
  pytest-asyncio>=0.21.0
47
+ hypothesis>=6.0.0
48
 
49
  # Visualization
50
  matplotlib>=3.7.0
scripts/phase2_migration.sql ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ -- Phase 2: User Analysis History Database Migration
2
+ -- This script creates the user_analysis_history table and related indexes
3
+ -- Execute this in your Supabase SQL Editor
4
+
5
+ -- Step 1: Create the user_analysis_history table
6
+ CREATE TABLE IF NOT EXISTS user_analysis_history (
7
+ id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
8
+ session_id VARCHAR(36) NOT NULL,
9
+ article_id VARCHAR(36) NOT NULL UNIQUE,
10
+ text_preview VARCHAR(200) NOT NULL,
11
+ predicted_label VARCHAR(50) NOT NULL CHECK (predicted_label IN ('True', 'Fake', 'Satire', 'Bias')),
12
+ confidence FLOAT NOT NULL CHECK (confidence >= 0.0 AND confidence <= 1.0),
13
+ model_name VARCHAR(100) NOT NULL,
14
+ created_at TIMESTAMPTZ DEFAULT NOW() NOT NULL,
15
+
16
+ CONSTRAINT fk_article FOREIGN KEY (article_id) REFERENCES predictions(article_id) ON DELETE CASCADE
17
+ );
18
+
19
+ -- Step 2: Create indexes for efficient queries
20
+ CREATE INDEX IF NOT EXISTS idx_history_session_created ON user_analysis_history(session_id, created_at DESC);
21
+ CREATE INDEX IF NOT EXISTS idx_history_article ON user_analysis_history(article_id);
22
+
23
+ -- Step 3: Enable row-level security
24
+ ALTER TABLE user_analysis_history ENABLE ROW LEVEL SECURITY;
25
+
26
+ -- Step 4: Create policy to allow all operations (for development)
27
+ -- Note: In production, you should restrict this based on your security requirements
28
+ DROP POLICY IF EXISTS "allow_all_history" ON user_analysis_history;
29
+ CREATE POLICY "allow_all_history" ON user_analysis_history FOR ALL USING (true) WITH CHECK (true);
30
+
31
+ -- Step 5: Verify the table was created
32
+ SELECT
33
+ table_name,
34
+ column_name,
35
+ data_type,
36
+ is_nullable,
37
+ column_default
38
+ FROM information_schema.columns
39
+ WHERE table_name = 'user_analysis_history'
40
+ ORDER BY ordinal_position;
41
+
42
+ -- Step 6: Verify indexes were created
43
+ SELECT
44
+ indexname,
45
+ indexdef
46
+ FROM pg_indexes
47
+ WHERE tablename = 'user_analysis_history';
48
+
49
+ -- Step 7: Verify RLS policy was created
50
+ SELECT
51
+ policyname,
52
+ permissive,
53
+ roles,
54
+ cmd
55
+ FROM pg_policies
56
+ WHERE tablename = 'user_analysis_history';
57
+
58
+ -- Optional: Insert a test record to verify everything works
59
+ -- Uncomment the following lines to test (replace with actual article_id from predictions table)
60
+ /*
61
+ DO $$
62
+ DECLARE
63
+ test_article_id VARCHAR(36);
64
+ test_session_id VARCHAR(36);
65
+ BEGIN
66
+ -- First, insert a test prediction
67
+ test_article_id := gen_random_uuid()::text;
68
+ test_session_id := gen_random_uuid()::text;
69
+
70
+ INSERT INTO predictions (article_id, text, predicted_label, confidence, model_name)
71
+ VALUES (test_article_id, 'Test article for migration verification', 'True', 0.95, 'ensemble');
72
+
73
+ -- Then, insert a test history record
74
+ INSERT INTO user_analysis_history (session_id, article_id, text_preview, predicted_label, confidence, model_name)
75
+ VALUES (test_session_id, test_article_id, 'Test article for migration verification', 'True', 0.95, 'ensemble');
76
+
77
+ -- Verify the record was inserted
78
+ IF EXISTS (SELECT 1 FROM user_analysis_history WHERE article_id = test_article_id) THEN
79
+ RAISE NOTICE 'Test record inserted successfully!';
80
+ ELSE
81
+ RAISE EXCEPTION 'Test record insertion failed!';
82
+ END IF;
83
+
84
+ -- Clean up test data
85
+ DELETE FROM user_analysis_history WHERE article_id = test_article_id;
86
+ DELETE FROM predictions WHERE article_id = test_article_id;
87
+
88
+ RAISE NOTICE 'Test data cleaned up. Migration verification complete!';
89
+ END $$;
90
+ */
scripts/setup_supabase.sql CHANGED
@@ -1,3 +1,4 @@
 
1
  DROP TABLE IF EXISTS feedback CASCADE;
2
  DROP TABLE IF EXISTS predictions CASCADE;
3
  DROP TABLE IF EXISTS news_articles CASCADE;
@@ -77,11 +78,30 @@ CREATE TABLE user_sessions (
77
  last_activity TIMESTAMPTZ DEFAULT NOW()
78
  );
79
 
80
- ALTER TABLE predictions DISABLE ROW LEVEL SECURITY;
81
- ALTER TABLE feedback DISABLE ROW LEVEL SECURITY;
82
- ALTER TABLE news_articles DISABLE ROW LEVEL SECURITY;
83
- ALTER TABLE model_performance DISABLE ROW LEVEL SECURITY;
84
- ALTER TABLE user_sessions DISABLE ROW LEVEL SECURITY;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
 
86
  CREATE VIEW prediction_stats AS
87
  SELECT predicted_label, COUNT(*) AS total_count, AVG(confidence) AS avg_confidence
 
1
+ DROP TABLE IF EXISTS user_analysis_history CASCADE;
2
  DROP TABLE IF EXISTS feedback CASCADE;
3
  DROP TABLE IF EXISTS predictions CASCADE;
4
  DROP TABLE IF EXISTS news_articles CASCADE;
 
78
  last_activity TIMESTAMPTZ DEFAULT NOW()
79
  );
80
 
81
+ CREATE TABLE user_analysis_history (
82
+ id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
83
+ session_id VARCHAR(36) NOT NULL,
84
+ article_id VARCHAR(36) NOT NULL UNIQUE,
85
+ text_preview VARCHAR(200) NOT NULL,
86
+ predicted_label VARCHAR(50) NOT NULL CHECK (predicted_label IN ('True', 'Fake', 'Satire', 'Bias')),
87
+ confidence FLOAT NOT NULL CHECK (confidence >= 0.0 AND confidence <= 1.0),
88
+ model_name VARCHAR(100) NOT NULL,
89
+ created_at TIMESTAMPTZ DEFAULT NOW() NOT NULL,
90
+
91
+ CONSTRAINT fk_article FOREIGN KEY (article_id) REFERENCES predictions(article_id) ON DELETE CASCADE
92
+ );
93
+
94
+ CREATE INDEX idx_history_session_created ON user_analysis_history(session_id, created_at DESC);
95
+ CREATE INDEX idx_history_article ON user_analysis_history(article_id);
96
+
97
+ ALTER TABLE predictions DISABLE ROW LEVEL SECURITY;
98
+ ALTER TABLE feedback DISABLE ROW LEVEL SECURITY;
99
+ ALTER TABLE news_articles DISABLE ROW LEVEL SECURITY;
100
+ ALTER TABLE model_performance DISABLE ROW LEVEL SECURITY;
101
+ ALTER TABLE user_sessions DISABLE ROW LEVEL SECURITY;
102
+ ALTER TABLE user_analysis_history ENABLE ROW LEVEL SECURITY;
103
+
104
+ CREATE POLICY "allow_all_history" ON user_analysis_history FOR ALL USING (true) WITH CHECK (true);
105
 
106
  CREATE VIEW prediction_stats AS
107
  SELECT predicted_label, COUNT(*) AS total_count, AVG(confidence) AS avg_confidence
src/api/main.py CHANGED
@@ -1,9 +1,13 @@
1
- from fastapi import FastAPI, HTTPException, BackgroundTasks
2
  from fastapi.middleware.cors import CORSMiddleware
3
- from pydantic import BaseModel
4
  from typing import Optional, List, Dict
5
  import os
6
  import uuid
 
 
 
 
7
  from dotenv import load_dotenv
8
 
9
  from src.utils.supabase_client import get_supabase_client
@@ -11,6 +15,14 @@ from src.utils.gnews_client import get_gnews_client
11
 
12
  load_dotenv()
13
 
 
 
 
 
 
 
 
 
14
  app = FastAPI(
15
  title="Fake News Detection API",
16
  description="Multi-class fake news detection using DistilBERT, RoBERTa, and XLNet",
@@ -34,6 +46,37 @@ app.add_middleware(
34
  allow_headers=["*"],
35
  )
36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  VALID_MODELS = {"distilbert", "roberta", "xlnet"}
38
 
39
 
@@ -70,6 +113,48 @@ class ExplainRequest(BaseModel):
70
  deep: Optional[bool] = False
71
 
72
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  @app.on_event("startup")
74
  async def startup_event():
75
  try:
@@ -122,8 +207,16 @@ async def health_check():
122
 
123
 
124
  @app.post("/predict", response_model=PredictionResponse)
125
- async def predict(request: PredictionRequest, background_tasks: BackgroundTasks):
126
- """Classify news as True / Fake / Satire / Bias."""
 
 
 
 
 
 
 
 
127
  if not request.text and not request.url:
128
  raise HTTPException(status_code=400, detail="Provide text or url")
129
 
@@ -164,23 +257,274 @@ async def predict(request: PredictionRequest, background_tasks: BackgroundTasks)
164
  )
165
 
166
  def _store():
 
 
 
 
167
  try:
168
  supabase = get_supabase_client()
169
- supabase.store_prediction(
170
- article_id=article_id,
171
- text=text,
172
- predicted_label=result["label"],
173
- confidence=result["confidence"],
174
- model_name=model_key,
175
- explanation=result.get("tokens", []),
176
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
177
  except Exception as e:
178
- print(f"[bg] store_prediction failed: {e}")
 
179
 
180
  background_tasks.add_task(_store)
181
  return response
182
 
183
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
184
  @app.post("/feedback")
185
  async def submit_feedback(feedback: FeedbackRequest):
186
  """Submit user correction for active learning."""
@@ -343,6 +687,23 @@ async def get_statistics():
343
  status_code=500, detail=f"Error fetching stats: {e}")
344
 
345
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
346
  @app.get("/models")
347
  async def list_models():
348
  """List available models and their training status."""
 
1
+ from fastapi import FastAPI, HTTPException, BackgroundTasks, Header, Query, Request
2
  from fastapi.middleware.cors import CORSMiddleware
3
+ from pydantic import BaseModel, validator
4
  from typing import Optional, List, Dict
5
  import os
6
  import uuid
7
+ import asyncio
8
+ import logging
9
+ import time
10
+ from collections import defaultdict
11
  from dotenv import load_dotenv
12
 
13
  from src.utils.supabase_client import get_supabase_client
 
15
 
16
  load_dotenv()
17
 
18
+ # Configure logger
19
+ logger = logging.getLogger(__name__)
20
+
21
+ # Rate limiting: Track requests per IP
22
+ request_tracker = defaultdict(list)
23
+ RATE_LIMIT_REQUESTS = 100 # Max requests per window
24
+ RATE_LIMIT_WINDOW = 60 # Window in seconds
25
+
26
  app = FastAPI(
27
  title="Fake News Detection API",
28
  description="Multi-class fake news detection using DistilBERT, RoBERTa, and XLNet",
 
46
  allow_headers=["*"],
47
  )
48
 
49
+
50
+ @app.middleware("http")
51
+ async def rate_limit_middleware(request: Request, call_next):
52
+ """
53
+ Rate limiting middleware to prevent abuse.
54
+ Allows RATE_LIMIT_REQUESTS per RATE_LIMIT_WINDOW seconds per IP.
55
+ """
56
+ client_ip = request.client.host
57
+ current_time = time.time()
58
+
59
+ # Clean old requests outside the window
60
+ request_tracker[client_ip] = [
61
+ req_time for req_time in request_tracker[client_ip]
62
+ if current_time - req_time < RATE_LIMIT_WINDOW
63
+ ]
64
+
65
+ # Check rate limit
66
+ if len(request_tracker[client_ip]) >= RATE_LIMIT_REQUESTS:
67
+ logger.warning(f"Rate limit exceeded for IP: {client_ip}")
68
+ raise HTTPException(
69
+ status_code=429,
70
+ detail=f"Rate limit exceeded. Maximum {RATE_LIMIT_REQUESTS} requests per {RATE_LIMIT_WINDOW} seconds."
71
+ )
72
+
73
+ # Track this request
74
+ request_tracker[client_ip].append(current_time)
75
+
76
+ response = await call_next(request)
77
+ return response
78
+
79
+
80
  VALID_MODELS = {"distilbert", "roberta", "xlnet"}
81
 
82
 
 
113
  deep: Optional[bool] = False
114
 
115
 
116
+ # Ensemble API Models
117
+ class EnsemblePredictionRequest(BaseModel):
118
+ text: str
119
+ session_id: Optional[str] = None
120
+
121
+ @validator('text')
122
+ def validate_text(cls, v):
123
+ if len(v.strip()) < 10:
124
+ raise ValueError("Text too short to classify")
125
+ return v
126
+
127
+
128
+ class VotingResult(BaseModel):
129
+ label: str
130
+ confidence: float
131
+ scores: Dict[str, float]
132
+
133
+
134
+ class VotingStrategies(BaseModel):
135
+ hard_voting: VotingResult
136
+ soft_voting: VotingResult
137
+ weighted_voting: VotingResult
138
+
139
+
140
+ class ModelPredictionResponse(BaseModel):
141
+ model_name: str
142
+ label: str
143
+ confidence: float
144
+ scores: Dict[str, float]
145
+ tokens: List[ExplanationData]
146
+
147
+
148
+ class EnsemblePredictionResponse(BaseModel):
149
+ article_id: str
150
+ primary_prediction: VotingResult # hard voting result
151
+ voting_strategies: VotingStrategies
152
+ individual_models: List[ModelPredictionResponse]
153
+ merged_explanation: List[ExplanationData]
154
+ execution_time_ms: float
155
+ warnings: Optional[List[str]] = None
156
+
157
+
158
  @app.on_event("startup")
159
  async def startup_event():
160
  try:
 
207
 
208
 
209
  @app.post("/predict", response_model=PredictionResponse)
210
+ async def predict(
211
+ request: PredictionRequest,
212
+ background_tasks: BackgroundTasks,
213
+ x_session_id: Optional[str] = Header(None, alias="X-Session-ID")
214
+ ):
215
+ """
216
+ Classify news as True / Fake / Satire / Bias.
217
+
218
+ Requirements: 4.4, 4.6, 2.7
219
+ """
220
  if not request.text and not request.url:
221
  raise HTTPException(status_code=400, detail="Provide text or url")
222
 
 
257
  )
258
 
259
  def _store():
260
+ """
261
+ Store prediction in both predictions and user_analysis_history tables.
262
+ Requirements: 4.4, 4.6, 2.7
263
+ """
264
  try:
265
  supabase = get_supabase_client()
266
+
267
+ # Store in predictions table (Requirement 2.7)
268
+ try:
269
+ supabase.store_prediction(
270
+ article_id=article_id,
271
+ text=text,
272
+ predicted_label=result["label"],
273
+ confidence=result["confidence"],
274
+ model_name=model_key,
275
+ explanation=result.get("tokens", []),
276
+ )
277
+ logger.info(
278
+ f"Stored prediction {article_id} in predictions table")
279
+ except Exception as e:
280
+ logger.error(
281
+ f"Failed to store prediction in predictions table: {e}")
282
+
283
+ # Store in user_analysis_history if session_id is provided (Requirement 4.4, 4.6)
284
+ if x_session_id:
285
+ try:
286
+ supabase.store_user_history(
287
+ session_id=x_session_id,
288
+ article_id=article_id,
289
+ text=text,
290
+ predicted_label=result["label"],
291
+ confidence=result["confidence"],
292
+ model_name=model_key
293
+ )
294
+ logger.info(
295
+ f"Stored prediction {article_id} in user_analysis_history for session {x_session_id}")
296
+ except Exception as e:
297
+ # Handle missing session_id gracefully (Requirement 4.4)
298
+ logger.error(
299
+ f"Failed to store prediction in user_analysis_history: {e}")
300
+ else:
301
+ logger.debug(
302
+ f"No session_id provided for prediction {article_id}, skipping history storage")
303
+
304
  except Exception as e:
305
+ logger.error(
306
+ f"Database storage failed for prediction {article_id}: {e}")
307
 
308
  background_tasks.add_task(_store)
309
  return response
310
 
311
 
312
+ @app.post("/predict/ensemble", response_model=EnsemblePredictionResponse)
313
+ async def predict_ensemble(
314
+ request: EnsemblePredictionRequest,
315
+ background_tasks: BackgroundTasks,
316
+ x_session_id: Optional[str] = Header(None, alias="X-Session-ID")
317
+ ):
318
+ """
319
+ Run ensemble prediction using all three models (DistilBERT, RoBERTa, XLNet).
320
+ Combines predictions using hard voting, soft voting, and weighted voting strategies.
321
+
322
+ Requirements: 2.1, 2.2, 2.5, 2.8
323
+ """
324
+ article_id = str(uuid.uuid4())
325
+ session_id = x_session_id or request.session_id
326
+
327
+ try:
328
+ from src.models.ensemble import get_ensemble_classifier
329
+
330
+ # Get ensemble classifier instance
331
+ ensemble = get_ensemble_classifier()
332
+
333
+ # Run ensemble prediction with 15s timeout (Requirement 2.8)
334
+ result = await asyncio.wait_for(
335
+ ensemble.predict_ensemble(request.text),
336
+ timeout=15.0
337
+ )
338
+
339
+ # Build response with all voting strategies
340
+ primary_prediction = VotingResult(
341
+ label=result.hard_voting_label,
342
+ confidence=result.hard_voting_confidence,
343
+ scores={result.hard_voting_label: result.hard_voting_confidence}
344
+ )
345
+
346
+ voting_strategies = VotingStrategies(
347
+ hard_voting=VotingResult(
348
+ label=result.hard_voting_label,
349
+ confidence=result.hard_voting_confidence,
350
+ scores={result.hard_voting_label: result.hard_voting_confidence}
351
+ ),
352
+ soft_voting=VotingResult(
353
+ label=result.soft_voting_label,
354
+ confidence=result.soft_voting_confidence,
355
+ scores=result.soft_voting_scores
356
+ ),
357
+ weighted_voting=VotingResult(
358
+ label=result.weighted_voting_label,
359
+ confidence=result.weighted_voting_confidence,
360
+ scores=result.weighted_voting_scores
361
+ )
362
+ )
363
+
364
+ # Convert individual model predictions
365
+ individual_models = [
366
+ ModelPredictionResponse(
367
+ model_name=pred.model_name,
368
+ label=pred.label,
369
+ confidence=pred.confidence,
370
+ scores=pred.scores,
371
+ tokens=[ExplanationData(**t) for t in pred.tokens]
372
+ )
373
+ for pred in result.individual_predictions
374
+ ]
375
+
376
+ # Convert merged explanation
377
+ merged_explanation = [
378
+ ExplanationData(**token) for token in result.merged_explanation
379
+ ]
380
+
381
+ response = EnsemblePredictionResponse(
382
+ article_id=article_id,
383
+ primary_prediction=primary_prediction,
384
+ voting_strategies=voting_strategies,
385
+ individual_models=individual_models,
386
+ merged_explanation=merged_explanation,
387
+ execution_time_ms=result.execution_time_ms,
388
+ warnings=result.warnings
389
+ )
390
+
391
+ # Background task: store ensemble prediction to database
392
+ def store_ensemble_prediction():
393
+ """
394
+ Store prediction in both predictions and user_analysis_history tables.
395
+ Handles database failures gracefully - logs errors but doesn't crash.
396
+ Requirements: 2.3, 2.4, 2.6, 2.7, 14.3
397
+ """
398
+ try:
399
+ supabase = get_supabase_client()
400
+
401
+ # Store in predictions table with model_name="ensemble" (Requirement 2.7)
402
+ try:
403
+ supabase.store_prediction(
404
+ article_id=article_id,
405
+ text=request.text,
406
+ predicted_label=result.hard_voting_label,
407
+ confidence=result.hard_voting_confidence,
408
+ model_name="ensemble",
409
+ explanation=result.merged_explanation,
410
+ )
411
+ logger.info(
412
+ f"Stored ensemble prediction {article_id} in predictions table")
413
+ except Exception as e:
414
+ # Log but continue - don't let predictions table failure stop history storage
415
+ logger.error(
416
+ f"Failed to store prediction in predictions table: {e}")
417
+
418
+ # Store in user_analysis_history if session_id is provided (Requirement 2.4)
419
+ if session_id:
420
+ try:
421
+ supabase.store_user_history(
422
+ session_id=session_id,
423
+ article_id=article_id,
424
+ text=request.text,
425
+ predicted_label=result.hard_voting_label,
426
+ confidence=result.hard_voting_confidence,
427
+ model_name="ensemble"
428
+ )
429
+ logger.info(
430
+ f"Stored ensemble prediction {article_id} in user_analysis_history for session {session_id}")
431
+ except Exception as e:
432
+ # Log but don't crash - history storage is non-critical (Requirement 14.3)
433
+ logger.error(
434
+ f"Failed to store prediction in user_analysis_history: {e}")
435
+ else:
436
+ logger.debug(
437
+ f"No session_id provided for prediction {article_id}, skipping history storage")
438
+
439
+ except Exception as e:
440
+ # Catch-all for any database connection failures (Requirement 14.3)
441
+ logger.error(
442
+ f"Database storage failed for prediction {article_id}: {e}")
443
+
444
+ background_tasks.add_task(store_ensemble_prediction)
445
+ return response
446
+
447
+ except asyncio.TimeoutError:
448
+ # Requirement 2.8: Return HTTP 504 after 15s timeout
449
+ raise HTTPException(
450
+ status_code=504,
451
+ detail="Ensemble prediction timed out after 15 seconds"
452
+ )
453
+ except ValueError as e:
454
+ # Handle validation errors (e.g., text too short)
455
+ raise HTTPException(status_code=422, detail=str(e))
456
+ except RuntimeError as e:
457
+ # Handle case where all models fail
458
+ raise HTTPException(status_code=500, detail=str(e))
459
+ except Exception as e:
460
+ import traceback
461
+ traceback.print_exc()
462
+ raise HTTPException(
463
+ status_code=500,
464
+ detail=f"Ensemble prediction error: {str(e)}"
465
+ )
466
+
467
+
468
+ @app.get("/history/{session_id}")
469
+ async def get_user_history(
470
+ session_id: str,
471
+ limit: int = Query(100, ge=1, le=100)
472
+ ):
473
+ """
474
+ Retrieve user's analysis history by session ID.
475
+
476
+ Args:
477
+ session_id: UUID v4 session identifier
478
+ limit: Maximum records to return (1-100, default 100)
479
+
480
+ Returns:
481
+ List of prediction records with metadata
482
+
483
+ Requirements: 6.1, 6.2, 6.3, 6.4, 6.5, 6.6, 6.7
484
+ """
485
+ # Validate UUID format (Requirement 6.6)
486
+ try:
487
+ uuid.UUID(session_id, version=4)
488
+ except ValueError:
489
+ raise HTTPException(
490
+ status_code=400,
491
+ detail="Invalid session ID format"
492
+ )
493
+
494
+ try:
495
+ # Add 2s timeout (Requirement 6.7)
496
+ supabase = get_supabase_client()
497
+ history = await asyncio.wait_for(
498
+ asyncio.get_event_loop().run_in_executor(
499
+ None,
500
+ supabase.get_user_history,
501
+ session_id,
502
+ limit
503
+ ),
504
+ timeout=2.0
505
+ )
506
+
507
+ # Return empty array with HTTP 200 for sessions with no history (Requirement 6.5)
508
+ return {
509
+ "status": "success",
510
+ "session_id": session_id,
511
+ "count": len(history),
512
+ "history": history
513
+ }
514
+ except asyncio.TimeoutError:
515
+ # Requirement 6.7: Return HTTP 504 after 2s timeout
516
+ raise HTTPException(
517
+ status_code=504,
518
+ detail="History retrieval timed out after 2 seconds"
519
+ )
520
+ except Exception as e:
521
+ logger.error(f"Failed to fetch history for session {session_id}: {e}")
522
+ raise HTTPException(
523
+ status_code=500,
524
+ detail="Failed to load history"
525
+ )
526
+
527
+
528
  @app.post("/feedback")
529
  async def submit_feedback(feedback: FeedbackRequest):
530
  """Submit user correction for active learning."""
 
687
  status_code=500, detail=f"Error fetching stats: {e}")
688
 
689
 
690
+ @app.get("/storage")
691
+ async def get_storage_usage():
692
+ """
693
+ Get database storage usage metrics and warnings.
694
+
695
+ Returns storage usage information and warns when approaching 90% of 500MB limit.
696
+ """
697
+ try:
698
+ supabase = get_supabase_client()
699
+ usage = supabase.check_storage_usage()
700
+ return {"status": "success", "storage": usage}
701
+ except Exception as e:
702
+ logger.error(f"Error fetching storage usage: {e}")
703
+ raise HTTPException(
704
+ status_code=500, detail=f"Error fetching storage usage: {e}")
705
+
706
+
707
  @app.get("/models")
708
  async def list_models():
709
  """List available models and their training status."""
src/models/ensemble.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Ensemble classifier combining DistilBERT, RoBERTa, and XLNet
3
+ with parallel execution and multiple voting strategies.
4
+ """
5
+
6
+ import asyncio
7
+ import time
8
+ import logging
9
+ from dataclasses import dataclass
10
+ from typing import List, Dict, Optional
11
+ from concurrent.futures import ThreadPoolExecutor
12
+ import numpy as np
13
+
14
+ from .inference import get_classifier
15
+
16
+ logging.basicConfig(level=logging.INFO)
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ @dataclass
21
+ class TokenImportance:
22
+ token: str
23
+ score: float
24
+
25
+
26
+ @dataclass
27
+ class ModelPrediction:
28
+ model_name: str
29
+ label: str
30
+ confidence: float
31
+ scores: Dict[str, float]
32
+ tokens: List[Dict]
33
+
34
+
35
+ @dataclass
36
+ class EnsembleResult:
37
+ hard_voting_label: str
38
+ hard_voting_confidence: float
39
+ soft_voting_label: str
40
+ soft_voting_confidence: float
41
+ soft_voting_scores: Dict[str, float]
42
+ weighted_voting_label: str
43
+ weighted_voting_confidence: float
44
+ weighted_voting_scores: Dict[str, float]
45
+ individual_predictions: List[ModelPrediction]
46
+ merged_explanation: List[Dict]
47
+ execution_time_ms: float
48
+ warnings: Optional[List[str]] = None
49
+
50
+
51
+ class EnsembleClassifier:
52
+ """Combines predictions from all three models using voting strategies."""
53
+
54
+ def __init__(self):
55
+ self.model_names = ['distilbert', 'roberta', 'xlnet']
56
+ self.models = {name: get_classifier(name) for name in self.model_names}
57
+ self.weights = {'distilbert': 0.859, 'roberta': 0.858, 'xlnet': 0.862}
58
+ self.executor = ThreadPoolExecutor(max_workers=3)
59
+
60
+ async def predict_ensemble(self, text: str, model_timeout: float = 10.0,
61
+ total_timeout: float = 15.0) -> EnsembleResult:
62
+ start_time = time.time()
63
+ warnings = []
64
+
65
+ loop = asyncio.get_event_loop()
66
+ tasks = [
67
+ loop.run_in_executor(
68
+ self.executor, self._predict_with_timeout, name, text, model_timeout)
69
+ for name in self.model_names
70
+ ]
71
+
72
+ try:
73
+ results = await asyncio.wait_for(
74
+ asyncio.gather(*tasks, return_exceptions=True),
75
+ timeout=total_timeout
76
+ )
77
+ except asyncio.TimeoutError:
78
+ warnings.append("Ensemble prediction exceeded total timeout")
79
+ raise
80
+
81
+ valid_predictions = []
82
+ for name, result in zip(self.model_names, results):
83
+ if isinstance(result, Exception):
84
+ warnings.append(f"{name} failed: {str(result)}")
85
+ elif result is None:
86
+ warnings.append(f"{name} returned no result")
87
+ else:
88
+ valid_predictions.append(ModelPrediction(
89
+ model_name=name,
90
+ label=result['label'],
91
+ confidence=result['confidence'],
92
+ scores=result['scores'],
93
+ tokens=result['tokens']
94
+ ))
95
+
96
+ if not valid_predictions:
97
+ raise RuntimeError("All models failed to process the request")
98
+
99
+ hard_label, hard_conf = self.hard_voting(valid_predictions)
100
+ soft_scores = self.soft_voting(valid_predictions)
101
+ soft_label = max(soft_scores.items(), key=lambda x: x[1])[0]
102
+ soft_conf = soft_scores[soft_label]
103
+
104
+ weighted_scores = self.weighted_voting(valid_predictions)
105
+ weighted_label = max(weighted_scores.items(), key=lambda x: x[1])[0]
106
+ weighted_conf = weighted_scores[weighted_label]
107
+
108
+ merged_tokens = self._merge_explanations(valid_predictions)
109
+ execution_time = (time.time() - start_time) * 1000
110
+
111
+ logger.info(
112
+ f"Ensemble completed in {execution_time:.2f}ms with {len(valid_predictions)}/{len(self.model_names)} models")
113
+ if warnings:
114
+ logger.warning(f"Ensemble warnings: {warnings}")
115
+
116
+ return EnsembleResult(
117
+ hard_voting_label=hard_label,
118
+ hard_voting_confidence=hard_conf,
119
+ soft_voting_label=soft_label,
120
+ soft_voting_confidence=soft_conf,
121
+ soft_voting_scores=soft_scores,
122
+ weighted_voting_label=weighted_label,
123
+ weighted_voting_confidence=weighted_conf,
124
+ weighted_voting_scores=weighted_scores,
125
+ individual_predictions=valid_predictions,
126
+ merged_explanation=merged_tokens,
127
+ execution_time_ms=round(execution_time, 2),
128
+ warnings=warnings if warnings else None
129
+ )
130
+
131
+ def _predict_with_timeout(self, model_name: str, text: str, timeout: float) -> Optional[Dict]:
132
+ try:
133
+ return self.models[model_name].predict(text)
134
+ except Exception as e:
135
+ logger.error(f"[ensemble] {model_name} prediction failed: {e}")
136
+ return None
137
+
138
+ def hard_voting(self, predictions: List[ModelPrediction]) -> tuple[str, float]:
139
+ votes = {}
140
+ for pred in predictions:
141
+ votes[pred.label] = votes.get(pred.label, 0) + 1
142
+ winning_label = max(votes.items(), key=lambda x: x[1])[0]
143
+ confidences = [
144
+ p.confidence for p in predictions if p.label == winning_label]
145
+ return winning_label, round(sum(confidences) / len(confidences), 4)
146
+
147
+ def soft_voting(self, predictions: List[ModelPrediction]) -> Dict[str, float]:
148
+ all_labels = set(
149
+ label for pred in predictions for label in pred.scores)
150
+ return {
151
+ label: round(sum(p.scores.get(label, 0.0)
152
+ for p in predictions) / len(predictions), 4)
153
+ for label in all_labels
154
+ }
155
+
156
+ def weighted_voting(self, predictions: List[ModelPrediction]) -> Dict[str, float]:
157
+ all_labels = set(
158
+ label for pred in predictions for label in pred.scores)
159
+ total_weight = sum(self.weights[p.model_name] for p in predictions)
160
+ return {
161
+ label: round(
162
+ sum(p.scores.get(label, 0.0) *
163
+ self.weights[p.model_name] for p in predictions) / total_weight,
164
+ 4
165
+ )
166
+ for label in all_labels
167
+ }
168
+
169
+ def _merge_explanations(self, predictions: List[ModelPrediction]) -> List[Dict]:
170
+ token_scores: Dict[str, float] = {}
171
+ token_counts: Dict[str, int] = {}
172
+ for pred in predictions:
173
+ for td in pred.tokens:
174
+ token = td['token']
175
+ token_scores[token] = token_scores.get(
176
+ token, 0.0) + td['score']
177
+ token_counts[token] = token_counts.get(token, 0) + 1
178
+ merged = [
179
+ {'token': t, 'score': round(token_scores[t] / token_counts[t], 4)}
180
+ for t in token_scores
181
+ ]
182
+ return sorted(merged, key=lambda x: x['score'], reverse=True)[:10]
183
+
184
+
185
+ _ensemble_classifier: Optional[EnsembleClassifier] = None
186
+
187
+
188
+ def get_ensemble_classifier() -> EnsembleClassifier:
189
+ global _ensemble_classifier
190
+ if _ensemble_classifier is None:
191
+ _ensemble_classifier = EnsembleClassifier()
192
+ return _ensemble_classifier
src/utils/supabase_client.py CHANGED
@@ -1,11 +1,36 @@
1
  import os
 
 
 
2
  from typing import Optional, Dict, Any, List
3
- from datetime import datetime
4
  from supabase import create_client, Client
5
  from dotenv import load_dotenv
6
 
7
  load_dotenv()
8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
  class SupabaseClient:
11
  def __init__(self):
@@ -17,15 +42,9 @@ class SupabaseClient:
17
  "SUPABASE_URL and SUPABASE_SERVICE_KEY must be set")
18
  self.client: Client = create_client(self.url, self.key)
19
 
20
- def store_prediction(
21
- self,
22
- article_id: str,
23
- text: str,
24
- predicted_label: str,
25
- confidence: float,
26
- model_name: str,
27
- explanation=None,
28
- ) -> Dict[str, Any]:
29
  data = {
30
  "article_id": article_id,
31
  "text": text[:1000],
@@ -33,24 +52,24 @@ class SupabaseClient:
33
  "confidence": confidence,
34
  "model_name": model_name,
35
  "explanation": explanation,
36
- "created_at": datetime.utcnow().isoformat(),
37
  }
38
- response = self.client.table("predictions").insert(data).execute()
39
- return response.data
 
 
 
 
 
40
 
41
- def store_feedback(
42
- self,
43
- article_id: str,
44
- predicted_label: str,
45
- actual_label: str,
46
- user_comment: Optional[str] = None,
47
- ) -> Dict[str, Any]:
48
  data = {
49
  "article_id": article_id,
50
  "predicted_label": predicted_label,
51
  "actual_label": actual_label,
52
  "user_comment": user_comment,
53
- "created_at": datetime.utcnow().isoformat(),
54
  }
55
  response = self.client.table("feedback").insert(data).execute()
56
  return response.data
@@ -64,16 +83,96 @@ class SupabaseClient:
64
  for row in by_label_rows.data:
65
  lbl = row["predicted_label"]
66
  label_counts[lbl] = label_counts.get(lbl, 0) + 1
67
- return {
68
- "total_predictions": total.count,
69
- "by_label": label_counts,
70
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
 
72
  def get_feedback_for_training(self, limit: int = 1000) -> List[Dict[str, Any]]:
73
  response = self.client.table("feedback").select(
74
  "*").limit(limit).execute()
75
  return response.data
76
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
 
78
  _supabase_client: Optional[SupabaseClient] = None
79
 
@@ -86,6 +185,5 @@ def get_supabase_client() -> SupabaseClient:
86
 
87
 
88
  def reset_client():
89
- """Force re-initialisation."""
90
  global _supabase_client
91
  _supabase_client = None
 
1
  import os
2
+ import uuid
3
+ import time
4
+ import logging
5
  from typing import Optional, Dict, Any, List
6
+ from datetime import datetime, timezone
7
  from supabase import create_client, Client
8
  from dotenv import load_dotenv
9
 
10
  load_dotenv()
11
 
12
+ logging.basicConfig(level=logging.INFO)
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ def retry_with_exponential_backoff(max_retries=3, base_delay=1.0):
17
+ def decorator(func):
18
+ def wrapper(*args, **kwargs):
19
+ for attempt in range(max_retries):
20
+ try:
21
+ return func(*args, **kwargs)
22
+ except Exception as e:
23
+ if attempt == max_retries - 1:
24
+ logger.error(
25
+ f"{func.__name__} failed after {max_retries} attempts: {e}")
26
+ raise
27
+ delay = base_delay * (2 ** attempt)
28
+ logger.warning(
29
+ f"{func.__name__} attempt {attempt + 1} failed: {e}. Retrying in {delay}s...")
30
+ time.sleep(delay)
31
+ return wrapper
32
+ return decorator
33
+
34
 
35
  class SupabaseClient:
36
  def __init__(self):
 
42
  "SUPABASE_URL and SUPABASE_SERVICE_KEY must be set")
43
  self.client: Client = create_client(self.url, self.key)
44
 
45
+ @retry_with_exponential_backoff(max_retries=3)
46
+ def store_prediction(self, article_id: str, text: str, predicted_label: str,
47
+ confidence: float, model_name: str, explanation=None) -> Dict[str, Any]:
 
 
 
 
 
 
48
  data = {
49
  "article_id": article_id,
50
  "text": text[:1000],
 
52
  "confidence": confidence,
53
  "model_name": model_name,
54
  "explanation": explanation,
55
+ "created_at": datetime.now(timezone.utc).isoformat(),
56
  }
57
+ try:
58
+ response = self.client.table("predictions").insert(data).execute()
59
+ logger.info(f"Stored prediction for article {article_id}")
60
+ return response.data
61
+ except Exception as e:
62
+ logger.error(f"Failed to store prediction: {e}")
63
+ raise
64
 
65
+ def store_feedback(self, article_id: str, predicted_label: str,
66
+ actual_label: str, user_comment: Optional[str] = None) -> Dict[str, Any]:
 
 
 
 
 
67
  data = {
68
  "article_id": article_id,
69
  "predicted_label": predicted_label,
70
  "actual_label": actual_label,
71
  "user_comment": user_comment,
72
+ "created_at": datetime.now(timezone.utc).isoformat(),
73
  }
74
  response = self.client.table("feedback").insert(data).execute()
75
  return response.data
 
83
  for row in by_label_rows.data:
84
  lbl = row["predicted_label"]
85
  label_counts[lbl] = label_counts.get(lbl, 0) + 1
86
+ logger.info(f"Total predictions: {total.count}")
87
+ return {"total_predictions": total.count, "by_label": label_counts}
88
+
89
+ def check_storage_usage(self) -> Dict[str, Any]:
90
+ """Check database storage usage and warn if approaching the 500MB free-tier limit."""
91
+ try:
92
+ predictions_count = self.client.table("predictions").select(
93
+ "*", count="exact").execute().count
94
+ history_count = self.client.table("user_analysis_history").select(
95
+ "*", count="exact").execute().count
96
+ estimated_mb = (predictions_count * 1.0 +
97
+ history_count * 0.5) / 1024
98
+ limit_mb = 500
99
+ usage_percent = (estimated_mb / limit_mb) * 100
100
+ result = {
101
+ "predictions_count": predictions_count,
102
+ "history_count": history_count,
103
+ "estimated_storage_mb": round(estimated_mb, 2),
104
+ "limit_mb": limit_mb,
105
+ "usage_percent": round(usage_percent, 2),
106
+ "warning": None
107
+ }
108
+ if usage_percent >= 90:
109
+ warning = f"Storage usage at {usage_percent:.1f}% ({estimated_mb:.1f}MB / {limit_mb}MB). Consider archiving old data."
110
+ result["warning"] = warning
111
+ logger.warning(warning)
112
+ elif usage_percent >= 75:
113
+ logger.info(
114
+ f"Storage usage at {usage_percent:.1f}% ({estimated_mb:.1f}MB / {limit_mb}MB)")
115
+ return result
116
+ except Exception as e:
117
+ logger.error(f"Failed to check storage usage: {e}")
118
+ return {"error": str(e), "warning": "Unable to check storage usage"}
119
 
120
  def get_feedback_for_training(self, limit: int = 1000) -> List[Dict[str, Any]]:
121
  response = self.client.table("feedback").select(
122
  "*").limit(limit).execute()
123
  return response.data
124
 
125
+ @retry_with_exponential_backoff(max_retries=3)
126
+ def store_user_history(self, session_id: str, article_id: str, text: str,
127
+ predicted_label: str, confidence: float, model_name: str) -> Dict[str, Any]:
128
+ try:
129
+ uuid.UUID(session_id)
130
+ except (ValueError, AttributeError) as e:
131
+ logger.error(f"Invalid session_id format: {e}")
132
+ raise ValueError(f"session_id must be a valid UUID: {e}")
133
+
134
+ data = {
135
+ "session_id": session_id,
136
+ "article_id": article_id,
137
+ "text_preview": text[:200],
138
+ "predicted_label": predicted_label,
139
+ "confidence": confidence,
140
+ "model_name": model_name,
141
+ "created_at": datetime.now(timezone.utc).isoformat()
142
+ }
143
+ try:
144
+ response = self.client.table(
145
+ "user_analysis_history").insert(data).execute()
146
+ logger.info(f"Stored user history for session {session_id}")
147
+ return response.data
148
+ except Exception as e:
149
+ logger.error(f"Failed to store user history: {e}")
150
+ raise
151
+
152
+ @retry_with_exponential_backoff(max_retries=3)
153
+ def get_user_history(self, session_id: str, limit: int = 100) -> List[Dict[str, Any]]:
154
+ try:
155
+ uuid.UUID(session_id)
156
+ except (ValueError, AttributeError) as e:
157
+ logger.error(f"Invalid session_id format: {e}")
158
+ raise ValueError(f"session_id must be a valid UUID: {e}")
159
+
160
+ try:
161
+ response = (
162
+ self.client.table("user_analysis_history")
163
+ .select("*")
164
+ .eq("session_id", session_id)
165
+ .order("created_at", desc=True)
166
+ .limit(limit)
167
+ .execute()
168
+ )
169
+ logger.info(
170
+ f"Retrieved {len(response.data)} history records for session {session_id}")
171
+ return response.data
172
+ except Exception as e:
173
+ logger.error(f"Failed to retrieve user history: {e}")
174
+ raise
175
+
176
 
177
  _supabase_client: Optional[SupabaseClient] = None
178
 
 
185
 
186
 
187
  def reset_client():
 
188
  global _supabase_client
189
  _supabase_client = None