aviseth commited on
Commit
06e73d2
·
1 Parent(s): 16da8ce

Initial deployment

Browse files
.dockerignore ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ models/
2
+ data/
3
+ notebooks/
4
+ frontend/
5
+ venv/
6
+ .git/
7
+ .vscode/
8
+ __pycache__/
9
+ *.pyc
10
+ *.pyo
11
+ *.pyd
12
+ .env
13
+ .env.example
14
+ *.log
15
+ .DS_Store
16
+ README.md
17
+ docker-compose.yml
Dockerfile ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.9-slim
2
+
3
+ WORKDIR /app
4
+
5
+ RUN apt-get update && apt-get install -y \
6
+ build-essential \
7
+ curl \
8
+ git \
9
+ && rm -rf /var/lib/apt/lists/*
10
+
11
+ COPY requirements.txt .
12
+ RUN pip install --no-cache-dir -r requirements.txt
13
+
14
+ COPY src/ ./src/
15
+ COPY scripts/ ./scripts/
16
+ COPY .env.example .env.example
17
+
18
+ # Download models from HuggingFace Hub at build time
19
+ RUN mkdir -p models && \
20
+ huggingface-cli download aviseth/distilbert-fakenews --local-dir models/distilbert --exclude "checkpoints/*" && \
21
+ huggingface-cli download aviseth/roberta-fakenews --local-dir models/roberta --exclude "checkpoints/*" && \
22
+ huggingface-cli download aviseth/xlnet-fakenews --local-dir models/xlnet --exclude "checkpoints/*"
23
+
24
+ # HuggingFace Spaces uses port 7860
25
+ ENV PORT=7860
26
+ EXPOSE 7860
27
+
28
+ CMD uvicorn src.api.main:app --host 0.0.0.0 --port ${PORT}
README.md CHANGED
@@ -1,10 +1,53 @@
1
  ---
2
- title: Fake News Api
3
- emoji: 🌍
4
- colorFrom: purple
5
- colorTo: pink
6
  sdk: docker
7
  pinned: false
8
  ---
9
 
10
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: Fake News Detection API
3
+ emoji: 🔍
4
+ colorFrom: orange
5
+ colorTo: red
6
  sdk: docker
7
  pinned: false
8
  ---
9
 
10
+ # Fake News Detection API
11
+
12
+ Multi-class fake news detection using fine-tuned DistilBERT, RoBERTa, and XLNet models.
13
+
14
+ Classifies news articles into: **True** · **Fake** · **Satire** · **Bias**
15
+
16
+ ## Features
17
+
18
+ - 3 transformer models (DistilBERT, RoBERTa, XLNet) trained on 80k+ articles
19
+ - Real-time explainability via gradient saliency + SHAP
20
+ - Live news integration via GNews API
21
+ - Prediction statistics and user feedback collection
22
+ - FastAPI backend with Swagger docs at `/docs`
23
+
24
+ ## Endpoints
25
+
26
+ - `POST /predict` — classify text as True / Fake / Satire / Bias
27
+ - `POST /explain` — gradient saliency + SHAP explainability
28
+ - `GET /news` — live news via GNews
29
+ - `GET /news/newspaper` — news grouped by predicted label
30
+ - `POST /feedback` — submit label corrections
31
+ - `GET /stats` — prediction statistics
32
+ - `GET /health` — health check
33
+ - `GET /docs` — Swagger UI
34
+
35
+ ## Environment Variables
36
+
37
+ Set these in your Space settings:
38
+
39
+ ```
40
+ SUPABASE_URL=your_supabase_url
41
+ SUPABASE_KEY=your_supabase_anon_key
42
+ SUPABASE_SERVICE_KEY=your_supabase_service_key
43
+ GNEWS_API_KEY=your_gnews_api_key
44
+ ALLOWED_ORIGINS=https://your-frontend.vercel.app
45
+ ```
46
+
47
+ ## Models
48
+
49
+ Models are automatically downloaded from:
50
+
51
+ - [aviseth/distilbert-fakenews](https://huggingface.co/aviseth/distilbert-fakenews)
52
+ - [aviseth/roberta-fakenews](https://huggingface.co/aviseth/roberta-fakenews)
53
+ - [aviseth/xlnet-fakenews](https://huggingface.co/aviseth/xlnet-fakenews)
requirements.txt ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Core ML
2
+ torch>=2.0.0
3
+ transformers>=4.30.0
4
+ datasets>=2.12.0
5
+ scikit-learn>=1.3.0
6
+ accelerate>=0.26.0
7
+
8
+ # Backend API
9
+ fastapi>=0.100.0
10
+ uvicorn[standard]>=0.23.0
11
+ pydantic>=2.0.0
12
+ python-multipart>=0.0.6
13
+
14
+ # Database / Supabase
15
+ supabase>=2.0.0
16
+ postgrest>=0.10.0
17
+ sqlalchemy>=2.0.0
18
+ psycopg2-binary>=2.9.0
19
+
20
+ # Data Processing
21
+ pandas>=2.0.0
22
+ numpy>=1.24.0
23
+ nltk>=3.8.0
24
+ spacy>=3.6.0
25
+
26
+ # Explainability
27
+ shap>=0.42.0
28
+ lime>=0.2.0
29
+
30
+ # News / Web
31
+ requests>=2.31.0
32
+ beautifulsoup4>=4.12.0
33
+ newspaper3k>=0.2.8
34
+
35
+ # MLOps
36
+ wandb>=0.15.0
37
+
38
+ # Utilities
39
+ python-dotenv>=1.0.0
40
+ pyyaml>=6.0
41
+ tqdm>=4.65.0
42
+
43
+ # Testing
44
+ pytest>=7.4.0
45
+ pytest-asyncio>=0.21.0
46
+
47
+ # Visualization
48
+ matplotlib>=3.7.0
49
+ seaborn>=0.12.0
scripts/download_models.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Downloads DistilBERT, RoBERTa, and XLNet base models from Hugging Face
3
+ and saves them to the models/ directory with the correct label configuration.
4
+ """
5
+
6
+ from pathlib import Path
7
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoConfig
8
+
9
+ MODELS = {
10
+ "distilbert": {"name": "distilbert-base-uncased", "description": "66M parameters"},
11
+ "roberta": {"name": "roberta-base", "description": "125M parameters"},
12
+ "xlnet": {"name": "xlnet-base-cased", "description": "110M parameters"},
13
+ }
14
+
15
+ LABEL_MAP = {0: "True", 1: "Fake", 2: "Satire", 3: "Bias"}
16
+
17
+
18
+ def download_model(model_key: str, model_info: dict, base_dir: Path) -> bool:
19
+ model_name = model_info["name"]
20
+ save_path = base_dir / model_key
21
+
22
+ print(f"\n{'='*60}")
23
+ print(
24
+ f"Downloading: {model_key} — {model_name} ({model_info['description']})")
25
+ print(f"{'='*60}\n")
26
+
27
+ try:
28
+ save_path.mkdir(parents=True, exist_ok=True)
29
+
30
+ print("[1/3] Tokenizer…")
31
+ AutoTokenizer.from_pretrained(model_name).save_pretrained(save_path)
32
+
33
+ print("[2/3] Config…")
34
+ config = AutoConfig.from_pretrained(
35
+ model_name,
36
+ num_labels=4,
37
+ id2label=LABEL_MAP,
38
+ label2id={v: k for k, v in LABEL_MAP.items()},
39
+ )
40
+ config.save_pretrained(save_path)
41
+
42
+ print("[3/3] Model weights…")
43
+ AutoModelForSequenceClassification.from_pretrained(
44
+ model_name, config=config).save_pretrained(save_path)
45
+
46
+ with open(save_path / "model_info.txt", "w") as f:
47
+ f.write(
48
+ f"Model: {model_name}\nParameters: {model_info['description']}\nLabels: {LABEL_MAP}\nStatus: pre-trained\n")
49
+
50
+ print(f"✅ {model_key} saved to {save_path}\n")
51
+ return True
52
+
53
+ except Exception as e:
54
+ print(f"❌ {model_key} failed: {e}\n")
55
+ return False
56
+
57
+
58
+ def main():
59
+ models_dir = Path(__file__).parent.parent / "models"
60
+ models_dir.mkdir(parents=True, exist_ok=True)
61
+
62
+ results = {key: download_model(key, info, models_dir)
63
+ for key, info in MODELS.items()}
64
+
65
+ print("\n" + "=" * 60)
66
+ print("SUMMARY")
67
+ print("=" * 60)
68
+ for key, ok in results.items():
69
+ print(f" {key:15} {'✅' if ok else '❌'}")
70
+ print(f"\n{sum(results.values())}/{len(results)} models downloaded")
71
+ print("=" * 60 + "\n")
72
+
73
+
74
+ if __name__ == "__main__":
75
+ main()
scripts/setup_environment.bat ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ @echo off
2
+ REM ============================================================
3
+ REM Fake News Detection - Environment Setup
4
+ REM Run from project root: scripts\setup_environment.bat
5
+ REM ============================================================
6
+
7
+ REM Move to project root (one level up from scripts/)
8
+ cd /d "%~dp0.."
9
+
10
+ echo.
11
+ echo ============================================================
12
+ echo FAKE NEWS DETECTION - ENVIRONMENT SETUP
13
+ echo ============================================================
14
+ echo.
15
+
16
+ REM Check Python
17
+ echo [1/5] Checking Python...
18
+ python --version >nul 2>&1
19
+ if errorlevel 1 (
20
+ echo [ERROR] Python not found. Install from https://www.python.org/
21
+ pause & exit /b 1
22
+ )
23
+ python --version
24
+ echo.
25
+
26
+ REM Handle existing venv
27
+ if exist venv (
28
+ echo [INFO] Virtual environment already exists.
29
+ set /p recreate="Recreate it? (y/n): "
30
+ if /i "%recreate%"=="y" (
31
+ echo Removing old venv...
32
+ rmdir /s /q venv
33
+ ) else (
34
+ goto :activate_venv
35
+ )
36
+ )
37
+
38
+ REM Create venv
39
+ echo [2/5] Creating virtual environment...
40
+ python -m venv venv
41
+ if errorlevel 1 ( echo [ERROR] Failed to create venv & pause & exit /b 1 )
42
+ echo [OK] venv created at %CD%\venv
43
+ echo.
44
+
45
+ :activate_venv
46
+ echo [3/5] Activating virtual environment...
47
+ call venv\Scripts\activate.bat
48
+ if errorlevel 1 ( echo [ERROR] Failed to activate venv & pause & exit /b 1 )
49
+ echo [OK] Activated
50
+ echo.
51
+
52
+ REM Upgrade pip
53
+ echo [4/5] Upgrading pip...
54
+ python -m pip install --upgrade pip --quiet
55
+ echo [OK] pip upgraded
56
+ echo.
57
+
58
+ REM Install requirements
59
+ echo [5/5] Installing requirements.txt...
60
+ echo (This takes a few minutes on first run)
61
+ echo.
62
+ pip install -r requirements.txt
63
+ if errorlevel 1 (
64
+ echo.
65
+ echo [ERROR] Some packages failed. Common fixes:
66
+ echo - Run as Administrator
67
+ echo - Install Visual C++ Build Tools: https://visualstudio.microsoft.com/visual-cpp-build-tools/
68
+ echo - Check internet connection
69
+ pause & exit /b 1
70
+ )
71
+
72
+ echo.
73
+ echo ============================================================
74
+ echo DONE - Virtual environment ready
75
+ echo ============================================================
76
+ echo.
77
+ echo Location : %CD%\venv
78
+ python --version
79
+ echo.
80
+ echo Key packages installed:
81
+ pip list --format=columns | findstr /C:"torch" /C:"transformers" /C:"fastapi" /C:"supabase" /C:"wandb"
82
+ echo.
83
+ echo ============================================================
84
+ echo NEXT STEPS
85
+ echo ============================================================
86
+ echo.
87
+ echo 1. Download base models (run once):
88
+ echo python scripts\download_models.py
89
+ echo.
90
+ echo 2. Run Supabase SQL schema:
91
+ echo Open Supabase dashboard ^> SQL Editor ^> paste scripts\setup_supabase.sql
92
+ echo.
93
+ echo 3. Test connections:
94
+ echo python scripts\test_connections.py
95
+ echo.
96
+ echo 4. Start API:
97
+ echo uvicorn src.api.main:app --reload
98
+ echo.
99
+ echo To activate venv in future sessions:
100
+ echo venv\Scripts\activate
101
+ echo.
102
+ pause
scripts/setup_environment.sh ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # ============================================================
3
+ # Fake News Detection - Environment Setup Script
4
+ # This script creates virtual environment and installs all dependencies
5
+ # ============================================================
6
+
7
+ set -e # Exit on error
8
+
9
+ echo ""
10
+ echo "============================================================"
11
+ echo "FAKE NEWS DETECTION - ENVIRONMENT SETUP"
12
+ echo "============================================================"
13
+ echo ""
14
+
15
+ # Colors
16
+ GREEN='\033[0;32m'
17
+ RED='\033[0;31m'
18
+ YELLOW='\033[1;33m'
19
+ NC='\033[0m' # No Color
20
+
21
+ # Check if Python is installed
22
+ echo "[1/5] Checking Python installation..."
23
+ if ! command -v python3 &> /dev/null; then
24
+ if ! command -v python &> /dev/null; then
25
+ echo -e "${RED}[ERROR] Python is not installed${NC}"
26
+ echo "Please install Python 3.9 or higher"
27
+ exit 1
28
+ else
29
+ PYTHON_CMD=python
30
+ fi
31
+ else
32
+ PYTHON_CMD=python3
33
+ fi
34
+
35
+ $PYTHON_CMD --version
36
+ echo -e "${GREEN}[SUCCESS] Python is installed${NC}"
37
+ echo ""
38
+
39
+ # Check if virtual environment already exists
40
+ if [ -d "venv" ]; then
41
+ echo -e "${YELLOW}[WARNING] Virtual environment already exists${NC}"
42
+ read -p "Do you want to recreate it? (y/n): " recreate
43
+ if [[ $recreate =~ ^[Yy]$ ]]; then
44
+ echo "[2/5] Removing existing virtual environment..."
45
+ rm -rf venv
46
+ echo -e "${GREEN}[SUCCESS] Removed existing virtual environment${NC}"
47
+ else
48
+ echo "[INFO] Using existing virtual environment"
49
+ fi
50
+ fi
51
+
52
+ # Create virtual environment if it doesn't exist
53
+ if [ ! -d "venv" ]; then
54
+ echo "[2/5] Creating virtual environment..."
55
+ $PYTHON_CMD -m venv venv
56
+ echo -e "${GREEN}[SUCCESS] Virtual environment created${NC}"
57
+ echo ""
58
+ else
59
+ echo "[2/5] Virtual environment already exists"
60
+ echo ""
61
+ fi
62
+
63
+ # Activate virtual environment
64
+ echo "[3/5] Activating virtual environment..."
65
+ source venv/bin/activate
66
+ echo -e "${GREEN}[SUCCESS] Virtual environment activated${NC}"
67
+ echo ""
68
+
69
+ # Upgrade pip
70
+ echo "[4/5] Upgrading pip..."
71
+ pip install --upgrade pip --quiet
72
+ echo -e "${GREEN}[SUCCESS] Pip upgraded${NC}"
73
+ echo ""
74
+
75
+ # Install requirements
76
+ echo "[5/5] Installing dependencies from requirements.txt..."
77
+ echo "This may take a few minutes..."
78
+ echo ""
79
+
80
+ if pip install -r requirements.txt; then
81
+ echo ""
82
+ echo "============================================================"
83
+ echo "INSTALLATION COMPLETE"
84
+ echo "============================================================"
85
+ echo ""
86
+ echo -e "${GREEN}[SUCCESS] All dependencies installed successfully!${NC}"
87
+ echo ""
88
+ echo "Virtual environment location: $(pwd)/venv"
89
+ echo "Python version: $($PYTHON_CMD --version)"
90
+ echo ""
91
+ echo "Installed packages:"
92
+ pip list | grep -E "torch|transformers|fastapi|supabase"
93
+ echo ""
94
+ else
95
+ echo ""
96
+ echo -e "${RED}[ERROR] Failed to install some dependencies${NC}"
97
+ echo "Please check the error messages above"
98
+ echo ""
99
+ echo "Common solutions:"
100
+ echo "1. Make sure you have internet connection"
101
+ echo "2. Check if requirements.txt exists"
102
+ echo "3. Install build tools if needed"
103
+ echo ""
104
+ exit 1
105
+ fi
106
+
107
+ echo "============================================================"
108
+ echo "NEXT STEPS"
109
+ echo "============================================================"
110
+ echo ""
111
+ echo "1. Virtual environment is already activated"
112
+ echo ""
113
+ echo "2. Download models from Hugging Face:"
114
+ echo " python scripts/download_models.py"
115
+ echo ""
116
+ echo "3. Setup Supabase database:"
117
+ echo " - Open Supabase dashboard"
118
+ echo " - Run scripts/setup_supabase.sql in SQL Editor"
119
+ echo ""
120
+ echo "4. Test connections:"
121
+ echo " python scripts/test_connections.py"
122
+ echo ""
123
+ echo "5. Start the API server:"
124
+ echo " uvicorn src.api.main:app --reload"
125
+ echo ""
126
+ echo "============================================================"
127
+ echo ""
128
+ echo "To activate virtual environment in future sessions:"
129
+ echo " source venv/bin/activate"
130
+ echo ""
131
+ echo "To deactivate:"
132
+ echo " deactivate"
133
+ echo ""
134
+ echo "============================================================"
135
+ echo ""
scripts/setup_supabase.sql ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ DROP TABLE IF EXISTS feedback CASCADE;
2
+ DROP TABLE IF EXISTS predictions CASCADE;
3
+ DROP TABLE IF EXISTS news_articles CASCADE;
4
+ DROP TABLE IF EXISTS model_performance CASCADE;
5
+ DROP TABLE IF EXISTS user_sessions CASCADE;
6
+
7
+ DROP VIEW IF EXISTS prediction_stats CASCADE;
8
+ DROP VIEW IF EXISTS daily_predictions CASCADE;
9
+ DROP VIEW IF EXISTS feedback_accuracy CASCADE;
10
+ DROP VIEW IF EXISTS model_comparison CASCADE;
11
+
12
+ CREATE EXTENSION IF NOT EXISTS "uuid-ossp";
13
+
14
+ CREATE TABLE predictions (
15
+ id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
16
+ article_id VARCHAR NOT NULL UNIQUE,
17
+ text TEXT,
18
+ predicted_label VARCHAR(50) NOT NULL,
19
+ confidence FLOAT NOT NULL,
20
+ model_name VARCHAR(100) NOT NULL,
21
+ explanation JSONB,
22
+ created_at TIMESTAMPTZ DEFAULT NOW()
23
+ );
24
+
25
+ CREATE INDEX idx_pred_created ON predictions(created_at DESC);
26
+ CREATE INDEX idx_pred_label ON predictions(predicted_label);
27
+ CREATE INDEX idx_pred_model ON predictions(model_name);
28
+
29
+ CREATE TABLE feedback (
30
+ id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
31
+ article_id VARCHAR NOT NULL,
32
+ predicted_label VARCHAR(50) NOT NULL,
33
+ actual_label VARCHAR(50) NOT NULL,
34
+ user_comment TEXT,
35
+ created_at TIMESTAMPTZ DEFAULT NOW()
36
+ );
37
+
38
+ CREATE INDEX idx_fb_created ON feedback(created_at DESC);
39
+ CREATE INDEX idx_fb_article ON feedback(article_id);
40
+
41
+ CREATE TABLE news_articles (
42
+ id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
43
+ title TEXT NOT NULL,
44
+ description TEXT,
45
+ content TEXT,
46
+ url TEXT NOT NULL UNIQUE,
47
+ image_url TEXT,
48
+ published_at TIMESTAMPTZ,
49
+ source_name VARCHAR(255),
50
+ source_url TEXT,
51
+ fetched_at TIMESTAMPTZ DEFAULT NOW(),
52
+ analyzed BOOLEAN DEFAULT FALSE,
53
+ prediction_id UUID
54
+ );
55
+
56
+ CREATE INDEX idx_news_published ON news_articles(published_at DESC);
57
+ CREATE INDEX idx_news_analyzed ON news_articles(analyzed);
58
+
59
+ CREATE TABLE model_performance (
60
+ id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
61
+ model_name VARCHAR(100) NOT NULL,
62
+ accuracy FLOAT,
63
+ precision FLOAT,
64
+ recall FLOAT,
65
+ f1_score FLOAT,
66
+ total_predictions INTEGER DEFAULT 0,
67
+ correct_predictions INTEGER DEFAULT 0,
68
+ evaluated_at TIMESTAMPTZ DEFAULT NOW()
69
+ );
70
+
71
+ CREATE TABLE user_sessions (
72
+ id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
73
+ session_id VARCHAR NOT NULL UNIQUE,
74
+ user_agent TEXT,
75
+ ip_address INET,
76
+ created_at TIMESTAMPTZ DEFAULT NOW(),
77
+ last_activity TIMESTAMPTZ DEFAULT NOW()
78
+ );
79
+
80
+ ALTER TABLE predictions DISABLE ROW LEVEL SECURITY;
81
+ ALTER TABLE feedback DISABLE ROW LEVEL SECURITY;
82
+ ALTER TABLE news_articles DISABLE ROW LEVEL SECURITY;
83
+ ALTER TABLE model_performance DISABLE ROW LEVEL SECURITY;
84
+ ALTER TABLE user_sessions DISABLE ROW LEVEL SECURITY;
85
+
86
+ CREATE VIEW prediction_stats AS
87
+ SELECT predicted_label, COUNT(*) AS total_count, AVG(confidence) AS avg_confidence
88
+ FROM predictions
89
+ GROUP BY predicted_label;
90
+
91
+ CREATE VIEW feedback_accuracy AS
92
+ SELECT predicted_label, actual_label, COUNT(*) AS count
93
+ FROM feedback
94
+ GROUP BY predicted_label, actual_label
95
+ ORDER BY count DESC;
src/__init__.py ADDED
File without changes
src/__pycache__/__init__.cpython-313.pyc ADDED
Binary file (192 Bytes). View file
 
src/api/__init__.py ADDED
File without changes
src/api/__pycache__/__init__.cpython-313.pyc ADDED
Binary file (196 Bytes). View file
 
src/api/__pycache__/main.cpython-313.pyc ADDED
Binary file (17 kB). View file
 
src/api/main.py ADDED
@@ -0,0 +1,355 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException, BackgroundTasks
2
+ from fastapi.middleware.cors import CORSMiddleware
3
+ from pydantic import BaseModel
4
+ from typing import Optional, List, Dict
5
+ import os
6
+ import uuid
7
+ from dotenv import load_dotenv
8
+
9
+ from src.utils.supabase_client import get_supabase_client
10
+ from src.utils.gnews_client import get_gnews_client
11
+
12
+ load_dotenv()
13
+
14
+ app = FastAPI(
15
+ title="Fake News Detection API",
16
+ description="Multi-class fake news detection using DistilBERT, RoBERTa, and XLNet",
17
+ version="1.0.0",
18
+ docs_url="/docs",
19
+ redoc_url="/redoc",
20
+ )
21
+
22
+ allowed_origins = [
23
+ o.strip()
24
+ for o in os.getenv(
25
+ "ALLOWED_ORIGINS", "http://localhost:3000,http://localhost:5173"
26
+ ).split(",")
27
+ if o.strip()
28
+ ]
29
+ app.add_middleware(
30
+ CORSMiddleware,
31
+ allow_origins=allowed_origins,
32
+ allow_credentials=True,
33
+ allow_methods=["*"],
34
+ allow_headers=["*"],
35
+ )
36
+
37
+ VALID_MODELS = {"distilbert", "roberta", "xlnet"}
38
+
39
+
40
+ class PredictionRequest(BaseModel):
41
+ text: Optional[str] = None
42
+ url: Optional[str] = None
43
+ model: Optional[str] = "distilbert"
44
+
45
+
46
+ class ExplanationData(BaseModel):
47
+ token: str
48
+ score: float
49
+
50
+
51
+ class PredictionResponse(BaseModel):
52
+ article_id: str
53
+ label: str
54
+ confidence: float
55
+ scores: dict
56
+ model_used: str
57
+ explanation: List[ExplanationData]
58
+
59
+
60
+ class FeedbackRequest(BaseModel):
61
+ article_id: str
62
+ predicted_label: str
63
+ actual_label: str
64
+ user_comment: Optional[str] = None
65
+
66
+
67
+ class ExplainRequest(BaseModel):
68
+ text: str
69
+ model: Optional[str] = "distilbert"
70
+ deep: Optional[bool] = False
71
+
72
+
73
+ @app.on_event("startup")
74
+ async def startup_event():
75
+ try:
76
+ get_supabase_client()
77
+ print("✅ Supabase connected")
78
+ except Exception as e:
79
+ print(f"⚠️ Supabase: {e}")
80
+ try:
81
+ get_gnews_client()
82
+ print("✅ GNews API connected")
83
+ except Exception as e:
84
+ print(f"⚠️ GNews: {e}")
85
+ print("🚀 API server started")
86
+
87
+
88
+ @app.get("/")
89
+ async def root():
90
+ return {
91
+ "message": "Fake News Detection API",
92
+ "status": "running",
93
+ "version": "1.0.0",
94
+ "models": list(VALID_MODELS),
95
+ }
96
+
97
+
98
+ @app.get("/health")
99
+ async def health_check():
100
+ status = {"api": "healthy", "supabase": "unknown", "gnews": "unknown"}
101
+ try:
102
+ get_supabase_client()
103
+ status["supabase"] = "healthy"
104
+ except Exception as e:
105
+ status["supabase"] = f"unhealthy: {e}"
106
+ try:
107
+ get_gnews_client()
108
+ status["gnews"] = "healthy"
109
+ except Exception as e:
110
+ status["gnews"] = f"unhealthy: {e}"
111
+ return status
112
+
113
+
114
+ @app.post("/predict", response_model=PredictionResponse)
115
+ async def predict(request: PredictionRequest, background_tasks: BackgroundTasks):
116
+ """Classify news as True / Fake / Satire / Bias."""
117
+ if not request.text and not request.url:
118
+ raise HTTPException(status_code=400, detail="Provide text or url")
119
+
120
+ model_key = request.model if request.model in VALID_MODELS else "distilbert"
121
+ article_id = str(uuid.uuid4())
122
+
123
+ text = request.text or ""
124
+ if not text and request.url:
125
+ try:
126
+ import requests as req
127
+ from bs4 import BeautifulSoup
128
+ r = req.get(request.url, timeout=10)
129
+ soup = BeautifulSoup(r.text, "html.parser")
130
+ text = " ".join(p.get_text() for p in soup.find_all("p"))[:4000]
131
+ except Exception as e:
132
+ raise HTTPException(
133
+ status_code=422, detail=f"Could not fetch URL: {e}")
134
+
135
+ if len(text.strip()) < 10:
136
+ raise HTTPException(
137
+ status_code=422, detail="Text too short to classify")
138
+
139
+ try:
140
+ from src.models.inference import predict as run_inference
141
+ result = run_inference(text, model_key)
142
+ except Exception as e:
143
+ raise HTTPException(status_code=500, detail=f"Inference error: {e}")
144
+
145
+ response = PredictionResponse(
146
+ article_id=article_id,
147
+ label=result["label"],
148
+ confidence=result["confidence"],
149
+ scores=result["scores"],
150
+ model_used=model_key,
151
+ explanation=[ExplanationData(**t) for t in result.get("tokens", [])],
152
+ )
153
+
154
+ def _store():
155
+ try:
156
+ supabase = get_supabase_client()
157
+ supabase.store_prediction(
158
+ article_id=article_id,
159
+ text=text,
160
+ predicted_label=result["label"],
161
+ confidence=result["confidence"],
162
+ model_name=model_key,
163
+ explanation=result.get("tokens", []),
164
+ )
165
+ except Exception as e:
166
+ print(f"[bg] store_prediction failed: {e}")
167
+
168
+ background_tasks.add_task(_store)
169
+ return response
170
+
171
+
172
+ @app.post("/feedback")
173
+ async def submit_feedback(feedback: FeedbackRequest):
174
+ """Submit user correction for active learning."""
175
+ try:
176
+ supabase = get_supabase_client()
177
+ result = supabase.store_feedback(
178
+ article_id=feedback.article_id,
179
+ predicted_label=feedback.predicted_label,
180
+ actual_label=feedback.actual_label,
181
+ user_comment=feedback.user_comment,
182
+ )
183
+ return {"status": "success", "message": "Feedback recorded", "data": result}
184
+ except Exception as e:
185
+ import traceback
186
+ print(f"[feedback] ERROR: {e}\n{traceback.format_exc()}")
187
+ raise HTTPException(
188
+ status_code=500, detail=f"Error storing feedback: {str(e)}")
189
+
190
+
191
+ @app.get("/news")
192
+ async def get_recent_news(
193
+ query: str = "breaking news",
194
+ max_results: int = 10,
195
+ category: Optional[str] = None,
196
+ ):
197
+ """Fetch recent articles from GNews."""
198
+ try:
199
+ gnews = get_gnews_client()
200
+ if category:
201
+ articles = gnews.get_top_headlines(
202
+ category=category, max_results=max_results)
203
+ else:
204
+ articles = gnews.search_news(query=query, max_results=max_results)
205
+ return {"status": "success", "count": len(articles), "articles": articles}
206
+ except Exception as e:
207
+ raise HTTPException(
208
+ status_code=500, detail=f"Error fetching news: {e}")
209
+
210
+
211
+ @app.get("/news/analyze")
212
+ async def analyze_recent_news(topic: str = "politics", max_articles: int = 5):
213
+ """Fetch and classify recent news articles."""
214
+ try:
215
+ gnews = get_gnews_client()
216
+ articles = gnews.search_news(query=topic, max_results=max_articles)
217
+
218
+ from src.models.inference import predict as run_inference
219
+ results = []
220
+ for article in articles:
221
+ text = article.get("content") or article.get(
222
+ "description") or article.get("title", "")
223
+ if len(text.strip()) < 10:
224
+ continue
225
+ try:
226
+ pred = run_inference(text, "distilbert")
227
+ results.append({"article": article, "prediction": pred})
228
+ except Exception:
229
+ results.append({"article": article, "prediction": None})
230
+
231
+ return {"status": "success", "topic": topic, "analyzed_count": len(results), "results": results}
232
+ except Exception as e:
233
+ raise HTTPException(
234
+ status_code=500, detail=f"Error analyzing news: {e}")
235
+
236
+
237
+ @app.get("/news/newspaper")
238
+ async def get_newspaper(max_per_topic: int = 6):
239
+ """Fetch and classify news across multiple topics, grouped by predicted label."""
240
+ topics = ["world news", "politics", "technology",
241
+ "science", "health", "business"]
242
+ try:
243
+ gnews = get_gnews_client()
244
+ from src.models.inference import predict as run_inference
245
+
246
+ all_results = []
247
+ seen_urls: set = set()
248
+
249
+ for topic in topics:
250
+ articles = gnews.search_news(
251
+ query=topic, max_results=max_per_topic)
252
+ for article in articles:
253
+ url = article.get("url", "")
254
+ if url in seen_urls:
255
+ continue
256
+ seen_urls.add(url)
257
+ text = article.get("content") or article.get(
258
+ "description") or article.get("title", "")
259
+ if len(text.strip()) < 10:
260
+ continue
261
+ try:
262
+ pred = run_inference(text, "distilbert")
263
+ all_results.append(
264
+ {"article": article, "prediction": pred})
265
+ except Exception:
266
+ all_results.append({"article": article, "prediction": {
267
+ "label": "True", "confidence": 0.5, "scores": {}, "tokens": []
268
+ }})
269
+
270
+ grouped: Dict[str, list] = {"True": [],
271
+ "Fake": [], "Satire": [], "Bias": []}
272
+ for item in all_results:
273
+ lbl = item["prediction"].get(
274
+ "label", "True") if item["prediction"] else "True"
275
+ if lbl in grouped:
276
+ grouped[lbl].append(item)
277
+
278
+ return {"status": "success", "total": len(all_results), "grouped": grouped}
279
+ except Exception as e:
280
+ raise HTTPException(
281
+ status_code=500, detail=f"Error building newspaper: {e}")
282
+
283
+
284
+ @app.post("/explain")
285
+ async def explain_prediction(request: ExplainRequest):
286
+ """
287
+ Return explainability data for a piece of text.
288
+ Always returns gradient saliency highlights. If deep=True, also runs SHAP via RoBERTa.
289
+ """
290
+ if len(request.text.strip()) < 10:
291
+ raise HTTPException(status_code=422, detail="Text too short")
292
+
293
+ model_key = request.model if request.model in VALID_MODELS else "distilbert"
294
+
295
+ try:
296
+ from src.models.inference import get_classifier
297
+ import asyncio
298
+
299
+ clf = get_classifier(model_key)
300
+ loop = asyncio.get_event_loop()
301
+
302
+ attention = await loop.run_in_executor(None, clf.attention_weights, request.text)
303
+
304
+ shap_tokens = []
305
+ explanation_text = ""
306
+ if request.deep:
307
+ shap_tokens = await loop.run_in_executor(None, clf.shap_explain, request.text)
308
+ if shap_tokens:
309
+ from src.models.inference import generate_explanation_text, predict as run_predict
310
+ pred = run_predict(request.text, model_key)
311
+ explanation_text = generate_explanation_text(
312
+ shap_tokens, pred["label"], pred["confidence"], model_key
313
+ )
314
+
315
+ return {"attention": attention, "shap": shap_tokens, "explanation_text": explanation_text, "model_used": model_key}
316
+ except Exception as e:
317
+ import traceback
318
+ print(f"[explain] ERROR: {e}\n{traceback.format_exc()}")
319
+ raise HTTPException(status_code=500, detail=f"Explain error: {e}")
320
+
321
+
322
+ @app.get("/stats")
323
+ async def get_statistics():
324
+ """Prediction statistics from Supabase."""
325
+ try:
326
+ supabase = get_supabase_client()
327
+ stats = supabase.get_prediction_stats()
328
+ return {"status": "success", "statistics": stats}
329
+ except Exception as e:
330
+ raise HTTPException(
331
+ status_code=500, detail=f"Error fetching stats: {e}")
332
+
333
+
334
+ @app.get("/models")
335
+ async def list_models():
336
+ """List available models and their training status."""
337
+ from pathlib import Path
338
+ models_dir = Path(__file__).parents[2] / "models"
339
+ available = []
340
+ for name in ["distilbert", "roberta", "xlnet"]:
341
+ path = models_dir / name
342
+ trained = (path / "config.json").exists()
343
+ available.append({"name": name, "trained": trained,
344
+ "path": str(path) if trained else None})
345
+ return {"models": available, "default": "distilbert"}
346
+
347
+
348
+ if __name__ == "__main__":
349
+ import uvicorn
350
+ uvicorn.run(
351
+ "src.api.main:app",
352
+ host=os.getenv("API_HOST", "0.0.0.0"),
353
+ port=int(os.getenv("API_PORT", 8000)),
354
+ reload=os.getenv("API_RELOAD", "true").lower() == "true",
355
+ )
src/data/__init__.py ADDED
File without changes
src/data/__pycache__/__init__.cpython-313.pyc ADDED
Binary file (197 Bytes). View file
 
src/data/__pycache__/dataset.cpython-313.pyc ADDED
Binary file (5.85 kB). View file
 
src/data/__pycache__/preprocessing.cpython-313.pyc ADDED
Binary file (1.28 kB). View file
 
src/data/dataset.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Dataset loader — reads Dataset_Clean.csv and returns tokenized HuggingFace DatasetDict splits.
3
+ """
4
+
5
+ import pandas as pd
6
+ from pathlib import Path
7
+ from datasets import Dataset, DatasetDict
8
+ from transformers import AutoTokenizer
9
+ from sklearn.model_selection import train_test_split
10
+ from src.data.preprocessing import clean_text
11
+
12
+ LABEL2ID = {"True": 0, "Fake": 1, "Satire": 2, "Bias": 3}
13
+ ID2LABEL = {v: k for k, v in LABEL2ID.items()}
14
+
15
+ DEFAULT_CSV = Path(__file__).parents[2] / \
16
+ "data" / "processed" / "Dataset_Clean.csv"
17
+ MAX_LENGTH = 256
18
+ VAL_SPLIT = 0.10
19
+ TEST_SPLIT = 0.10
20
+ RANDOM_SEED = 42
21
+
22
+
23
+ def load_dataframe(csv_path: str | Path = DEFAULT_CSV) -> pd.DataFrame:
24
+ """Load and clean Dataset_Clean.csv. Returns a DataFrame with columns: text, label (int)."""
25
+ df = pd.read_csv(csv_path, low_memory=False)
26
+ df["label_text"] = df["label_text"].astype(
27
+ str).str.strip().str.capitalize()
28
+ df = df[df["label_text"].isin(LABEL2ID)].copy()
29
+
30
+ df["content"] = df["content"].fillna("").astype(str)
31
+ df["title"] = df["title"].fillna("").astype(str)
32
+ df["text"] = df.apply(lambda r: r["content"] if len(
33
+ r["content"]) > 30 else r["title"], axis=1)
34
+ df["text"] = df["text"].apply(clean_text)
35
+ df = df[df["text"].str.len() > 10].copy()
36
+ df["label"] = df["label_text"].map(LABEL2ID).astype(int)
37
+
38
+ print(f"[dataset] Loaded {len(df):,} rows")
39
+ print(
40
+ f"[dataset] Label distribution:\n{df['label_text'].value_counts().to_string()}\n")
41
+ return df[["text", "label"]].reset_index(drop=True)
42
+
43
+
44
+ def make_splits(df: pd.DataFrame) -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
45
+ """Stratified train / val / test split."""
46
+ train_df, temp_df = train_test_split(
47
+ df, test_size=VAL_SPLIT + TEST_SPLIT, stratify=df["label"], random_state=RANDOM_SEED
48
+ )
49
+ val_df, test_df = train_test_split(
50
+ temp_df, test_size=TEST_SPLIT / (VAL_SPLIT + TEST_SPLIT),
51
+ stratify=temp_df["label"], random_state=RANDOM_SEED
52
+ )
53
+ print(
54
+ f"[dataset] Train: {len(train_df):,} Val: {len(val_df):,} Test: {len(test_df):,}")
55
+ return train_df, val_df, test_df
56
+
57
+
58
+ def tokenize_dataset(dataset_dict: DatasetDict, tokenizer_name: str, max_length: int = MAX_LENGTH) -> DatasetDict:
59
+ """Tokenize all splits."""
60
+ tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
61
+
62
+ def _tokenize(batch):
63
+ return tokenizer(batch["text"], padding="max_length", truncation=True, max_length=max_length)
64
+
65
+ tokenized = dataset_dict.map(_tokenize, batched=True, batch_size=512, remove_columns=[
66
+ "text"], desc="Tokenizing")
67
+ tokenized.set_format("torch")
68
+ return tokenized
69
+
70
+
71
+ def build_dataset(
72
+ csv_path: str | Path = DEFAULT_CSV,
73
+ tokenizer_name: str = "distilbert-base-uncased",
74
+ max_length: int = MAX_LENGTH,
75
+ ) -> DatasetDict:
76
+ """Full pipeline: CSV → cleaned DataFrame → HuggingFace DatasetDict → tokenized splits."""
77
+ df = load_dataframe(csv_path)
78
+ train_df, val_df, test_df = make_splits(df)
79
+
80
+ raw = DatasetDict({
81
+ "train": Dataset.from_pandas(train_df, preserve_index=False),
82
+ "validation": Dataset.from_pandas(val_df, preserve_index=False),
83
+ "test": Dataset.from_pandas(test_df, preserve_index=False),
84
+ })
85
+ return tokenize_dataset(raw, tokenizer_name, max_length)
86
+
87
+
88
+ if __name__ == "__main__":
89
+ ds = build_dataset()
90
+ print(ds)
91
+ print("Sample:", ds["train"][0])
src/data/gnews_collector.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Fetches live GNews articles and appends them to the training dataset.
3
+
4
+ Usage:
5
+ python -m src.data.gnews_collector # fetch and save
6
+ python -m src.data.gnews_collector --preview # print without saving
7
+ python -m src.data.gnews_collector --label --model-path models/distilbert --merge
8
+ """
9
+
10
+ from src.data.preprocessing import clean_text
11
+ from src.utils.gnews_client import GNewsClient
12
+ import os
13
+ import sys
14
+ import uuid
15
+ import argparse
16
+ import pandas as pd
17
+ from pathlib import Path
18
+ from datetime import datetime
19
+ from dotenv import load_dotenv
20
+
21
+ sys.path.insert(0, str(Path(__file__).parents[2]))
22
+ load_dotenv()
23
+
24
+
25
+ PROJECT_ROOT = Path(__file__).parents[2]
26
+ AUGMENTED_DIR = PROJECT_ROOT / "data" / "augmented"
27
+ CLEAN_CSV = PROJECT_ROOT / "data" / "processed" / "Dataset_Clean.csv"
28
+
29
+ FETCH_TOPICS = [
30
+ "scientific research breakthrough",
31
+ "official government announcement",
32
+ "verified breaking news",
33
+ "conspiracy theory debunked",
34
+ "fact check false claim",
35
+ "misinformation viral",
36
+ "satire news comedy",
37
+ "parody news article",
38
+ "political opinion editorial",
39
+ "partisan news analysis",
40
+ ]
41
+
42
+ MAX_PER_TOPIC = 5
43
+
44
+
45
+ def fetch_articles(max_per_topic: int = MAX_PER_TOPIC) -> list[dict]:
46
+ client = GNewsClient()
47
+ all_articles = []
48
+ seen_urls: set[str] = set()
49
+
50
+ for topic in FETCH_TOPICS:
51
+ try:
52
+ articles = client.search_news(
53
+ query=topic, max_results=max_per_topic)
54
+ for a in articles:
55
+ url = a.get("url", "")
56
+ if url and url not in seen_urls:
57
+ seen_urls.add(url)
58
+ all_articles.append(a)
59
+ print(f" ✓ '{topic}' → {len(articles)} articles")
60
+ except Exception as e:
61
+ print(f" ✗ '{topic}' → error: {e}")
62
+
63
+ print(f"\n[collector] Fetched {len(all_articles)} unique articles\n")
64
+ return all_articles
65
+
66
+
67
+ def articles_to_dataframe(articles: list[dict]) -> pd.DataFrame:
68
+ """Convert raw GNews articles to Dataset_Clean.csv schema. Labels are set to -1 (unlabelled)."""
69
+ rows = []
70
+ for a in articles:
71
+ title = clean_text(a.get("title", ""))
72
+ content = clean_text(a.get("content", "") or a.get("description", ""))
73
+ text = content if len(content) > 30 else title
74
+ if len(text) < 10:
75
+ continue
76
+ rows.append({
77
+ "id": f"GNEWS_{uuid.uuid4().hex[:8].upper()}",
78
+ "title": title,
79
+ "content": content,
80
+ "label": -1,
81
+ "label_text": "UNLABELLED",
82
+ "label_original": "gnews_live",
83
+ "source_dataset": "GNews_Live",
84
+ "topic": "",
85
+ "url": a.get("url", ""),
86
+ "speaker": a.get("source", ""),
87
+ "fetched_at": datetime.utcnow().isoformat(),
88
+ })
89
+ return pd.DataFrame(rows)
90
+
91
+
92
+ def pseudo_label(df: pd.DataFrame, model_path: str) -> pd.DataFrame:
93
+ """Assign pseudo-labels to unlabelled articles using a trained model."""
94
+ import torch
95
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
96
+
97
+ ID2LABEL = {0: "True", 1: "Fake", 2: "Satire", 3: "Bias"}
98
+ print(f"[pseudo_label] Loading model from {model_path}…")
99
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
100
+ model = AutoModelForSequenceClassification.from_pretrained(model_path)
101
+ model.eval()
102
+ device = "cuda" if torch.cuda.is_available() else "cpu"
103
+ model.to(device)
104
+
105
+ texts = df["content"].fillna(df["title"]).tolist()
106
+ labels = []
107
+ confidences = []
108
+
109
+ for i in range(0, len(texts), 16):
110
+ batch = texts[i: i + 16]
111
+ enc = tokenizer(batch, padding=True, truncation=True,
112
+ max_length=256, return_tensors="pt").to(device)
113
+ with torch.no_grad():
114
+ probs = torch.softmax(model(**enc).logits, dim=-1)
115
+ labels.extend(probs.argmax(dim=-1).cpu().tolist())
116
+ confidences.extend(probs.max(dim=-1).values.cpu().tolist())
117
+
118
+ df = df.copy()
119
+ df["label"] = labels
120
+ df["label_text"] = [ID2LABEL[l] for l in labels]
121
+ df["confidence"] = [round(c, 4) for c in confidences]
122
+ print(
123
+ f"[pseudo_label] Label distribution:\n{df['label_text'].value_counts().to_string()}")
124
+ return df
125
+
126
+
127
+ def save_augmented(df: pd.DataFrame, tag: str = "") -> Path:
128
+ AUGMENTED_DIR.mkdir(parents=True, exist_ok=True)
129
+ ts = datetime.utcnow().strftime("%Y%m%d_%H%M%S")
130
+ name = f"gnews_{ts}{('_' + tag) if tag else ''}.csv"
131
+ path = AUGMENTED_DIR / name
132
+ df.to_csv(path, index=False, encoding="utf-8")
133
+ print(f"[collector] Saved {len(df)} rows → {path}")
134
+ return path
135
+
136
+
137
+ def merge_into_training(augmented_path: Path, min_confidence: float = 0.80) -> int:
138
+ """Merge pseudo-labelled articles into Dataset_Clean.csv, filtered by confidence threshold."""
139
+ aug_df = pd.read_csv(augmented_path)
140
+ if "confidence" in aug_df.columns:
141
+ aug_df = aug_df[aug_df["confidence"] >= min_confidence]
142
+ aug_df = aug_df[aug_df["label"] != -1]
143
+
144
+ if len(aug_df) == 0:
145
+ print("[merge] No rows met the confidence threshold.")
146
+ return 0
147
+
148
+ keep_cols = ["id", "title", "content", "label", "label_text",
149
+ "label_original", "source_dataset", "topic", "url", "speaker"]
150
+ aug_df = aug_df[[c for c in keep_cols if c in aug_df.columns]]
151
+ aug_df.to_csv(CLEAN_CSV, mode="a", header=False,
152
+ index=False, encoding="utf-8")
153
+ print(f"[merge] Added {len(aug_df)} rows to {CLEAN_CSV}")
154
+ return len(aug_df)
155
+
156
+
157
+ def main():
158
+ parser = argparse.ArgumentParser()
159
+ parser.add_argument("--preview", action="store_true")
160
+ parser.add_argument("--label", action="store_true")
161
+ parser.add_argument("--model-path", type=str,
162
+ default="models/distilbert")
163
+ parser.add_argument("--merge", action="store_true")
164
+ parser.add_argument("--min-conf", type=float, default=0.80)
165
+ parser.add_argument("--max-per-topic", type=int, default=MAX_PER_TOPIC)
166
+ args = parser.parse_args()
167
+
168
+ articles = fetch_articles(max_per_topic=args.max_per_topic)
169
+ df = articles_to_dataframe(articles)
170
+
171
+ if args.preview:
172
+ print(df[["title", "source_dataset", "url"]].to_string())
173
+ return
174
+
175
+ raw_path = save_augmented(df, tag="raw")
176
+
177
+ if args.label:
178
+ model_path = str(PROJECT_ROOT / args.model_path)
179
+ if not Path(model_path).exists():
180
+ print(f"[error] Model not found at {model_path}")
181
+ return
182
+ df = pseudo_label(df, model_path)
183
+ labelled_path = save_augmented(df, tag="labelled")
184
+ if args.merge:
185
+ merge_into_training(labelled_path, min_confidence=args.min_conf)
186
+
187
+
188
+ if __name__ == "__main__":
189
+ main()
src/data/preprocessing.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import html
3
+ from typing import List
4
+
5
+
6
+ def clean_text(text: str) -> str:
7
+ """Clean and normalize raw text — decodes HTML, strips URLs, normalizes whitespace."""
8
+ text = html.unescape(text)
9
+ text = re.sub(r'http\S+', '', text)
10
+ text = text.replace('\u201c', '"').replace(
11
+ '\u201d', '"').replace('\u2013', '-')
12
+ text = re.sub(r'\s+', ' ', text).strip()
13
+ return text
14
+
15
+
16
+ def preprocess_batch(texts: List[str]) -> List[str]:
17
+ """Apply clean_text to a list of strings."""
18
+ return [clean_text(text) for text in texts]
src/models/__init__.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from src.models.inference import predict, get_classifier, FakeNewsClassifier
2
+ from src.models.train import train_model
3
+ from src.models.evaluate import compute_metrics, full_report
4
+
5
+ __all__ = [
6
+ "predict",
7
+ "get_classifier",
8
+ "FakeNewsClassifier",
9
+ "train_model",
10
+ "compute_metrics",
11
+ "full_report",
12
+ ]
src/models/__pycache__/__init__.cpython-313.pyc ADDED
Binary file (502 Bytes). View file
 
src/models/__pycache__/evaluate.cpython-313.pyc ADDED
Binary file (2.53 kB). View file
 
src/models/__pycache__/inference.cpython-313.pyc ADDED
Binary file (18.9 kB). View file
 
src/models/__pycache__/train.cpython-313.pyc ADDED
Binary file (8.63 kB). View file
 
src/models/evaluate.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Evaluation utilities — metrics computed during and after training.
3
+ """
4
+
5
+ import numpy as np
6
+ from sklearn.metrics import classification_report, confusion_matrix, f1_score, accuracy_score
7
+ from transformers import EvalPrediction
8
+
9
+ LABEL_NAMES = ["True", "Fake", "Satire", "Bias"]
10
+
11
+
12
+ def compute_metrics(eval_pred: EvalPrediction) -> dict:
13
+ """Called by HuggingFace Trainer after every eval step. Returns accuracy and macro/weighted F1."""
14
+ logits, labels = eval_pred
15
+ preds = np.argmax(logits, axis=-1)
16
+ return {
17
+ "accuracy": round(accuracy_score(labels, preds), 4),
18
+ "f1_macro": round(f1_score(labels, preds, average="macro", zero_division=0), 4),
19
+ "f1_weighted": round(f1_score(labels, preds, average="weighted", zero_division=0), 4),
20
+ }
21
+
22
+
23
+ def full_report(model, tokenized_test, label_names=LABEL_NAMES) -> dict:
24
+ """Run full evaluation on the test split. Returns per-class metrics and confusion matrix."""
25
+ from transformers import Trainer
26
+
27
+ trainer = Trainer(model=model, compute_metrics=compute_metrics)
28
+ preds_out = trainer.predict(tokenized_test)
29
+
30
+ preds = np.argmax(preds_out.predictions, axis=-1)
31
+ labels = preds_out.label_ids
32
+
33
+ report = classification_report(
34
+ labels, preds, target_names=label_names, output_dict=True, zero_division=0)
35
+ cm = confusion_matrix(labels, preds)
36
+
37
+ print("\n" + "=" * 60)
38
+ print("CLASSIFICATION REPORT")
39
+ print("=" * 60)
40
+ print(classification_report(labels, preds,
41
+ target_names=label_names, zero_division=0))
42
+ print("Confusion Matrix:")
43
+ print(cm)
44
+ print("=" * 60 + "\n")
45
+
46
+ return {"report": report, "confusion_matrix": cm.tolist()}
src/models/inference.py ADDED
@@ -0,0 +1,345 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Model inference — lazy-loads fine-tuned models and runs predictions with explainability.
3
+ """
4
+
5
+ import os
6
+ import torch
7
+ import numpy as np
8
+ from pathlib import Path
9
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
10
+ from dotenv import load_dotenv
11
+
12
+ load_dotenv()
13
+
14
+ ID2LABEL = {0: "True", 1: "Fake", 2: "Satire", 3: "Bias"}
15
+ LABEL2ID = {v: k for k, v in ID2LABEL.items()}
16
+
17
+ PROJECT_ROOT = Path(__file__).parents[2]
18
+ MODELS_DIR = PROJECT_ROOT / "models"
19
+
20
+ # Override with HF Hub repo IDs via env vars, e.g. HF_REPO_DISTILBERT=your-username/distilbert-fakenews
21
+ MODEL_NAMES = {
22
+ "distilbert": os.getenv("HF_REPO_DISTILBERT", "distilbert-base-uncased"),
23
+ "roberta": os.getenv("HF_REPO_ROBERTA", "roberta-base"),
24
+ "xlnet": os.getenv("HF_REPO_XLNET", "xlnet-base-cased"),
25
+ }
26
+
27
+
28
+ class FakeNewsClassifier:
29
+ """Wraps a fine-tuned HuggingFace model. Lazy-loads on first call and caches in memory."""
30
+
31
+ def __init__(self, model_key: str = "distilbert", max_length: int = 256):
32
+ self.model_key = model_key
33
+ self.max_length = max_length
34
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
35
+ self._model = None
36
+ self._tokenizer = None
37
+
38
+ def _load(self):
39
+ local_path = MODELS_DIR / self.model_key
40
+ source = str(local_path) if (
41
+ local_path / "config.json").exists() else MODEL_NAMES[self.model_key]
42
+ print(f"[inference] Loading {self.model_key} from: {source}")
43
+ self._tokenizer = AutoTokenizer.from_pretrained(source)
44
+ self._model = AutoModelForSequenceClassification.from_pretrained(
45
+ source,
46
+ num_labels=4,
47
+ id2label=ID2LABEL,
48
+ label2id=LABEL2ID,
49
+ ignore_mismatched_sizes=True,
50
+ )
51
+ self._model.to(self.device)
52
+ self._model.eval()
53
+ print(f"[inference] Model ready on {self.device}")
54
+
55
+ @property
56
+ def model(self):
57
+ if self._model is None:
58
+ self._load()
59
+ return self._model
60
+
61
+ @property
62
+ def tokenizer(self):
63
+ if self._tokenizer is None:
64
+ self._load()
65
+ return self._tokenizer
66
+
67
+ def predict(self, text: str) -> dict:
68
+ """
69
+ Run inference on a single text.
70
+ Returns label, confidence (0-1), per-class scores, and top token importance scores.
71
+ """
72
+ enc = self.tokenizer(
73
+ text,
74
+ return_tensors="pt",
75
+ truncation=True,
76
+ max_length=self.max_length,
77
+ padding=True,
78
+ ).to(self.device)
79
+
80
+ with torch.no_grad():
81
+ outputs = self.model(**enc)
82
+ probs = torch.softmax(outputs.logits, dim=-1)[0].cpu().numpy()
83
+
84
+ pred_id = int(np.argmax(probs))
85
+ label = ID2LABEL[pred_id]
86
+ confidence = float(probs[pred_id])
87
+ scores = {ID2LABEL[i]: round(float(p), 4) for i, p in enumerate(probs)}
88
+ tokens = self._token_importance(enc, pred_id)
89
+
90
+ return {
91
+ "label": label,
92
+ "confidence": round(confidence, 4),
93
+ "scores": scores,
94
+ "tokens": tokens,
95
+ }
96
+
97
+ def _token_importance(self, enc, pred_id: int, top_k: int = 8) -> list[dict]:
98
+ """Gradient saliency — returns top-k tokens sorted by importance."""
99
+ try:
100
+ self.model.zero_grad()
101
+ input_ids = enc["input_ids"]
102
+ embeds = self.model.get_input_embeddings()(
103
+ input_ids).detach().requires_grad_(True)
104
+ outputs = self.model(inputs_embeds=embeds,
105
+ attention_mask=enc.get("attention_mask"))
106
+ outputs.logits[0, pred_id].backward()
107
+ importance = embeds.grad[0].norm(dim=-1).cpu().numpy()
108
+ tokens = self.tokenizer.convert_ids_to_tokens(
109
+ input_ids[0].cpu().tolist())
110
+ special = {"[CLS]", "[SEP]", "[PAD]", "<s>",
111
+ "</s>", "<pad>", "<cls>", "<sep>", "▁", "Ġ"}
112
+ pairs = [
113
+ (t.replace("##", "").replace("▁", "").replace("Ġ", ""), float(s))
114
+ for t, s in zip(tokens, importance)
115
+ if t not in special and len(t.strip()) > 1
116
+ ]
117
+ if pairs:
118
+ max_s = max(s for _, s in pairs) or 1.0
119
+ pairs = [(t, round(s / max_s, 4)) for t, s in pairs]
120
+ pairs.sort(key=lambda x: x[1], reverse=True)
121
+ return [{"token": t, "score": s} for t, s in pairs[:top_k]]
122
+ except Exception:
123
+ return []
124
+
125
+ def attention_weights(self, text: str) -> list[dict]:
126
+ """
127
+ Gradient saliency mapped to original words in reading order.
128
+ Merges subword tokens (BERT ## and RoBERTa Ġ) back into full words.
129
+ """
130
+ try:
131
+ enc = self.tokenizer(
132
+ text,
133
+ return_tensors="pt",
134
+ truncation=True,
135
+ max_length=self.max_length,
136
+ padding=False,
137
+ ).to(self.device)
138
+
139
+ input_ids = enc["input_ids"]
140
+ self.model.zero_grad()
141
+ embeds = self.model.get_input_embeddings()(
142
+ input_ids).detach().requires_grad_(True)
143
+ outputs = self.model(inputs_embeds=embeds,
144
+ attention_mask=enc.get("attention_mask"))
145
+ pred_id = int(torch.argmax(outputs.logits, dim=-1)[0])
146
+ outputs.logits[0, pred_id].backward()
147
+ importance = embeds.grad[0].norm(dim=-1).cpu().numpy()
148
+
149
+ tokens = self.tokenizer.convert_ids_to_tokens(
150
+ input_ids[0].cpu().tolist())
151
+ SPECIAL = {"[CLS]", "[SEP]", "[PAD]", "<s>",
152
+ "</s>", "<pad>", "<cls>", "<sep>", "<unk>"}
153
+
154
+ words = []
155
+ current_word = ""
156
+ current_score = 0.0
157
+ for tok, score in zip(tokens, importance):
158
+ if tok in SPECIAL:
159
+ if current_word:
160
+ words.append((current_word, current_score))
161
+ current_word = ""
162
+ current_score = 0.0
163
+ continue
164
+
165
+ is_continuation = tok.startswith("##")
166
+ is_new_word = tok.startswith("Ġ") or tok.startswith("▁")
167
+ clean = tok.replace("##", "").replace("Ġ", "").replace("▁", "")
168
+
169
+ if is_continuation:
170
+ current_word += clean
171
+ current_score = max(current_score, float(score))
172
+ elif is_new_word:
173
+ if current_word:
174
+ words.append((current_word, current_score))
175
+ current_word = clean
176
+ current_score = float(score)
177
+ else:
178
+ if current_word:
179
+ words.append((current_word, current_score))
180
+ current_word = clean
181
+ current_score = float(score)
182
+
183
+ if current_word:
184
+ words.append((current_word, current_score))
185
+
186
+ if not words:
187
+ return []
188
+
189
+ max_s = max(s for _, s in words) or 1.0
190
+ return [{"word": w, "attention": round(s / max_s, 4)} for w, s in words if w.strip()]
191
+
192
+ except Exception as e:
193
+ print(f"[attention_weights] failed: {e}")
194
+ import traceback
195
+ traceback.print_exc()
196
+ return []
197
+
198
+ def shap_explain(self, text: str) -> list[dict]:
199
+ """
200
+ Word-level SHAP explanation using RoBERTa for better context.
201
+ Returns words sorted by absolute SHAP value, most influential first.
202
+ """
203
+ try:
204
+ import shap
205
+
206
+ clf = get_classifier("roberta")
207
+
208
+ def predict_proba(texts):
209
+ results = []
210
+ for t in texts:
211
+ enc = clf.tokenizer(
212
+ t, return_tensors="pt", truncation=True,
213
+ max_length=clf.max_length, padding=True,
214
+ ).to(clf.device)
215
+ with torch.no_grad():
216
+ logits = clf.model(**enc).logits
217
+ probs = torch.softmax(logits, dim=-1)[0].cpu().numpy()
218
+ results.append(probs)
219
+ return np.array(results)
220
+
221
+ masker = shap.maskers.Text(r"\W+")
222
+ explainer = shap.Explainer(
223
+ predict_proba, masker, output_names=list(ID2LABEL.values()))
224
+ shap_values = explainer([text], max_evals=200, batch_size=8)
225
+
226
+ enc = clf.tokenizer(text, return_tensors="pt", truncation=True,
227
+ max_length=clf.max_length).to(clf.device)
228
+ with torch.no_grad():
229
+ pred_id = int(torch.argmax(clf.model(**enc).logits, dim=-1)[0])
230
+
231
+ words = shap_values.data[0]
232
+ values = shap_values.values[0, :, pred_id]
233
+
234
+ max_abs = float(np.max(np.abs(values))) if len(values) else 1.0
235
+ if max_abs == 0:
236
+ max_abs = 1.0
237
+
238
+ result = []
239
+ for word, val in zip(words, values):
240
+ w = word.strip()
241
+ if not w:
242
+ continue
243
+ result.append(
244
+ {"word": w, "shap_value": round(float(val) / max_abs, 4)})
245
+
246
+ # Keep original sentence order so inline text rendering makes sense
247
+ return result
248
+
249
+ except Exception as e:
250
+ print(f"[shap_explain] failed: {e}")
251
+ import traceback
252
+ traceback.print_exc()
253
+ return []
254
+
255
+
256
+ _classifiers: dict[str, FakeNewsClassifier] = {}
257
+
258
+
259
+ def get_classifier(model_key: str = "distilbert") -> FakeNewsClassifier:
260
+ """Get or create a cached classifier instance."""
261
+ if model_key not in _classifiers:
262
+ _classifiers[model_key] = FakeNewsClassifier(model_key)
263
+ return _classifiers[model_key]
264
+
265
+
266
+ def predict(text: str, model_key: str = "distilbert") -> dict:
267
+ """Convenience wrapper for single prediction."""
268
+ return get_classifier(model_key).predict(text)
269
+
270
+
271
+ def generate_explanation_text(
272
+ shap_tokens: list[dict],
273
+ label: str,
274
+ confidence: float,
275
+ model_key: str,
276
+ ) -> str:
277
+ """
278
+ Build a natural-language paragraph explaining the prediction from SHAP data.
279
+ No LLM required — derived entirely from token scores and prediction metadata.
280
+ """
281
+ if not shap_tokens:
282
+ return (
283
+ f"The {model_key} model classified this article as {label} "
284
+ f"with {round(confidence * 100)}% confidence, but no word-level "
285
+ f"explanation data was available for this prediction."
286
+ )
287
+
288
+ positive = sorted(
289
+ [t for t in shap_tokens if t["shap_value"] > 0.05],
290
+ key=lambda x: x["shap_value"], reverse=True
291
+ )[:5]
292
+ negative = sorted(
293
+ [t for t in shap_tokens if t["shap_value"] < -0.05],
294
+ key=lambda x: x["shap_value"]
295
+ )[:3]
296
+
297
+ conf_pct = round(confidence * 100)
298
+ model_display = {"distilbert": "DistilBERT", "roberta": "RoBERTa",
299
+ "xlnet": "XLNet"}.get(model_key, model_key)
300
+
301
+ conf_phrase = (
302
+ "with very high confidence" if conf_pct >= 90 else
303
+ "with high confidence" if conf_pct >= 75 else
304
+ "with moderate confidence" if conf_pct >= 55 else
305
+ "with low confidence"
306
+ )
307
+
308
+ label_descriptions = {
309
+ "True": "factual and credible reporting",
310
+ "Fake": "fabricated or misleading content",
311
+ "Satire": "satirical or parody content",
312
+ "Bias": "politically or ideologically biased reporting",
313
+ }
314
+ label_desc = label_descriptions.get(label, label)
315
+
316
+ parts = [
317
+ f"{model_display} classified this article as {label} ({label_desc}) "
318
+ f"{conf_phrase} ({conf_pct}%)."
319
+ ]
320
+
321
+ if positive:
322
+ word_list = ", ".join(f'"{t["word"]}"' for t in positive)
323
+ parts.append(
324
+ f"The words most strongly associated with this classification were {word_list}, "
325
+ f"which the model weighted heavily toward a {label} prediction."
326
+ )
327
+
328
+ if negative:
329
+ word_list = ", ".join(f'"{t["word"]}"' for t in negative)
330
+ parts.append(
331
+ f"On the other hand, terms like {word_list} pulled against this classification, "
332
+ f"suggesting some linguistic signals that are inconsistent with {label} content."
333
+ )
334
+ elif not negative:
335
+ parts.append(
336
+ f"The model found little linguistic evidence contradicting this classification."
337
+ )
338
+
339
+ if conf_pct < 65:
340
+ parts.append(
341
+ "The relatively lower confidence suggests the article contains mixed signals "
342
+ "and the prediction should be interpreted with caution."
343
+ )
344
+
345
+ return " ".join(parts)
src/models/train.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Training script for fake news detection.
3
+
4
+ Usage:
5
+ python -m src.models.train --model distilbert
6
+ python -m src.models.train --model roberta --epochs 5
7
+ python -m src.models.train --all
8
+ """
9
+
10
+ from src.data.dataset import build_dataset, LABEL2ID, ID2LABEL
11
+ from src.models.evaluate import compute_metrics, full_report
12
+ import os
13
+ import sys
14
+ import json
15
+ import argparse
16
+ from pathlib import Path
17
+ from datetime import datetime
18
+
19
+ import torch
20
+ from transformers import (
21
+ AutoTokenizer,
22
+ AutoModelForSequenceClassification,
23
+ TrainingArguments,
24
+ Trainer,
25
+ EarlyStoppingCallback,
26
+ )
27
+ from dotenv import load_dotenv
28
+
29
+ sys.path.insert(0, str(Path(__file__).parents[2]))
30
+ load_dotenv()
31
+
32
+
33
+ MODELS = {
34
+ "distilbert": "distilbert-base-uncased",
35
+ "roberta": "roberta-base",
36
+ "xlnet": "xlnet-base-cased",
37
+ }
38
+
39
+ PROJECT_ROOT = Path(__file__).parents[2]
40
+ MODELS_DIR = PROJECT_ROOT / "models"
41
+ DATA_CSV = PROJECT_ROOT / "data" / "processed" / "Dataset_Clean.csv"
42
+
43
+
44
+ def get_training_args(model_key, output_dir, epochs, batch_size, learning_rate, use_wandb) -> TrainingArguments:
45
+ return TrainingArguments(
46
+ output_dir=str(output_dir / "checkpoints"),
47
+ num_train_epochs=epochs,
48
+ per_device_train_batch_size=batch_size,
49
+ per_device_eval_batch_size=batch_size * 2,
50
+ learning_rate=learning_rate,
51
+ weight_decay=0.01,
52
+ warmup_ratio=0.06,
53
+ eval_strategy="epoch",
54
+ save_strategy="epoch",
55
+ load_best_model_at_end=True,
56
+ metric_for_best_model="f1_macro",
57
+ greater_is_better=True,
58
+ save_total_limit=2,
59
+ logging_dir=str(output_dir / "logs"),
60
+ logging_steps=50,
61
+ report_to="wandb" if use_wandb else "none",
62
+ run_name=f"{model_key}-{datetime.now().strftime('%Y%m%d-%H%M')}",
63
+ fp16=torch.cuda.is_available(),
64
+ dataloader_num_workers=0,
65
+ push_to_hub=False,
66
+ )
67
+
68
+
69
+ def train_model(model_key, epochs=3, batch_size=16, learning_rate=2e-5, max_length=256, use_wandb=False) -> dict:
70
+ """Full training run for one model. Returns evaluation metrics."""
71
+ model_name = MODELS[model_key]
72
+ output_dir = MODELS_DIR / model_key
73
+
74
+ print("\n" + "=" * 60)
75
+ print(f"TRAINING: {model_key} ({model_name})")
76
+ print(f"Epochs: {epochs} | Batch: {batch_size} | LR: {learning_rate}")
77
+ print(
78
+ f"Device: {'GPU (' + torch.cuda.get_device_name(0) + ')' if torch.cuda.is_available() else 'CPU'}")
79
+ print("=" * 60 + "\n")
80
+
81
+ print("[1/4] Building dataset…")
82
+ tokenized = build_dataset(
83
+ csv_path=DATA_CSV, tokenizer_name=model_name, max_length=max_length)
84
+
85
+ print("[2/4] Loading model…")
86
+ model = AutoModelForSequenceClassification.from_pretrained(
87
+ model_name, num_labels=4, id2label=ID2LABEL, label2id=LABEL2ID, ignore_mismatched_sizes=True,
88
+ )
89
+
90
+ print("[3/4] Setting up trainer…")
91
+ output_dir.mkdir(parents=True, exist_ok=True)
92
+
93
+ if use_wandb:
94
+ import wandb
95
+ wandb.init(
96
+ project=os.getenv("WANDB_PROJECT", "fake-news-detection"),
97
+ name=f"{model_key}-{datetime.now().strftime('%Y%m%d-%H%M')}",
98
+ config={"model": model_name, "epochs": epochs, "batch_size": batch_size,
99
+ "learning_rate": learning_rate, "max_length": max_length},
100
+ )
101
+
102
+ trainer = Trainer(
103
+ model=model,
104
+ args=get_training_args(model_key, output_dir,
105
+ epochs, batch_size, learning_rate, use_wandb),
106
+ train_dataset=tokenized["train"],
107
+ eval_dataset=tokenized["validation"],
108
+ compute_metrics=compute_metrics,
109
+ callbacks=[EarlyStoppingCallback(early_stopping_patience=2)],
110
+ )
111
+
112
+ print("[4/4] Training…\n")
113
+ trainer.train()
114
+
115
+ print(f"\n[✓] Saving model to {output_dir}")
116
+ trainer.save_model(str(output_dir))
117
+ AutoTokenizer.from_pretrained(model_name).save_pretrained(str(output_dir))
118
+
119
+ print("[✓] Evaluating on test set…")
120
+ metrics = full_report(model, tokenized["test"])
121
+
122
+ metrics_path = output_dir / "metrics.json"
123
+ with open(metrics_path, "w") as f:
124
+ json.dump(metrics["report"], f, indent=2)
125
+ print(f"[✓] Metrics saved to {metrics_path}")
126
+
127
+ if use_wandb:
128
+ import wandb
129
+ wandb.log(metrics["report"])
130
+ wandb.finish()
131
+
132
+ return metrics
133
+
134
+
135
+ def main():
136
+ parser = argparse.ArgumentParser(
137
+ description="Train fake news detection models")
138
+ parser.add_argument(
139
+ "--model", choices=list(MODELS.keys()), default="distilbert")
140
+ parser.add_argument("--all", action="store_true",
141
+ help="Train all three models sequentially")
142
+ parser.add_argument("--epochs", type=int, default=3)
143
+ parser.add_argument("--batch-size", type=int, default=16)
144
+ parser.add_argument("--lr", type=float, default=2e-5)
145
+ parser.add_argument("--max-length", type=int, default=256)
146
+ parser.add_argument("--wandb", action="store_true")
147
+ args = parser.parse_args()
148
+
149
+ targets = list(MODELS.keys()) if args.all else [args.model]
150
+ all_metrics = {}
151
+ for model_key in targets:
152
+ all_metrics[model_key] = train_model(
153
+ model_key=model_key, epochs=args.epochs, batch_size=args.batch_size,
154
+ learning_rate=args.lr, max_length=args.max_length, use_wandb=args.wandb,
155
+ )
156
+
157
+ print("\n" + "=" * 60)
158
+ print("TRAINING SUMMARY")
159
+ print("=" * 60)
160
+ for key, m in all_metrics.items():
161
+ r = m["report"]
162
+ print(f"\n{key.upper()}")
163
+ print(f" Accuracy: {r.get('accuracy', 'N/A'):.4f}")
164
+ print(
165
+ f" Macro F1: {r.get('macro avg', {}).get('f1-score', 'N/A'):.4f}")
166
+ print(
167
+ f" Weighted F1: {r.get('weighted avg', {}).get('f1-score', 'N/A'):.4f}")
168
+ print("\n" + "=" * 60 + "\n")
169
+
170
+
171
+ if __name__ == "__main__":
172
+ main()
src/utils/__init__.py ADDED
File without changes
src/utils/__pycache__/__init__.cpython-313.pyc ADDED
Binary file (198 Bytes). View file
 
src/utils/__pycache__/gnews_client.cpython-313.pyc ADDED
Binary file (6.04 kB). View file
 
src/utils/__pycache__/supabase_client.cpython-313.pyc ADDED
Binary file (4.75 kB). View file
 
src/utils/gnews_client.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import requests
3
+ from typing import List, Dict, Any, Optional
4
+ from datetime import datetime, timedelta
5
+ from dotenv import load_dotenv
6
+
7
+ load_dotenv()
8
+
9
+
10
+ class GNewsClient:
11
+ def __init__(self):
12
+ self.api_key = os.getenv("GNEWS_API_KEY")
13
+ self.base_url = os.getenv("GNEWS_API_URL", "https://gnews.io/api/v4")
14
+ if not self.api_key:
15
+ raise ValueError(
16
+ "GNEWS_API_KEY must be set in environment variables")
17
+
18
+ def search_news(
19
+ self,
20
+ query: str = "politics",
21
+ lang: str = "en",
22
+ country: Optional[str] = None,
23
+ max_results: int = 10,
24
+ from_date: Optional[datetime] = None,
25
+ to_date: Optional[datetime] = None,
26
+ ) -> List[Dict[str, Any]]:
27
+ """Search for news articles by query."""
28
+ params = {
29
+ "q": query,
30
+ "lang": lang,
31
+ "max": min(max_results, 100),
32
+ "apikey": self.api_key,
33
+ }
34
+ if country:
35
+ params["country"] = country
36
+ if from_date:
37
+ params["from"] = from_date.strftime("%Y-%m-%dT%H:%M:%SZ")
38
+ if to_date:
39
+ params["to"] = to_date.strftime("%Y-%m-%dT%H:%M:%SZ")
40
+
41
+ try:
42
+ response = requests.get(
43
+ f"{self.base_url}/search", params=params, timeout=10)
44
+ response.raise_for_status()
45
+ return self._format_articles(response.json().get("articles", []))
46
+ except requests.exceptions.RequestException as e:
47
+ print(f"Error fetching news: {e}")
48
+ return []
49
+
50
+ def get_top_headlines(
51
+ self,
52
+ category: Optional[str] = None,
53
+ lang: str = "en",
54
+ country: Optional[str] = None,
55
+ max_results: int = 10,
56
+ ) -> List[Dict[str, Any]]:
57
+ """Get top headlines, optionally filtered by category."""
58
+ params = {
59
+ "lang": lang,
60
+ "max": min(max_results, 100),
61
+ "apikey": self.api_key,
62
+ }
63
+ if category:
64
+ params["category"] = category
65
+ if country:
66
+ params["country"] = country
67
+
68
+ try:
69
+ response = requests.get(
70
+ f"{self.base_url}/top-headlines", params=params, timeout=10)
71
+ response.raise_for_status()
72
+ return self._format_articles(response.json().get("articles", []))
73
+ except requests.exceptions.RequestException as e:
74
+ print(f"Error fetching headlines: {e}")
75
+ return []
76
+
77
+ def _format_articles(self, articles: List[Dict]) -> List[Dict[str, Any]]:
78
+ return [
79
+ {
80
+ "title": article.get("title", ""),
81
+ "description": article.get("description", ""),
82
+ "content": article.get("content", ""),
83
+ "url": article.get("url", ""),
84
+ "image": article.get("image", ""),
85
+ "published_at": article.get("publishedAt", ""),
86
+ "source": article.get("source", {}).get("name", ""),
87
+ "source_url": article.get("source", {}).get("url", ""),
88
+ }
89
+ for article in articles
90
+ ]
91
+
92
+ def get_recent_news_for_analysis(
93
+ self,
94
+ topics: List[str] = ["politics", "breaking news", "world news"],
95
+ max_per_topic: int = 5,
96
+ ) -> List[Dict[str, Any]]:
97
+ """Fetch and deduplicate articles across multiple topics."""
98
+ all_articles = []
99
+ seen_urls: set = set()
100
+ for topic in topics:
101
+ articles = self.search_news(
102
+ query=topic,
103
+ max_results=max_per_topic,
104
+ from_date=datetime.now() - timedelta(days=1),
105
+ )
106
+ for article in articles:
107
+ url = article.get("url", "")
108
+ if url and url not in seen_urls:
109
+ seen_urls.add(url)
110
+ all_articles.append(article)
111
+ return all_articles
112
+
113
+
114
+ _gnews_client: Optional[GNewsClient] = None
115
+
116
+
117
+ def get_gnews_client() -> GNewsClient:
118
+ global _gnews_client
119
+ if _gnews_client is None:
120
+ _gnews_client = GNewsClient()
121
+ return _gnews_client
src/utils/supabase_client.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Optional, Dict, Any, List
3
+ from datetime import datetime
4
+ from supabase import create_client, Client
5
+ from dotenv import load_dotenv
6
+
7
+ load_dotenv()
8
+
9
+
10
+ class SupabaseClient:
11
+ def __init__(self):
12
+ self.url = os.getenv("SUPABASE_URL")
13
+ self.key = os.getenv(
14
+ "SUPABASE_SERVICE_KEY") or os.getenv("SUPABASE_KEY")
15
+ if not self.url or not self.key:
16
+ raise ValueError(
17
+ "SUPABASE_URL and SUPABASE_SERVICE_KEY must be set")
18
+ self.client: Client = create_client(self.url, self.key)
19
+
20
+ def store_prediction(
21
+ self,
22
+ article_id: str,
23
+ text: str,
24
+ predicted_label: str,
25
+ confidence: float,
26
+ model_name: str,
27
+ explanation=None,
28
+ ) -> Dict[str, Any]:
29
+ data = {
30
+ "article_id": article_id,
31
+ "text": text[:1000],
32
+ "predicted_label": predicted_label,
33
+ "confidence": confidence,
34
+ "model_name": model_name,
35
+ "explanation": explanation,
36
+ "created_at": datetime.utcnow().isoformat(),
37
+ }
38
+ response = self.client.table("predictions").insert(data).execute()
39
+ return response.data
40
+
41
+ def store_feedback(
42
+ self,
43
+ article_id: str,
44
+ predicted_label: str,
45
+ actual_label: str,
46
+ user_comment: Optional[str] = None,
47
+ ) -> Dict[str, Any]:
48
+ data = {
49
+ "article_id": article_id,
50
+ "predicted_label": predicted_label,
51
+ "actual_label": actual_label,
52
+ "user_comment": user_comment,
53
+ "created_at": datetime.utcnow().isoformat(),
54
+ }
55
+ response = self.client.table("feedback").insert(data).execute()
56
+ return response.data
57
+
58
+ def get_prediction_stats(self) -> Dict[str, Any]:
59
+ total = self.client.table("predictions").select(
60
+ "*", count="exact").execute()
61
+ by_label_rows = self.client.table(
62
+ "predictions").select("predicted_label").execute()
63
+ label_counts: Dict[str, int] = {}
64
+ for row in by_label_rows.data:
65
+ lbl = row["predicted_label"]
66
+ label_counts[lbl] = label_counts.get(lbl, 0) + 1
67
+ return {
68
+ "total_predictions": total.count,
69
+ "by_label": label_counts,
70
+ }
71
+
72
+ def get_feedback_for_training(self, limit: int = 1000) -> List[Dict[str, Any]]:
73
+ response = self.client.table("feedback").select(
74
+ "*").limit(limit).execute()
75
+ return response.data
76
+
77
+
78
+ _supabase_client: Optional[SupabaseClient] = None
79
+
80
+
81
+ def get_supabase_client() -> SupabaseClient:
82
+ global _supabase_client
83
+ if _supabase_client is None:
84
+ _supabase_client = SupabaseClient()
85
+ return _supabase_client
86
+
87
+
88
+ def reset_client():
89
+ """Force re-initialisation."""
90
+ global _supabase_client
91
+ _supabase_client = None