Spaces:
Running
Running
Initial deployment
Browse files- .dockerignore +17 -0
- Dockerfile +28 -0
- README.md +48 -5
- requirements.txt +49 -0
- scripts/download_models.py +75 -0
- scripts/setup_environment.bat +102 -0
- scripts/setup_environment.sh +135 -0
- scripts/setup_supabase.sql +95 -0
- src/__init__.py +0 -0
- src/__pycache__/__init__.cpython-313.pyc +0 -0
- src/api/__init__.py +0 -0
- src/api/__pycache__/__init__.cpython-313.pyc +0 -0
- src/api/__pycache__/main.cpython-313.pyc +0 -0
- src/api/main.py +355 -0
- src/data/__init__.py +0 -0
- src/data/__pycache__/__init__.cpython-313.pyc +0 -0
- src/data/__pycache__/dataset.cpython-313.pyc +0 -0
- src/data/__pycache__/preprocessing.cpython-313.pyc +0 -0
- src/data/dataset.py +91 -0
- src/data/gnews_collector.py +189 -0
- src/data/preprocessing.py +18 -0
- src/models/__init__.py +12 -0
- src/models/__pycache__/__init__.cpython-313.pyc +0 -0
- src/models/__pycache__/evaluate.cpython-313.pyc +0 -0
- src/models/__pycache__/inference.cpython-313.pyc +0 -0
- src/models/__pycache__/train.cpython-313.pyc +0 -0
- src/models/evaluate.py +46 -0
- src/models/inference.py +345 -0
- src/models/train.py +172 -0
- src/utils/__init__.py +0 -0
- src/utils/__pycache__/__init__.cpython-313.pyc +0 -0
- src/utils/__pycache__/gnews_client.cpython-313.pyc +0 -0
- src/utils/__pycache__/supabase_client.cpython-313.pyc +0 -0
- src/utils/gnews_client.py +121 -0
- src/utils/supabase_client.py +91 -0
.dockerignore
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
models/
|
| 2 |
+
data/
|
| 3 |
+
notebooks/
|
| 4 |
+
frontend/
|
| 5 |
+
venv/
|
| 6 |
+
.git/
|
| 7 |
+
.vscode/
|
| 8 |
+
__pycache__/
|
| 9 |
+
*.pyc
|
| 10 |
+
*.pyo
|
| 11 |
+
*.pyd
|
| 12 |
+
.env
|
| 13 |
+
.env.example
|
| 14 |
+
*.log
|
| 15 |
+
.DS_Store
|
| 16 |
+
README.md
|
| 17 |
+
docker-compose.yml
|
Dockerfile
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.9-slim
|
| 2 |
+
|
| 3 |
+
WORKDIR /app
|
| 4 |
+
|
| 5 |
+
RUN apt-get update && apt-get install -y \
|
| 6 |
+
build-essential \
|
| 7 |
+
curl \
|
| 8 |
+
git \
|
| 9 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 10 |
+
|
| 11 |
+
COPY requirements.txt .
|
| 12 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
| 13 |
+
|
| 14 |
+
COPY src/ ./src/
|
| 15 |
+
COPY scripts/ ./scripts/
|
| 16 |
+
COPY .env.example .env.example
|
| 17 |
+
|
| 18 |
+
# Download models from HuggingFace Hub at build time
|
| 19 |
+
RUN mkdir -p models && \
|
| 20 |
+
huggingface-cli download aviseth/distilbert-fakenews --local-dir models/distilbert --exclude "checkpoints/*" && \
|
| 21 |
+
huggingface-cli download aviseth/roberta-fakenews --local-dir models/roberta --exclude "checkpoints/*" && \
|
| 22 |
+
huggingface-cli download aviseth/xlnet-fakenews --local-dir models/xlnet --exclude "checkpoints/*"
|
| 23 |
+
|
| 24 |
+
# HuggingFace Spaces uses port 7860
|
| 25 |
+
ENV PORT=7860
|
| 26 |
+
EXPOSE 7860
|
| 27 |
+
|
| 28 |
+
CMD uvicorn src.api.main:app --host 0.0.0.0 --port ${PORT}
|
README.md
CHANGED
|
@@ -1,10 +1,53 @@
|
|
| 1 |
---
|
| 2 |
-
title: Fake News
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
sdk: docker
|
| 7 |
pinned: false
|
| 8 |
---
|
| 9 |
|
| 10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
+
title: Fake News Detection API
|
| 3 |
+
emoji: 🔍
|
| 4 |
+
colorFrom: orange
|
| 5 |
+
colorTo: red
|
| 6 |
sdk: docker
|
| 7 |
pinned: false
|
| 8 |
---
|
| 9 |
|
| 10 |
+
# Fake News Detection API
|
| 11 |
+
|
| 12 |
+
Multi-class fake news detection using fine-tuned DistilBERT, RoBERTa, and XLNet models.
|
| 13 |
+
|
| 14 |
+
Classifies news articles into: **True** · **Fake** · **Satire** · **Bias**
|
| 15 |
+
|
| 16 |
+
## Features
|
| 17 |
+
|
| 18 |
+
- 3 transformer models (DistilBERT, RoBERTa, XLNet) trained on 80k+ articles
|
| 19 |
+
- Real-time explainability via gradient saliency + SHAP
|
| 20 |
+
- Live news integration via GNews API
|
| 21 |
+
- Prediction statistics and user feedback collection
|
| 22 |
+
- FastAPI backend with Swagger docs at `/docs`
|
| 23 |
+
|
| 24 |
+
## Endpoints
|
| 25 |
+
|
| 26 |
+
- `POST /predict` — classify text as True / Fake / Satire / Bias
|
| 27 |
+
- `POST /explain` — gradient saliency + SHAP explainability
|
| 28 |
+
- `GET /news` — live news via GNews
|
| 29 |
+
- `GET /news/newspaper` — news grouped by predicted label
|
| 30 |
+
- `POST /feedback` — submit label corrections
|
| 31 |
+
- `GET /stats` — prediction statistics
|
| 32 |
+
- `GET /health` — health check
|
| 33 |
+
- `GET /docs` — Swagger UI
|
| 34 |
+
|
| 35 |
+
## Environment Variables
|
| 36 |
+
|
| 37 |
+
Set these in your Space settings:
|
| 38 |
+
|
| 39 |
+
```
|
| 40 |
+
SUPABASE_URL=your_supabase_url
|
| 41 |
+
SUPABASE_KEY=your_supabase_anon_key
|
| 42 |
+
SUPABASE_SERVICE_KEY=your_supabase_service_key
|
| 43 |
+
GNEWS_API_KEY=your_gnews_api_key
|
| 44 |
+
ALLOWED_ORIGINS=https://your-frontend.vercel.app
|
| 45 |
+
```
|
| 46 |
+
|
| 47 |
+
## Models
|
| 48 |
+
|
| 49 |
+
Models are automatically downloaded from:
|
| 50 |
+
|
| 51 |
+
- [aviseth/distilbert-fakenews](https://huggingface.co/aviseth/distilbert-fakenews)
|
| 52 |
+
- [aviseth/roberta-fakenews](https://huggingface.co/aviseth/roberta-fakenews)
|
| 53 |
+
- [aviseth/xlnet-fakenews](https://huggingface.co/aviseth/xlnet-fakenews)
|
requirements.txt
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Core ML
|
| 2 |
+
torch>=2.0.0
|
| 3 |
+
transformers>=4.30.0
|
| 4 |
+
datasets>=2.12.0
|
| 5 |
+
scikit-learn>=1.3.0
|
| 6 |
+
accelerate>=0.26.0
|
| 7 |
+
|
| 8 |
+
# Backend API
|
| 9 |
+
fastapi>=0.100.0
|
| 10 |
+
uvicorn[standard]>=0.23.0
|
| 11 |
+
pydantic>=2.0.0
|
| 12 |
+
python-multipart>=0.0.6
|
| 13 |
+
|
| 14 |
+
# Database / Supabase
|
| 15 |
+
supabase>=2.0.0
|
| 16 |
+
postgrest>=0.10.0
|
| 17 |
+
sqlalchemy>=2.0.0
|
| 18 |
+
psycopg2-binary>=2.9.0
|
| 19 |
+
|
| 20 |
+
# Data Processing
|
| 21 |
+
pandas>=2.0.0
|
| 22 |
+
numpy>=1.24.0
|
| 23 |
+
nltk>=3.8.0
|
| 24 |
+
spacy>=3.6.0
|
| 25 |
+
|
| 26 |
+
# Explainability
|
| 27 |
+
shap>=0.42.0
|
| 28 |
+
lime>=0.2.0
|
| 29 |
+
|
| 30 |
+
# News / Web
|
| 31 |
+
requests>=2.31.0
|
| 32 |
+
beautifulsoup4>=4.12.0
|
| 33 |
+
newspaper3k>=0.2.8
|
| 34 |
+
|
| 35 |
+
# MLOps
|
| 36 |
+
wandb>=0.15.0
|
| 37 |
+
|
| 38 |
+
# Utilities
|
| 39 |
+
python-dotenv>=1.0.0
|
| 40 |
+
pyyaml>=6.0
|
| 41 |
+
tqdm>=4.65.0
|
| 42 |
+
|
| 43 |
+
# Testing
|
| 44 |
+
pytest>=7.4.0
|
| 45 |
+
pytest-asyncio>=0.21.0
|
| 46 |
+
|
| 47 |
+
# Visualization
|
| 48 |
+
matplotlib>=3.7.0
|
| 49 |
+
seaborn>=0.12.0
|
scripts/download_models.py
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Downloads DistilBERT, RoBERTa, and XLNet base models from Hugging Face
|
| 3 |
+
and saves them to the models/ directory with the correct label configuration.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoConfig
|
| 8 |
+
|
| 9 |
+
MODELS = {
|
| 10 |
+
"distilbert": {"name": "distilbert-base-uncased", "description": "66M parameters"},
|
| 11 |
+
"roberta": {"name": "roberta-base", "description": "125M parameters"},
|
| 12 |
+
"xlnet": {"name": "xlnet-base-cased", "description": "110M parameters"},
|
| 13 |
+
}
|
| 14 |
+
|
| 15 |
+
LABEL_MAP = {0: "True", 1: "Fake", 2: "Satire", 3: "Bias"}
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def download_model(model_key: str, model_info: dict, base_dir: Path) -> bool:
|
| 19 |
+
model_name = model_info["name"]
|
| 20 |
+
save_path = base_dir / model_key
|
| 21 |
+
|
| 22 |
+
print(f"\n{'='*60}")
|
| 23 |
+
print(
|
| 24 |
+
f"Downloading: {model_key} — {model_name} ({model_info['description']})")
|
| 25 |
+
print(f"{'='*60}\n")
|
| 26 |
+
|
| 27 |
+
try:
|
| 28 |
+
save_path.mkdir(parents=True, exist_ok=True)
|
| 29 |
+
|
| 30 |
+
print("[1/3] Tokenizer…")
|
| 31 |
+
AutoTokenizer.from_pretrained(model_name).save_pretrained(save_path)
|
| 32 |
+
|
| 33 |
+
print("[2/3] Config…")
|
| 34 |
+
config = AutoConfig.from_pretrained(
|
| 35 |
+
model_name,
|
| 36 |
+
num_labels=4,
|
| 37 |
+
id2label=LABEL_MAP,
|
| 38 |
+
label2id={v: k for k, v in LABEL_MAP.items()},
|
| 39 |
+
)
|
| 40 |
+
config.save_pretrained(save_path)
|
| 41 |
+
|
| 42 |
+
print("[3/3] Model weights…")
|
| 43 |
+
AutoModelForSequenceClassification.from_pretrained(
|
| 44 |
+
model_name, config=config).save_pretrained(save_path)
|
| 45 |
+
|
| 46 |
+
with open(save_path / "model_info.txt", "w") as f:
|
| 47 |
+
f.write(
|
| 48 |
+
f"Model: {model_name}\nParameters: {model_info['description']}\nLabels: {LABEL_MAP}\nStatus: pre-trained\n")
|
| 49 |
+
|
| 50 |
+
print(f"✅ {model_key} saved to {save_path}\n")
|
| 51 |
+
return True
|
| 52 |
+
|
| 53 |
+
except Exception as e:
|
| 54 |
+
print(f"❌ {model_key} failed: {e}\n")
|
| 55 |
+
return False
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def main():
|
| 59 |
+
models_dir = Path(__file__).parent.parent / "models"
|
| 60 |
+
models_dir.mkdir(parents=True, exist_ok=True)
|
| 61 |
+
|
| 62 |
+
results = {key: download_model(key, info, models_dir)
|
| 63 |
+
for key, info in MODELS.items()}
|
| 64 |
+
|
| 65 |
+
print("\n" + "=" * 60)
|
| 66 |
+
print("SUMMARY")
|
| 67 |
+
print("=" * 60)
|
| 68 |
+
for key, ok in results.items():
|
| 69 |
+
print(f" {key:15} {'✅' if ok else '❌'}")
|
| 70 |
+
print(f"\n{sum(results.values())}/{len(results)} models downloaded")
|
| 71 |
+
print("=" * 60 + "\n")
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
if __name__ == "__main__":
|
| 75 |
+
main()
|
scripts/setup_environment.bat
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
@echo off
|
| 2 |
+
REM ============================================================
|
| 3 |
+
REM Fake News Detection - Environment Setup
|
| 4 |
+
REM Run from project root: scripts\setup_environment.bat
|
| 5 |
+
REM ============================================================
|
| 6 |
+
|
| 7 |
+
REM Move to project root (one level up from scripts/)
|
| 8 |
+
cd /d "%~dp0.."
|
| 9 |
+
|
| 10 |
+
echo.
|
| 11 |
+
echo ============================================================
|
| 12 |
+
echo FAKE NEWS DETECTION - ENVIRONMENT SETUP
|
| 13 |
+
echo ============================================================
|
| 14 |
+
echo.
|
| 15 |
+
|
| 16 |
+
REM Check Python
|
| 17 |
+
echo [1/5] Checking Python...
|
| 18 |
+
python --version >nul 2>&1
|
| 19 |
+
if errorlevel 1 (
|
| 20 |
+
echo [ERROR] Python not found. Install from https://www.python.org/
|
| 21 |
+
pause & exit /b 1
|
| 22 |
+
)
|
| 23 |
+
python --version
|
| 24 |
+
echo.
|
| 25 |
+
|
| 26 |
+
REM Handle existing venv
|
| 27 |
+
if exist venv (
|
| 28 |
+
echo [INFO] Virtual environment already exists.
|
| 29 |
+
set /p recreate="Recreate it? (y/n): "
|
| 30 |
+
if /i "%recreate%"=="y" (
|
| 31 |
+
echo Removing old venv...
|
| 32 |
+
rmdir /s /q venv
|
| 33 |
+
) else (
|
| 34 |
+
goto :activate_venv
|
| 35 |
+
)
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
REM Create venv
|
| 39 |
+
echo [2/5] Creating virtual environment...
|
| 40 |
+
python -m venv venv
|
| 41 |
+
if errorlevel 1 ( echo [ERROR] Failed to create venv & pause & exit /b 1 )
|
| 42 |
+
echo [OK] venv created at %CD%\venv
|
| 43 |
+
echo.
|
| 44 |
+
|
| 45 |
+
:activate_venv
|
| 46 |
+
echo [3/5] Activating virtual environment...
|
| 47 |
+
call venv\Scripts\activate.bat
|
| 48 |
+
if errorlevel 1 ( echo [ERROR] Failed to activate venv & pause & exit /b 1 )
|
| 49 |
+
echo [OK] Activated
|
| 50 |
+
echo.
|
| 51 |
+
|
| 52 |
+
REM Upgrade pip
|
| 53 |
+
echo [4/5] Upgrading pip...
|
| 54 |
+
python -m pip install --upgrade pip --quiet
|
| 55 |
+
echo [OK] pip upgraded
|
| 56 |
+
echo.
|
| 57 |
+
|
| 58 |
+
REM Install requirements
|
| 59 |
+
echo [5/5] Installing requirements.txt...
|
| 60 |
+
echo (This takes a few minutes on first run)
|
| 61 |
+
echo.
|
| 62 |
+
pip install -r requirements.txt
|
| 63 |
+
if errorlevel 1 (
|
| 64 |
+
echo.
|
| 65 |
+
echo [ERROR] Some packages failed. Common fixes:
|
| 66 |
+
echo - Run as Administrator
|
| 67 |
+
echo - Install Visual C++ Build Tools: https://visualstudio.microsoft.com/visual-cpp-build-tools/
|
| 68 |
+
echo - Check internet connection
|
| 69 |
+
pause & exit /b 1
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
echo.
|
| 73 |
+
echo ============================================================
|
| 74 |
+
echo DONE - Virtual environment ready
|
| 75 |
+
echo ============================================================
|
| 76 |
+
echo.
|
| 77 |
+
echo Location : %CD%\venv
|
| 78 |
+
python --version
|
| 79 |
+
echo.
|
| 80 |
+
echo Key packages installed:
|
| 81 |
+
pip list --format=columns | findstr /C:"torch" /C:"transformers" /C:"fastapi" /C:"supabase" /C:"wandb"
|
| 82 |
+
echo.
|
| 83 |
+
echo ============================================================
|
| 84 |
+
echo NEXT STEPS
|
| 85 |
+
echo ============================================================
|
| 86 |
+
echo.
|
| 87 |
+
echo 1. Download base models (run once):
|
| 88 |
+
echo python scripts\download_models.py
|
| 89 |
+
echo.
|
| 90 |
+
echo 2. Run Supabase SQL schema:
|
| 91 |
+
echo Open Supabase dashboard ^> SQL Editor ^> paste scripts\setup_supabase.sql
|
| 92 |
+
echo.
|
| 93 |
+
echo 3. Test connections:
|
| 94 |
+
echo python scripts\test_connections.py
|
| 95 |
+
echo.
|
| 96 |
+
echo 4. Start API:
|
| 97 |
+
echo uvicorn src.api.main:app --reload
|
| 98 |
+
echo.
|
| 99 |
+
echo To activate venv in future sessions:
|
| 100 |
+
echo venv\Scripts\activate
|
| 101 |
+
echo.
|
| 102 |
+
pause
|
scripts/setup_environment.sh
ADDED
|
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
# ============================================================
|
| 3 |
+
# Fake News Detection - Environment Setup Script
|
| 4 |
+
# This script creates virtual environment and installs all dependencies
|
| 5 |
+
# ============================================================
|
| 6 |
+
|
| 7 |
+
set -e # Exit on error
|
| 8 |
+
|
| 9 |
+
echo ""
|
| 10 |
+
echo "============================================================"
|
| 11 |
+
echo "FAKE NEWS DETECTION - ENVIRONMENT SETUP"
|
| 12 |
+
echo "============================================================"
|
| 13 |
+
echo ""
|
| 14 |
+
|
| 15 |
+
# Colors
|
| 16 |
+
GREEN='\033[0;32m'
|
| 17 |
+
RED='\033[0;31m'
|
| 18 |
+
YELLOW='\033[1;33m'
|
| 19 |
+
NC='\033[0m' # No Color
|
| 20 |
+
|
| 21 |
+
# Check if Python is installed
|
| 22 |
+
echo "[1/5] Checking Python installation..."
|
| 23 |
+
if ! command -v python3 &> /dev/null; then
|
| 24 |
+
if ! command -v python &> /dev/null; then
|
| 25 |
+
echo -e "${RED}[ERROR] Python is not installed${NC}"
|
| 26 |
+
echo "Please install Python 3.9 or higher"
|
| 27 |
+
exit 1
|
| 28 |
+
else
|
| 29 |
+
PYTHON_CMD=python
|
| 30 |
+
fi
|
| 31 |
+
else
|
| 32 |
+
PYTHON_CMD=python3
|
| 33 |
+
fi
|
| 34 |
+
|
| 35 |
+
$PYTHON_CMD --version
|
| 36 |
+
echo -e "${GREEN}[SUCCESS] Python is installed${NC}"
|
| 37 |
+
echo ""
|
| 38 |
+
|
| 39 |
+
# Check if virtual environment already exists
|
| 40 |
+
if [ -d "venv" ]; then
|
| 41 |
+
echo -e "${YELLOW}[WARNING] Virtual environment already exists${NC}"
|
| 42 |
+
read -p "Do you want to recreate it? (y/n): " recreate
|
| 43 |
+
if [[ $recreate =~ ^[Yy]$ ]]; then
|
| 44 |
+
echo "[2/5] Removing existing virtual environment..."
|
| 45 |
+
rm -rf venv
|
| 46 |
+
echo -e "${GREEN}[SUCCESS] Removed existing virtual environment${NC}"
|
| 47 |
+
else
|
| 48 |
+
echo "[INFO] Using existing virtual environment"
|
| 49 |
+
fi
|
| 50 |
+
fi
|
| 51 |
+
|
| 52 |
+
# Create virtual environment if it doesn't exist
|
| 53 |
+
if [ ! -d "venv" ]; then
|
| 54 |
+
echo "[2/5] Creating virtual environment..."
|
| 55 |
+
$PYTHON_CMD -m venv venv
|
| 56 |
+
echo -e "${GREEN}[SUCCESS] Virtual environment created${NC}"
|
| 57 |
+
echo ""
|
| 58 |
+
else
|
| 59 |
+
echo "[2/5] Virtual environment already exists"
|
| 60 |
+
echo ""
|
| 61 |
+
fi
|
| 62 |
+
|
| 63 |
+
# Activate virtual environment
|
| 64 |
+
echo "[3/5] Activating virtual environment..."
|
| 65 |
+
source venv/bin/activate
|
| 66 |
+
echo -e "${GREEN}[SUCCESS] Virtual environment activated${NC}"
|
| 67 |
+
echo ""
|
| 68 |
+
|
| 69 |
+
# Upgrade pip
|
| 70 |
+
echo "[4/5] Upgrading pip..."
|
| 71 |
+
pip install --upgrade pip --quiet
|
| 72 |
+
echo -e "${GREEN}[SUCCESS] Pip upgraded${NC}"
|
| 73 |
+
echo ""
|
| 74 |
+
|
| 75 |
+
# Install requirements
|
| 76 |
+
echo "[5/5] Installing dependencies from requirements.txt..."
|
| 77 |
+
echo "This may take a few minutes..."
|
| 78 |
+
echo ""
|
| 79 |
+
|
| 80 |
+
if pip install -r requirements.txt; then
|
| 81 |
+
echo ""
|
| 82 |
+
echo "============================================================"
|
| 83 |
+
echo "INSTALLATION COMPLETE"
|
| 84 |
+
echo "============================================================"
|
| 85 |
+
echo ""
|
| 86 |
+
echo -e "${GREEN}[SUCCESS] All dependencies installed successfully!${NC}"
|
| 87 |
+
echo ""
|
| 88 |
+
echo "Virtual environment location: $(pwd)/venv"
|
| 89 |
+
echo "Python version: $($PYTHON_CMD --version)"
|
| 90 |
+
echo ""
|
| 91 |
+
echo "Installed packages:"
|
| 92 |
+
pip list | grep -E "torch|transformers|fastapi|supabase"
|
| 93 |
+
echo ""
|
| 94 |
+
else
|
| 95 |
+
echo ""
|
| 96 |
+
echo -e "${RED}[ERROR] Failed to install some dependencies${NC}"
|
| 97 |
+
echo "Please check the error messages above"
|
| 98 |
+
echo ""
|
| 99 |
+
echo "Common solutions:"
|
| 100 |
+
echo "1. Make sure you have internet connection"
|
| 101 |
+
echo "2. Check if requirements.txt exists"
|
| 102 |
+
echo "3. Install build tools if needed"
|
| 103 |
+
echo ""
|
| 104 |
+
exit 1
|
| 105 |
+
fi
|
| 106 |
+
|
| 107 |
+
echo "============================================================"
|
| 108 |
+
echo "NEXT STEPS"
|
| 109 |
+
echo "============================================================"
|
| 110 |
+
echo ""
|
| 111 |
+
echo "1. Virtual environment is already activated"
|
| 112 |
+
echo ""
|
| 113 |
+
echo "2. Download models from Hugging Face:"
|
| 114 |
+
echo " python scripts/download_models.py"
|
| 115 |
+
echo ""
|
| 116 |
+
echo "3. Setup Supabase database:"
|
| 117 |
+
echo " - Open Supabase dashboard"
|
| 118 |
+
echo " - Run scripts/setup_supabase.sql in SQL Editor"
|
| 119 |
+
echo ""
|
| 120 |
+
echo "4. Test connections:"
|
| 121 |
+
echo " python scripts/test_connections.py"
|
| 122 |
+
echo ""
|
| 123 |
+
echo "5. Start the API server:"
|
| 124 |
+
echo " uvicorn src.api.main:app --reload"
|
| 125 |
+
echo ""
|
| 126 |
+
echo "============================================================"
|
| 127 |
+
echo ""
|
| 128 |
+
echo "To activate virtual environment in future sessions:"
|
| 129 |
+
echo " source venv/bin/activate"
|
| 130 |
+
echo ""
|
| 131 |
+
echo "To deactivate:"
|
| 132 |
+
echo " deactivate"
|
| 133 |
+
echo ""
|
| 134 |
+
echo "============================================================"
|
| 135 |
+
echo ""
|
scripts/setup_supabase.sql
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
DROP TABLE IF EXISTS feedback CASCADE;
|
| 2 |
+
DROP TABLE IF EXISTS predictions CASCADE;
|
| 3 |
+
DROP TABLE IF EXISTS news_articles CASCADE;
|
| 4 |
+
DROP TABLE IF EXISTS model_performance CASCADE;
|
| 5 |
+
DROP TABLE IF EXISTS user_sessions CASCADE;
|
| 6 |
+
|
| 7 |
+
DROP VIEW IF EXISTS prediction_stats CASCADE;
|
| 8 |
+
DROP VIEW IF EXISTS daily_predictions CASCADE;
|
| 9 |
+
DROP VIEW IF EXISTS feedback_accuracy CASCADE;
|
| 10 |
+
DROP VIEW IF EXISTS model_comparison CASCADE;
|
| 11 |
+
|
| 12 |
+
CREATE EXTENSION IF NOT EXISTS "uuid-ossp";
|
| 13 |
+
|
| 14 |
+
CREATE TABLE predictions (
|
| 15 |
+
id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
|
| 16 |
+
article_id VARCHAR NOT NULL UNIQUE,
|
| 17 |
+
text TEXT,
|
| 18 |
+
predicted_label VARCHAR(50) NOT NULL,
|
| 19 |
+
confidence FLOAT NOT NULL,
|
| 20 |
+
model_name VARCHAR(100) NOT NULL,
|
| 21 |
+
explanation JSONB,
|
| 22 |
+
created_at TIMESTAMPTZ DEFAULT NOW()
|
| 23 |
+
);
|
| 24 |
+
|
| 25 |
+
CREATE INDEX idx_pred_created ON predictions(created_at DESC);
|
| 26 |
+
CREATE INDEX idx_pred_label ON predictions(predicted_label);
|
| 27 |
+
CREATE INDEX idx_pred_model ON predictions(model_name);
|
| 28 |
+
|
| 29 |
+
CREATE TABLE feedback (
|
| 30 |
+
id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
|
| 31 |
+
article_id VARCHAR NOT NULL,
|
| 32 |
+
predicted_label VARCHAR(50) NOT NULL,
|
| 33 |
+
actual_label VARCHAR(50) NOT NULL,
|
| 34 |
+
user_comment TEXT,
|
| 35 |
+
created_at TIMESTAMPTZ DEFAULT NOW()
|
| 36 |
+
);
|
| 37 |
+
|
| 38 |
+
CREATE INDEX idx_fb_created ON feedback(created_at DESC);
|
| 39 |
+
CREATE INDEX idx_fb_article ON feedback(article_id);
|
| 40 |
+
|
| 41 |
+
CREATE TABLE news_articles (
|
| 42 |
+
id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
|
| 43 |
+
title TEXT NOT NULL,
|
| 44 |
+
description TEXT,
|
| 45 |
+
content TEXT,
|
| 46 |
+
url TEXT NOT NULL UNIQUE,
|
| 47 |
+
image_url TEXT,
|
| 48 |
+
published_at TIMESTAMPTZ,
|
| 49 |
+
source_name VARCHAR(255),
|
| 50 |
+
source_url TEXT,
|
| 51 |
+
fetched_at TIMESTAMPTZ DEFAULT NOW(),
|
| 52 |
+
analyzed BOOLEAN DEFAULT FALSE,
|
| 53 |
+
prediction_id UUID
|
| 54 |
+
);
|
| 55 |
+
|
| 56 |
+
CREATE INDEX idx_news_published ON news_articles(published_at DESC);
|
| 57 |
+
CREATE INDEX idx_news_analyzed ON news_articles(analyzed);
|
| 58 |
+
|
| 59 |
+
CREATE TABLE model_performance (
|
| 60 |
+
id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
|
| 61 |
+
model_name VARCHAR(100) NOT NULL,
|
| 62 |
+
accuracy FLOAT,
|
| 63 |
+
precision FLOAT,
|
| 64 |
+
recall FLOAT,
|
| 65 |
+
f1_score FLOAT,
|
| 66 |
+
total_predictions INTEGER DEFAULT 0,
|
| 67 |
+
correct_predictions INTEGER DEFAULT 0,
|
| 68 |
+
evaluated_at TIMESTAMPTZ DEFAULT NOW()
|
| 69 |
+
);
|
| 70 |
+
|
| 71 |
+
CREATE TABLE user_sessions (
|
| 72 |
+
id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
|
| 73 |
+
session_id VARCHAR NOT NULL UNIQUE,
|
| 74 |
+
user_agent TEXT,
|
| 75 |
+
ip_address INET,
|
| 76 |
+
created_at TIMESTAMPTZ DEFAULT NOW(),
|
| 77 |
+
last_activity TIMESTAMPTZ DEFAULT NOW()
|
| 78 |
+
);
|
| 79 |
+
|
| 80 |
+
ALTER TABLE predictions DISABLE ROW LEVEL SECURITY;
|
| 81 |
+
ALTER TABLE feedback DISABLE ROW LEVEL SECURITY;
|
| 82 |
+
ALTER TABLE news_articles DISABLE ROW LEVEL SECURITY;
|
| 83 |
+
ALTER TABLE model_performance DISABLE ROW LEVEL SECURITY;
|
| 84 |
+
ALTER TABLE user_sessions DISABLE ROW LEVEL SECURITY;
|
| 85 |
+
|
| 86 |
+
CREATE VIEW prediction_stats AS
|
| 87 |
+
SELECT predicted_label, COUNT(*) AS total_count, AVG(confidence) AS avg_confidence
|
| 88 |
+
FROM predictions
|
| 89 |
+
GROUP BY predicted_label;
|
| 90 |
+
|
| 91 |
+
CREATE VIEW feedback_accuracy AS
|
| 92 |
+
SELECT predicted_label, actual_label, COUNT(*) AS count
|
| 93 |
+
FROM feedback
|
| 94 |
+
GROUP BY predicted_label, actual_label
|
| 95 |
+
ORDER BY count DESC;
|
src/__init__.py
ADDED
|
File without changes
|
src/__pycache__/__init__.cpython-313.pyc
ADDED
|
Binary file (192 Bytes). View file
|
|
|
src/api/__init__.py
ADDED
|
File without changes
|
src/api/__pycache__/__init__.cpython-313.pyc
ADDED
|
Binary file (196 Bytes). View file
|
|
|
src/api/__pycache__/main.cpython-313.pyc
ADDED
|
Binary file (17 kB). View file
|
|
|
src/api/main.py
ADDED
|
@@ -0,0 +1,355 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import FastAPI, HTTPException, BackgroundTasks
|
| 2 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 3 |
+
from pydantic import BaseModel
|
| 4 |
+
from typing import Optional, List, Dict
|
| 5 |
+
import os
|
| 6 |
+
import uuid
|
| 7 |
+
from dotenv import load_dotenv
|
| 8 |
+
|
| 9 |
+
from src.utils.supabase_client import get_supabase_client
|
| 10 |
+
from src.utils.gnews_client import get_gnews_client
|
| 11 |
+
|
| 12 |
+
load_dotenv()
|
| 13 |
+
|
| 14 |
+
app = FastAPI(
|
| 15 |
+
title="Fake News Detection API",
|
| 16 |
+
description="Multi-class fake news detection using DistilBERT, RoBERTa, and XLNet",
|
| 17 |
+
version="1.0.0",
|
| 18 |
+
docs_url="/docs",
|
| 19 |
+
redoc_url="/redoc",
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
allowed_origins = [
|
| 23 |
+
o.strip()
|
| 24 |
+
for o in os.getenv(
|
| 25 |
+
"ALLOWED_ORIGINS", "http://localhost:3000,http://localhost:5173"
|
| 26 |
+
).split(",")
|
| 27 |
+
if o.strip()
|
| 28 |
+
]
|
| 29 |
+
app.add_middleware(
|
| 30 |
+
CORSMiddleware,
|
| 31 |
+
allow_origins=allowed_origins,
|
| 32 |
+
allow_credentials=True,
|
| 33 |
+
allow_methods=["*"],
|
| 34 |
+
allow_headers=["*"],
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
VALID_MODELS = {"distilbert", "roberta", "xlnet"}
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class PredictionRequest(BaseModel):
|
| 41 |
+
text: Optional[str] = None
|
| 42 |
+
url: Optional[str] = None
|
| 43 |
+
model: Optional[str] = "distilbert"
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class ExplanationData(BaseModel):
|
| 47 |
+
token: str
|
| 48 |
+
score: float
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class PredictionResponse(BaseModel):
|
| 52 |
+
article_id: str
|
| 53 |
+
label: str
|
| 54 |
+
confidence: float
|
| 55 |
+
scores: dict
|
| 56 |
+
model_used: str
|
| 57 |
+
explanation: List[ExplanationData]
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
class FeedbackRequest(BaseModel):
|
| 61 |
+
article_id: str
|
| 62 |
+
predicted_label: str
|
| 63 |
+
actual_label: str
|
| 64 |
+
user_comment: Optional[str] = None
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
class ExplainRequest(BaseModel):
|
| 68 |
+
text: str
|
| 69 |
+
model: Optional[str] = "distilbert"
|
| 70 |
+
deep: Optional[bool] = False
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
@app.on_event("startup")
|
| 74 |
+
async def startup_event():
|
| 75 |
+
try:
|
| 76 |
+
get_supabase_client()
|
| 77 |
+
print("✅ Supabase connected")
|
| 78 |
+
except Exception as e:
|
| 79 |
+
print(f"⚠️ Supabase: {e}")
|
| 80 |
+
try:
|
| 81 |
+
get_gnews_client()
|
| 82 |
+
print("✅ GNews API connected")
|
| 83 |
+
except Exception as e:
|
| 84 |
+
print(f"⚠️ GNews: {e}")
|
| 85 |
+
print("🚀 API server started")
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
@app.get("/")
|
| 89 |
+
async def root():
|
| 90 |
+
return {
|
| 91 |
+
"message": "Fake News Detection API",
|
| 92 |
+
"status": "running",
|
| 93 |
+
"version": "1.0.0",
|
| 94 |
+
"models": list(VALID_MODELS),
|
| 95 |
+
}
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
@app.get("/health")
|
| 99 |
+
async def health_check():
|
| 100 |
+
status = {"api": "healthy", "supabase": "unknown", "gnews": "unknown"}
|
| 101 |
+
try:
|
| 102 |
+
get_supabase_client()
|
| 103 |
+
status["supabase"] = "healthy"
|
| 104 |
+
except Exception as e:
|
| 105 |
+
status["supabase"] = f"unhealthy: {e}"
|
| 106 |
+
try:
|
| 107 |
+
get_gnews_client()
|
| 108 |
+
status["gnews"] = "healthy"
|
| 109 |
+
except Exception as e:
|
| 110 |
+
status["gnews"] = f"unhealthy: {e}"
|
| 111 |
+
return status
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
@app.post("/predict", response_model=PredictionResponse)
|
| 115 |
+
async def predict(request: PredictionRequest, background_tasks: BackgroundTasks):
|
| 116 |
+
"""Classify news as True / Fake / Satire / Bias."""
|
| 117 |
+
if not request.text and not request.url:
|
| 118 |
+
raise HTTPException(status_code=400, detail="Provide text or url")
|
| 119 |
+
|
| 120 |
+
model_key = request.model if request.model in VALID_MODELS else "distilbert"
|
| 121 |
+
article_id = str(uuid.uuid4())
|
| 122 |
+
|
| 123 |
+
text = request.text or ""
|
| 124 |
+
if not text and request.url:
|
| 125 |
+
try:
|
| 126 |
+
import requests as req
|
| 127 |
+
from bs4 import BeautifulSoup
|
| 128 |
+
r = req.get(request.url, timeout=10)
|
| 129 |
+
soup = BeautifulSoup(r.text, "html.parser")
|
| 130 |
+
text = " ".join(p.get_text() for p in soup.find_all("p"))[:4000]
|
| 131 |
+
except Exception as e:
|
| 132 |
+
raise HTTPException(
|
| 133 |
+
status_code=422, detail=f"Could not fetch URL: {e}")
|
| 134 |
+
|
| 135 |
+
if len(text.strip()) < 10:
|
| 136 |
+
raise HTTPException(
|
| 137 |
+
status_code=422, detail="Text too short to classify")
|
| 138 |
+
|
| 139 |
+
try:
|
| 140 |
+
from src.models.inference import predict as run_inference
|
| 141 |
+
result = run_inference(text, model_key)
|
| 142 |
+
except Exception as e:
|
| 143 |
+
raise HTTPException(status_code=500, detail=f"Inference error: {e}")
|
| 144 |
+
|
| 145 |
+
response = PredictionResponse(
|
| 146 |
+
article_id=article_id,
|
| 147 |
+
label=result["label"],
|
| 148 |
+
confidence=result["confidence"],
|
| 149 |
+
scores=result["scores"],
|
| 150 |
+
model_used=model_key,
|
| 151 |
+
explanation=[ExplanationData(**t) for t in result.get("tokens", [])],
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
def _store():
|
| 155 |
+
try:
|
| 156 |
+
supabase = get_supabase_client()
|
| 157 |
+
supabase.store_prediction(
|
| 158 |
+
article_id=article_id,
|
| 159 |
+
text=text,
|
| 160 |
+
predicted_label=result["label"],
|
| 161 |
+
confidence=result["confidence"],
|
| 162 |
+
model_name=model_key,
|
| 163 |
+
explanation=result.get("tokens", []),
|
| 164 |
+
)
|
| 165 |
+
except Exception as e:
|
| 166 |
+
print(f"[bg] store_prediction failed: {e}")
|
| 167 |
+
|
| 168 |
+
background_tasks.add_task(_store)
|
| 169 |
+
return response
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
@app.post("/feedback")
|
| 173 |
+
async def submit_feedback(feedback: FeedbackRequest):
|
| 174 |
+
"""Submit user correction for active learning."""
|
| 175 |
+
try:
|
| 176 |
+
supabase = get_supabase_client()
|
| 177 |
+
result = supabase.store_feedback(
|
| 178 |
+
article_id=feedback.article_id,
|
| 179 |
+
predicted_label=feedback.predicted_label,
|
| 180 |
+
actual_label=feedback.actual_label,
|
| 181 |
+
user_comment=feedback.user_comment,
|
| 182 |
+
)
|
| 183 |
+
return {"status": "success", "message": "Feedback recorded", "data": result}
|
| 184 |
+
except Exception as e:
|
| 185 |
+
import traceback
|
| 186 |
+
print(f"[feedback] ERROR: {e}\n{traceback.format_exc()}")
|
| 187 |
+
raise HTTPException(
|
| 188 |
+
status_code=500, detail=f"Error storing feedback: {str(e)}")
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
@app.get("/news")
|
| 192 |
+
async def get_recent_news(
|
| 193 |
+
query: str = "breaking news",
|
| 194 |
+
max_results: int = 10,
|
| 195 |
+
category: Optional[str] = None,
|
| 196 |
+
):
|
| 197 |
+
"""Fetch recent articles from GNews."""
|
| 198 |
+
try:
|
| 199 |
+
gnews = get_gnews_client()
|
| 200 |
+
if category:
|
| 201 |
+
articles = gnews.get_top_headlines(
|
| 202 |
+
category=category, max_results=max_results)
|
| 203 |
+
else:
|
| 204 |
+
articles = gnews.search_news(query=query, max_results=max_results)
|
| 205 |
+
return {"status": "success", "count": len(articles), "articles": articles}
|
| 206 |
+
except Exception as e:
|
| 207 |
+
raise HTTPException(
|
| 208 |
+
status_code=500, detail=f"Error fetching news: {e}")
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
@app.get("/news/analyze")
|
| 212 |
+
async def analyze_recent_news(topic: str = "politics", max_articles: int = 5):
|
| 213 |
+
"""Fetch and classify recent news articles."""
|
| 214 |
+
try:
|
| 215 |
+
gnews = get_gnews_client()
|
| 216 |
+
articles = gnews.search_news(query=topic, max_results=max_articles)
|
| 217 |
+
|
| 218 |
+
from src.models.inference import predict as run_inference
|
| 219 |
+
results = []
|
| 220 |
+
for article in articles:
|
| 221 |
+
text = article.get("content") or article.get(
|
| 222 |
+
"description") or article.get("title", "")
|
| 223 |
+
if len(text.strip()) < 10:
|
| 224 |
+
continue
|
| 225 |
+
try:
|
| 226 |
+
pred = run_inference(text, "distilbert")
|
| 227 |
+
results.append({"article": article, "prediction": pred})
|
| 228 |
+
except Exception:
|
| 229 |
+
results.append({"article": article, "prediction": None})
|
| 230 |
+
|
| 231 |
+
return {"status": "success", "topic": topic, "analyzed_count": len(results), "results": results}
|
| 232 |
+
except Exception as e:
|
| 233 |
+
raise HTTPException(
|
| 234 |
+
status_code=500, detail=f"Error analyzing news: {e}")
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
@app.get("/news/newspaper")
|
| 238 |
+
async def get_newspaper(max_per_topic: int = 6):
|
| 239 |
+
"""Fetch and classify news across multiple topics, grouped by predicted label."""
|
| 240 |
+
topics = ["world news", "politics", "technology",
|
| 241 |
+
"science", "health", "business"]
|
| 242 |
+
try:
|
| 243 |
+
gnews = get_gnews_client()
|
| 244 |
+
from src.models.inference import predict as run_inference
|
| 245 |
+
|
| 246 |
+
all_results = []
|
| 247 |
+
seen_urls: set = set()
|
| 248 |
+
|
| 249 |
+
for topic in topics:
|
| 250 |
+
articles = gnews.search_news(
|
| 251 |
+
query=topic, max_results=max_per_topic)
|
| 252 |
+
for article in articles:
|
| 253 |
+
url = article.get("url", "")
|
| 254 |
+
if url in seen_urls:
|
| 255 |
+
continue
|
| 256 |
+
seen_urls.add(url)
|
| 257 |
+
text = article.get("content") or article.get(
|
| 258 |
+
"description") or article.get("title", "")
|
| 259 |
+
if len(text.strip()) < 10:
|
| 260 |
+
continue
|
| 261 |
+
try:
|
| 262 |
+
pred = run_inference(text, "distilbert")
|
| 263 |
+
all_results.append(
|
| 264 |
+
{"article": article, "prediction": pred})
|
| 265 |
+
except Exception:
|
| 266 |
+
all_results.append({"article": article, "prediction": {
|
| 267 |
+
"label": "True", "confidence": 0.5, "scores": {}, "tokens": []
|
| 268 |
+
}})
|
| 269 |
+
|
| 270 |
+
grouped: Dict[str, list] = {"True": [],
|
| 271 |
+
"Fake": [], "Satire": [], "Bias": []}
|
| 272 |
+
for item in all_results:
|
| 273 |
+
lbl = item["prediction"].get(
|
| 274 |
+
"label", "True") if item["prediction"] else "True"
|
| 275 |
+
if lbl in grouped:
|
| 276 |
+
grouped[lbl].append(item)
|
| 277 |
+
|
| 278 |
+
return {"status": "success", "total": len(all_results), "grouped": grouped}
|
| 279 |
+
except Exception as e:
|
| 280 |
+
raise HTTPException(
|
| 281 |
+
status_code=500, detail=f"Error building newspaper: {e}")
|
| 282 |
+
|
| 283 |
+
|
| 284 |
+
@app.post("/explain")
|
| 285 |
+
async def explain_prediction(request: ExplainRequest):
|
| 286 |
+
"""
|
| 287 |
+
Return explainability data for a piece of text.
|
| 288 |
+
Always returns gradient saliency highlights. If deep=True, also runs SHAP via RoBERTa.
|
| 289 |
+
"""
|
| 290 |
+
if len(request.text.strip()) < 10:
|
| 291 |
+
raise HTTPException(status_code=422, detail="Text too short")
|
| 292 |
+
|
| 293 |
+
model_key = request.model if request.model in VALID_MODELS else "distilbert"
|
| 294 |
+
|
| 295 |
+
try:
|
| 296 |
+
from src.models.inference import get_classifier
|
| 297 |
+
import asyncio
|
| 298 |
+
|
| 299 |
+
clf = get_classifier(model_key)
|
| 300 |
+
loop = asyncio.get_event_loop()
|
| 301 |
+
|
| 302 |
+
attention = await loop.run_in_executor(None, clf.attention_weights, request.text)
|
| 303 |
+
|
| 304 |
+
shap_tokens = []
|
| 305 |
+
explanation_text = ""
|
| 306 |
+
if request.deep:
|
| 307 |
+
shap_tokens = await loop.run_in_executor(None, clf.shap_explain, request.text)
|
| 308 |
+
if shap_tokens:
|
| 309 |
+
from src.models.inference import generate_explanation_text, predict as run_predict
|
| 310 |
+
pred = run_predict(request.text, model_key)
|
| 311 |
+
explanation_text = generate_explanation_text(
|
| 312 |
+
shap_tokens, pred["label"], pred["confidence"], model_key
|
| 313 |
+
)
|
| 314 |
+
|
| 315 |
+
return {"attention": attention, "shap": shap_tokens, "explanation_text": explanation_text, "model_used": model_key}
|
| 316 |
+
except Exception as e:
|
| 317 |
+
import traceback
|
| 318 |
+
print(f"[explain] ERROR: {e}\n{traceback.format_exc()}")
|
| 319 |
+
raise HTTPException(status_code=500, detail=f"Explain error: {e}")
|
| 320 |
+
|
| 321 |
+
|
| 322 |
+
@app.get("/stats")
|
| 323 |
+
async def get_statistics():
|
| 324 |
+
"""Prediction statistics from Supabase."""
|
| 325 |
+
try:
|
| 326 |
+
supabase = get_supabase_client()
|
| 327 |
+
stats = supabase.get_prediction_stats()
|
| 328 |
+
return {"status": "success", "statistics": stats}
|
| 329 |
+
except Exception as e:
|
| 330 |
+
raise HTTPException(
|
| 331 |
+
status_code=500, detail=f"Error fetching stats: {e}")
|
| 332 |
+
|
| 333 |
+
|
| 334 |
+
@app.get("/models")
|
| 335 |
+
async def list_models():
|
| 336 |
+
"""List available models and their training status."""
|
| 337 |
+
from pathlib import Path
|
| 338 |
+
models_dir = Path(__file__).parents[2] / "models"
|
| 339 |
+
available = []
|
| 340 |
+
for name in ["distilbert", "roberta", "xlnet"]:
|
| 341 |
+
path = models_dir / name
|
| 342 |
+
trained = (path / "config.json").exists()
|
| 343 |
+
available.append({"name": name, "trained": trained,
|
| 344 |
+
"path": str(path) if trained else None})
|
| 345 |
+
return {"models": available, "default": "distilbert"}
|
| 346 |
+
|
| 347 |
+
|
| 348 |
+
if __name__ == "__main__":
|
| 349 |
+
import uvicorn
|
| 350 |
+
uvicorn.run(
|
| 351 |
+
"src.api.main:app",
|
| 352 |
+
host=os.getenv("API_HOST", "0.0.0.0"),
|
| 353 |
+
port=int(os.getenv("API_PORT", 8000)),
|
| 354 |
+
reload=os.getenv("API_RELOAD", "true").lower() == "true",
|
| 355 |
+
)
|
src/data/__init__.py
ADDED
|
File without changes
|
src/data/__pycache__/__init__.cpython-313.pyc
ADDED
|
Binary file (197 Bytes). View file
|
|
|
src/data/__pycache__/dataset.cpython-313.pyc
ADDED
|
Binary file (5.85 kB). View file
|
|
|
src/data/__pycache__/preprocessing.cpython-313.pyc
ADDED
|
Binary file (1.28 kB). View file
|
|
|
src/data/dataset.py
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Dataset loader — reads Dataset_Clean.csv and returns tokenized HuggingFace DatasetDict splits.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import pandas as pd
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from datasets import Dataset, DatasetDict
|
| 8 |
+
from transformers import AutoTokenizer
|
| 9 |
+
from sklearn.model_selection import train_test_split
|
| 10 |
+
from src.data.preprocessing import clean_text
|
| 11 |
+
|
| 12 |
+
LABEL2ID = {"True": 0, "Fake": 1, "Satire": 2, "Bias": 3}
|
| 13 |
+
ID2LABEL = {v: k for k, v in LABEL2ID.items()}
|
| 14 |
+
|
| 15 |
+
DEFAULT_CSV = Path(__file__).parents[2] / \
|
| 16 |
+
"data" / "processed" / "Dataset_Clean.csv"
|
| 17 |
+
MAX_LENGTH = 256
|
| 18 |
+
VAL_SPLIT = 0.10
|
| 19 |
+
TEST_SPLIT = 0.10
|
| 20 |
+
RANDOM_SEED = 42
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def load_dataframe(csv_path: str | Path = DEFAULT_CSV) -> pd.DataFrame:
|
| 24 |
+
"""Load and clean Dataset_Clean.csv. Returns a DataFrame with columns: text, label (int)."""
|
| 25 |
+
df = pd.read_csv(csv_path, low_memory=False)
|
| 26 |
+
df["label_text"] = df["label_text"].astype(
|
| 27 |
+
str).str.strip().str.capitalize()
|
| 28 |
+
df = df[df["label_text"].isin(LABEL2ID)].copy()
|
| 29 |
+
|
| 30 |
+
df["content"] = df["content"].fillna("").astype(str)
|
| 31 |
+
df["title"] = df["title"].fillna("").astype(str)
|
| 32 |
+
df["text"] = df.apply(lambda r: r["content"] if len(
|
| 33 |
+
r["content"]) > 30 else r["title"], axis=1)
|
| 34 |
+
df["text"] = df["text"].apply(clean_text)
|
| 35 |
+
df = df[df["text"].str.len() > 10].copy()
|
| 36 |
+
df["label"] = df["label_text"].map(LABEL2ID).astype(int)
|
| 37 |
+
|
| 38 |
+
print(f"[dataset] Loaded {len(df):,} rows")
|
| 39 |
+
print(
|
| 40 |
+
f"[dataset] Label distribution:\n{df['label_text'].value_counts().to_string()}\n")
|
| 41 |
+
return df[["text", "label"]].reset_index(drop=True)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def make_splits(df: pd.DataFrame) -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
|
| 45 |
+
"""Stratified train / val / test split."""
|
| 46 |
+
train_df, temp_df = train_test_split(
|
| 47 |
+
df, test_size=VAL_SPLIT + TEST_SPLIT, stratify=df["label"], random_state=RANDOM_SEED
|
| 48 |
+
)
|
| 49 |
+
val_df, test_df = train_test_split(
|
| 50 |
+
temp_df, test_size=TEST_SPLIT / (VAL_SPLIT + TEST_SPLIT),
|
| 51 |
+
stratify=temp_df["label"], random_state=RANDOM_SEED
|
| 52 |
+
)
|
| 53 |
+
print(
|
| 54 |
+
f"[dataset] Train: {len(train_df):,} Val: {len(val_df):,} Test: {len(test_df):,}")
|
| 55 |
+
return train_df, val_df, test_df
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def tokenize_dataset(dataset_dict: DatasetDict, tokenizer_name: str, max_length: int = MAX_LENGTH) -> DatasetDict:
|
| 59 |
+
"""Tokenize all splits."""
|
| 60 |
+
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
|
| 61 |
+
|
| 62 |
+
def _tokenize(batch):
|
| 63 |
+
return tokenizer(batch["text"], padding="max_length", truncation=True, max_length=max_length)
|
| 64 |
+
|
| 65 |
+
tokenized = dataset_dict.map(_tokenize, batched=True, batch_size=512, remove_columns=[
|
| 66 |
+
"text"], desc="Tokenizing")
|
| 67 |
+
tokenized.set_format("torch")
|
| 68 |
+
return tokenized
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def build_dataset(
|
| 72 |
+
csv_path: str | Path = DEFAULT_CSV,
|
| 73 |
+
tokenizer_name: str = "distilbert-base-uncased",
|
| 74 |
+
max_length: int = MAX_LENGTH,
|
| 75 |
+
) -> DatasetDict:
|
| 76 |
+
"""Full pipeline: CSV → cleaned DataFrame → HuggingFace DatasetDict → tokenized splits."""
|
| 77 |
+
df = load_dataframe(csv_path)
|
| 78 |
+
train_df, val_df, test_df = make_splits(df)
|
| 79 |
+
|
| 80 |
+
raw = DatasetDict({
|
| 81 |
+
"train": Dataset.from_pandas(train_df, preserve_index=False),
|
| 82 |
+
"validation": Dataset.from_pandas(val_df, preserve_index=False),
|
| 83 |
+
"test": Dataset.from_pandas(test_df, preserve_index=False),
|
| 84 |
+
})
|
| 85 |
+
return tokenize_dataset(raw, tokenizer_name, max_length)
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
if __name__ == "__main__":
|
| 89 |
+
ds = build_dataset()
|
| 90 |
+
print(ds)
|
| 91 |
+
print("Sample:", ds["train"][0])
|
src/data/gnews_collector.py
ADDED
|
@@ -0,0 +1,189 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Fetches live GNews articles and appends them to the training dataset.
|
| 3 |
+
|
| 4 |
+
Usage:
|
| 5 |
+
python -m src.data.gnews_collector # fetch and save
|
| 6 |
+
python -m src.data.gnews_collector --preview # print without saving
|
| 7 |
+
python -m src.data.gnews_collector --label --model-path models/distilbert --merge
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
from src.data.preprocessing import clean_text
|
| 11 |
+
from src.utils.gnews_client import GNewsClient
|
| 12 |
+
import os
|
| 13 |
+
import sys
|
| 14 |
+
import uuid
|
| 15 |
+
import argparse
|
| 16 |
+
import pandas as pd
|
| 17 |
+
from pathlib import Path
|
| 18 |
+
from datetime import datetime
|
| 19 |
+
from dotenv import load_dotenv
|
| 20 |
+
|
| 21 |
+
sys.path.insert(0, str(Path(__file__).parents[2]))
|
| 22 |
+
load_dotenv()
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
PROJECT_ROOT = Path(__file__).parents[2]
|
| 26 |
+
AUGMENTED_DIR = PROJECT_ROOT / "data" / "augmented"
|
| 27 |
+
CLEAN_CSV = PROJECT_ROOT / "data" / "processed" / "Dataset_Clean.csv"
|
| 28 |
+
|
| 29 |
+
FETCH_TOPICS = [
|
| 30 |
+
"scientific research breakthrough",
|
| 31 |
+
"official government announcement",
|
| 32 |
+
"verified breaking news",
|
| 33 |
+
"conspiracy theory debunked",
|
| 34 |
+
"fact check false claim",
|
| 35 |
+
"misinformation viral",
|
| 36 |
+
"satire news comedy",
|
| 37 |
+
"parody news article",
|
| 38 |
+
"political opinion editorial",
|
| 39 |
+
"partisan news analysis",
|
| 40 |
+
]
|
| 41 |
+
|
| 42 |
+
MAX_PER_TOPIC = 5
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def fetch_articles(max_per_topic: int = MAX_PER_TOPIC) -> list[dict]:
|
| 46 |
+
client = GNewsClient()
|
| 47 |
+
all_articles = []
|
| 48 |
+
seen_urls: set[str] = set()
|
| 49 |
+
|
| 50 |
+
for topic in FETCH_TOPICS:
|
| 51 |
+
try:
|
| 52 |
+
articles = client.search_news(
|
| 53 |
+
query=topic, max_results=max_per_topic)
|
| 54 |
+
for a in articles:
|
| 55 |
+
url = a.get("url", "")
|
| 56 |
+
if url and url not in seen_urls:
|
| 57 |
+
seen_urls.add(url)
|
| 58 |
+
all_articles.append(a)
|
| 59 |
+
print(f" ✓ '{topic}' → {len(articles)} articles")
|
| 60 |
+
except Exception as e:
|
| 61 |
+
print(f" ✗ '{topic}' → error: {e}")
|
| 62 |
+
|
| 63 |
+
print(f"\n[collector] Fetched {len(all_articles)} unique articles\n")
|
| 64 |
+
return all_articles
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def articles_to_dataframe(articles: list[dict]) -> pd.DataFrame:
|
| 68 |
+
"""Convert raw GNews articles to Dataset_Clean.csv schema. Labels are set to -1 (unlabelled)."""
|
| 69 |
+
rows = []
|
| 70 |
+
for a in articles:
|
| 71 |
+
title = clean_text(a.get("title", ""))
|
| 72 |
+
content = clean_text(a.get("content", "") or a.get("description", ""))
|
| 73 |
+
text = content if len(content) > 30 else title
|
| 74 |
+
if len(text) < 10:
|
| 75 |
+
continue
|
| 76 |
+
rows.append({
|
| 77 |
+
"id": f"GNEWS_{uuid.uuid4().hex[:8].upper()}",
|
| 78 |
+
"title": title,
|
| 79 |
+
"content": content,
|
| 80 |
+
"label": -1,
|
| 81 |
+
"label_text": "UNLABELLED",
|
| 82 |
+
"label_original": "gnews_live",
|
| 83 |
+
"source_dataset": "GNews_Live",
|
| 84 |
+
"topic": "",
|
| 85 |
+
"url": a.get("url", ""),
|
| 86 |
+
"speaker": a.get("source", ""),
|
| 87 |
+
"fetched_at": datetime.utcnow().isoformat(),
|
| 88 |
+
})
|
| 89 |
+
return pd.DataFrame(rows)
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def pseudo_label(df: pd.DataFrame, model_path: str) -> pd.DataFrame:
|
| 93 |
+
"""Assign pseudo-labels to unlabelled articles using a trained model."""
|
| 94 |
+
import torch
|
| 95 |
+
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
| 96 |
+
|
| 97 |
+
ID2LABEL = {0: "True", 1: "Fake", 2: "Satire", 3: "Bias"}
|
| 98 |
+
print(f"[pseudo_label] Loading model from {model_path}…")
|
| 99 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
| 100 |
+
model = AutoModelForSequenceClassification.from_pretrained(model_path)
|
| 101 |
+
model.eval()
|
| 102 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 103 |
+
model.to(device)
|
| 104 |
+
|
| 105 |
+
texts = df["content"].fillna(df["title"]).tolist()
|
| 106 |
+
labels = []
|
| 107 |
+
confidences = []
|
| 108 |
+
|
| 109 |
+
for i in range(0, len(texts), 16):
|
| 110 |
+
batch = texts[i: i + 16]
|
| 111 |
+
enc = tokenizer(batch, padding=True, truncation=True,
|
| 112 |
+
max_length=256, return_tensors="pt").to(device)
|
| 113 |
+
with torch.no_grad():
|
| 114 |
+
probs = torch.softmax(model(**enc).logits, dim=-1)
|
| 115 |
+
labels.extend(probs.argmax(dim=-1).cpu().tolist())
|
| 116 |
+
confidences.extend(probs.max(dim=-1).values.cpu().tolist())
|
| 117 |
+
|
| 118 |
+
df = df.copy()
|
| 119 |
+
df["label"] = labels
|
| 120 |
+
df["label_text"] = [ID2LABEL[l] for l in labels]
|
| 121 |
+
df["confidence"] = [round(c, 4) for c in confidences]
|
| 122 |
+
print(
|
| 123 |
+
f"[pseudo_label] Label distribution:\n{df['label_text'].value_counts().to_string()}")
|
| 124 |
+
return df
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def save_augmented(df: pd.DataFrame, tag: str = "") -> Path:
|
| 128 |
+
AUGMENTED_DIR.mkdir(parents=True, exist_ok=True)
|
| 129 |
+
ts = datetime.utcnow().strftime("%Y%m%d_%H%M%S")
|
| 130 |
+
name = f"gnews_{ts}{('_' + tag) if tag else ''}.csv"
|
| 131 |
+
path = AUGMENTED_DIR / name
|
| 132 |
+
df.to_csv(path, index=False, encoding="utf-8")
|
| 133 |
+
print(f"[collector] Saved {len(df)} rows → {path}")
|
| 134 |
+
return path
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def merge_into_training(augmented_path: Path, min_confidence: float = 0.80) -> int:
|
| 138 |
+
"""Merge pseudo-labelled articles into Dataset_Clean.csv, filtered by confidence threshold."""
|
| 139 |
+
aug_df = pd.read_csv(augmented_path)
|
| 140 |
+
if "confidence" in aug_df.columns:
|
| 141 |
+
aug_df = aug_df[aug_df["confidence"] >= min_confidence]
|
| 142 |
+
aug_df = aug_df[aug_df["label"] != -1]
|
| 143 |
+
|
| 144 |
+
if len(aug_df) == 0:
|
| 145 |
+
print("[merge] No rows met the confidence threshold.")
|
| 146 |
+
return 0
|
| 147 |
+
|
| 148 |
+
keep_cols = ["id", "title", "content", "label", "label_text",
|
| 149 |
+
"label_original", "source_dataset", "topic", "url", "speaker"]
|
| 150 |
+
aug_df = aug_df[[c for c in keep_cols if c in aug_df.columns]]
|
| 151 |
+
aug_df.to_csv(CLEAN_CSV, mode="a", header=False,
|
| 152 |
+
index=False, encoding="utf-8")
|
| 153 |
+
print(f"[merge] Added {len(aug_df)} rows to {CLEAN_CSV}")
|
| 154 |
+
return len(aug_df)
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def main():
|
| 158 |
+
parser = argparse.ArgumentParser()
|
| 159 |
+
parser.add_argument("--preview", action="store_true")
|
| 160 |
+
parser.add_argument("--label", action="store_true")
|
| 161 |
+
parser.add_argument("--model-path", type=str,
|
| 162 |
+
default="models/distilbert")
|
| 163 |
+
parser.add_argument("--merge", action="store_true")
|
| 164 |
+
parser.add_argument("--min-conf", type=float, default=0.80)
|
| 165 |
+
parser.add_argument("--max-per-topic", type=int, default=MAX_PER_TOPIC)
|
| 166 |
+
args = parser.parse_args()
|
| 167 |
+
|
| 168 |
+
articles = fetch_articles(max_per_topic=args.max_per_topic)
|
| 169 |
+
df = articles_to_dataframe(articles)
|
| 170 |
+
|
| 171 |
+
if args.preview:
|
| 172 |
+
print(df[["title", "source_dataset", "url"]].to_string())
|
| 173 |
+
return
|
| 174 |
+
|
| 175 |
+
raw_path = save_augmented(df, tag="raw")
|
| 176 |
+
|
| 177 |
+
if args.label:
|
| 178 |
+
model_path = str(PROJECT_ROOT / args.model_path)
|
| 179 |
+
if not Path(model_path).exists():
|
| 180 |
+
print(f"[error] Model not found at {model_path}")
|
| 181 |
+
return
|
| 182 |
+
df = pseudo_label(df, model_path)
|
| 183 |
+
labelled_path = save_augmented(df, tag="labelled")
|
| 184 |
+
if args.merge:
|
| 185 |
+
merge_into_training(labelled_path, min_confidence=args.min_conf)
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
if __name__ == "__main__":
|
| 189 |
+
main()
|
src/data/preprocessing.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
import html
|
| 3 |
+
from typing import List
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def clean_text(text: str) -> str:
|
| 7 |
+
"""Clean and normalize raw text — decodes HTML, strips URLs, normalizes whitespace."""
|
| 8 |
+
text = html.unescape(text)
|
| 9 |
+
text = re.sub(r'http\S+', '', text)
|
| 10 |
+
text = text.replace('\u201c', '"').replace(
|
| 11 |
+
'\u201d', '"').replace('\u2013', '-')
|
| 12 |
+
text = re.sub(r'\s+', ' ', text).strip()
|
| 13 |
+
return text
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def preprocess_batch(texts: List[str]) -> List[str]:
|
| 17 |
+
"""Apply clean_text to a list of strings."""
|
| 18 |
+
return [clean_text(text) for text in texts]
|
src/models/__init__.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from src.models.inference import predict, get_classifier, FakeNewsClassifier
|
| 2 |
+
from src.models.train import train_model
|
| 3 |
+
from src.models.evaluate import compute_metrics, full_report
|
| 4 |
+
|
| 5 |
+
__all__ = [
|
| 6 |
+
"predict",
|
| 7 |
+
"get_classifier",
|
| 8 |
+
"FakeNewsClassifier",
|
| 9 |
+
"train_model",
|
| 10 |
+
"compute_metrics",
|
| 11 |
+
"full_report",
|
| 12 |
+
]
|
src/models/__pycache__/__init__.cpython-313.pyc
ADDED
|
Binary file (502 Bytes). View file
|
|
|
src/models/__pycache__/evaluate.cpython-313.pyc
ADDED
|
Binary file (2.53 kB). View file
|
|
|
src/models/__pycache__/inference.cpython-313.pyc
ADDED
|
Binary file (18.9 kB). View file
|
|
|
src/models/__pycache__/train.cpython-313.pyc
ADDED
|
Binary file (8.63 kB). View file
|
|
|
src/models/evaluate.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Evaluation utilities — metrics computed during and after training.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
from sklearn.metrics import classification_report, confusion_matrix, f1_score, accuracy_score
|
| 7 |
+
from transformers import EvalPrediction
|
| 8 |
+
|
| 9 |
+
LABEL_NAMES = ["True", "Fake", "Satire", "Bias"]
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def compute_metrics(eval_pred: EvalPrediction) -> dict:
|
| 13 |
+
"""Called by HuggingFace Trainer after every eval step. Returns accuracy and macro/weighted F1."""
|
| 14 |
+
logits, labels = eval_pred
|
| 15 |
+
preds = np.argmax(logits, axis=-1)
|
| 16 |
+
return {
|
| 17 |
+
"accuracy": round(accuracy_score(labels, preds), 4),
|
| 18 |
+
"f1_macro": round(f1_score(labels, preds, average="macro", zero_division=0), 4),
|
| 19 |
+
"f1_weighted": round(f1_score(labels, preds, average="weighted", zero_division=0), 4),
|
| 20 |
+
}
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def full_report(model, tokenized_test, label_names=LABEL_NAMES) -> dict:
|
| 24 |
+
"""Run full evaluation on the test split. Returns per-class metrics and confusion matrix."""
|
| 25 |
+
from transformers import Trainer
|
| 26 |
+
|
| 27 |
+
trainer = Trainer(model=model, compute_metrics=compute_metrics)
|
| 28 |
+
preds_out = trainer.predict(tokenized_test)
|
| 29 |
+
|
| 30 |
+
preds = np.argmax(preds_out.predictions, axis=-1)
|
| 31 |
+
labels = preds_out.label_ids
|
| 32 |
+
|
| 33 |
+
report = classification_report(
|
| 34 |
+
labels, preds, target_names=label_names, output_dict=True, zero_division=0)
|
| 35 |
+
cm = confusion_matrix(labels, preds)
|
| 36 |
+
|
| 37 |
+
print("\n" + "=" * 60)
|
| 38 |
+
print("CLASSIFICATION REPORT")
|
| 39 |
+
print("=" * 60)
|
| 40 |
+
print(classification_report(labels, preds,
|
| 41 |
+
target_names=label_names, zero_division=0))
|
| 42 |
+
print("Confusion Matrix:")
|
| 43 |
+
print(cm)
|
| 44 |
+
print("=" * 60 + "\n")
|
| 45 |
+
|
| 46 |
+
return {"report": report, "confusion_matrix": cm.tolist()}
|
src/models/inference.py
ADDED
|
@@ -0,0 +1,345 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Model inference — lazy-loads fine-tuned models and runs predictions with explainability.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
import torch
|
| 7 |
+
import numpy as np
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
| 10 |
+
from dotenv import load_dotenv
|
| 11 |
+
|
| 12 |
+
load_dotenv()
|
| 13 |
+
|
| 14 |
+
ID2LABEL = {0: "True", 1: "Fake", 2: "Satire", 3: "Bias"}
|
| 15 |
+
LABEL2ID = {v: k for k, v in ID2LABEL.items()}
|
| 16 |
+
|
| 17 |
+
PROJECT_ROOT = Path(__file__).parents[2]
|
| 18 |
+
MODELS_DIR = PROJECT_ROOT / "models"
|
| 19 |
+
|
| 20 |
+
# Override with HF Hub repo IDs via env vars, e.g. HF_REPO_DISTILBERT=your-username/distilbert-fakenews
|
| 21 |
+
MODEL_NAMES = {
|
| 22 |
+
"distilbert": os.getenv("HF_REPO_DISTILBERT", "distilbert-base-uncased"),
|
| 23 |
+
"roberta": os.getenv("HF_REPO_ROBERTA", "roberta-base"),
|
| 24 |
+
"xlnet": os.getenv("HF_REPO_XLNET", "xlnet-base-cased"),
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class FakeNewsClassifier:
|
| 29 |
+
"""Wraps a fine-tuned HuggingFace model. Lazy-loads on first call and caches in memory."""
|
| 30 |
+
|
| 31 |
+
def __init__(self, model_key: str = "distilbert", max_length: int = 256):
|
| 32 |
+
self.model_key = model_key
|
| 33 |
+
self.max_length = max_length
|
| 34 |
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 35 |
+
self._model = None
|
| 36 |
+
self._tokenizer = None
|
| 37 |
+
|
| 38 |
+
def _load(self):
|
| 39 |
+
local_path = MODELS_DIR / self.model_key
|
| 40 |
+
source = str(local_path) if (
|
| 41 |
+
local_path / "config.json").exists() else MODEL_NAMES[self.model_key]
|
| 42 |
+
print(f"[inference] Loading {self.model_key} from: {source}")
|
| 43 |
+
self._tokenizer = AutoTokenizer.from_pretrained(source)
|
| 44 |
+
self._model = AutoModelForSequenceClassification.from_pretrained(
|
| 45 |
+
source,
|
| 46 |
+
num_labels=4,
|
| 47 |
+
id2label=ID2LABEL,
|
| 48 |
+
label2id=LABEL2ID,
|
| 49 |
+
ignore_mismatched_sizes=True,
|
| 50 |
+
)
|
| 51 |
+
self._model.to(self.device)
|
| 52 |
+
self._model.eval()
|
| 53 |
+
print(f"[inference] Model ready on {self.device}")
|
| 54 |
+
|
| 55 |
+
@property
|
| 56 |
+
def model(self):
|
| 57 |
+
if self._model is None:
|
| 58 |
+
self._load()
|
| 59 |
+
return self._model
|
| 60 |
+
|
| 61 |
+
@property
|
| 62 |
+
def tokenizer(self):
|
| 63 |
+
if self._tokenizer is None:
|
| 64 |
+
self._load()
|
| 65 |
+
return self._tokenizer
|
| 66 |
+
|
| 67 |
+
def predict(self, text: str) -> dict:
|
| 68 |
+
"""
|
| 69 |
+
Run inference on a single text.
|
| 70 |
+
Returns label, confidence (0-1), per-class scores, and top token importance scores.
|
| 71 |
+
"""
|
| 72 |
+
enc = self.tokenizer(
|
| 73 |
+
text,
|
| 74 |
+
return_tensors="pt",
|
| 75 |
+
truncation=True,
|
| 76 |
+
max_length=self.max_length,
|
| 77 |
+
padding=True,
|
| 78 |
+
).to(self.device)
|
| 79 |
+
|
| 80 |
+
with torch.no_grad():
|
| 81 |
+
outputs = self.model(**enc)
|
| 82 |
+
probs = torch.softmax(outputs.logits, dim=-1)[0].cpu().numpy()
|
| 83 |
+
|
| 84 |
+
pred_id = int(np.argmax(probs))
|
| 85 |
+
label = ID2LABEL[pred_id]
|
| 86 |
+
confidence = float(probs[pred_id])
|
| 87 |
+
scores = {ID2LABEL[i]: round(float(p), 4) for i, p in enumerate(probs)}
|
| 88 |
+
tokens = self._token_importance(enc, pred_id)
|
| 89 |
+
|
| 90 |
+
return {
|
| 91 |
+
"label": label,
|
| 92 |
+
"confidence": round(confidence, 4),
|
| 93 |
+
"scores": scores,
|
| 94 |
+
"tokens": tokens,
|
| 95 |
+
}
|
| 96 |
+
|
| 97 |
+
def _token_importance(self, enc, pred_id: int, top_k: int = 8) -> list[dict]:
|
| 98 |
+
"""Gradient saliency — returns top-k tokens sorted by importance."""
|
| 99 |
+
try:
|
| 100 |
+
self.model.zero_grad()
|
| 101 |
+
input_ids = enc["input_ids"]
|
| 102 |
+
embeds = self.model.get_input_embeddings()(
|
| 103 |
+
input_ids).detach().requires_grad_(True)
|
| 104 |
+
outputs = self.model(inputs_embeds=embeds,
|
| 105 |
+
attention_mask=enc.get("attention_mask"))
|
| 106 |
+
outputs.logits[0, pred_id].backward()
|
| 107 |
+
importance = embeds.grad[0].norm(dim=-1).cpu().numpy()
|
| 108 |
+
tokens = self.tokenizer.convert_ids_to_tokens(
|
| 109 |
+
input_ids[0].cpu().tolist())
|
| 110 |
+
special = {"[CLS]", "[SEP]", "[PAD]", "<s>",
|
| 111 |
+
"</s>", "<pad>", "<cls>", "<sep>", "▁", "Ġ"}
|
| 112 |
+
pairs = [
|
| 113 |
+
(t.replace("##", "").replace("▁", "").replace("Ġ", ""), float(s))
|
| 114 |
+
for t, s in zip(tokens, importance)
|
| 115 |
+
if t not in special and len(t.strip()) > 1
|
| 116 |
+
]
|
| 117 |
+
if pairs:
|
| 118 |
+
max_s = max(s for _, s in pairs) or 1.0
|
| 119 |
+
pairs = [(t, round(s / max_s, 4)) for t, s in pairs]
|
| 120 |
+
pairs.sort(key=lambda x: x[1], reverse=True)
|
| 121 |
+
return [{"token": t, "score": s} for t, s in pairs[:top_k]]
|
| 122 |
+
except Exception:
|
| 123 |
+
return []
|
| 124 |
+
|
| 125 |
+
def attention_weights(self, text: str) -> list[dict]:
|
| 126 |
+
"""
|
| 127 |
+
Gradient saliency mapped to original words in reading order.
|
| 128 |
+
Merges subword tokens (BERT ## and RoBERTa Ġ) back into full words.
|
| 129 |
+
"""
|
| 130 |
+
try:
|
| 131 |
+
enc = self.tokenizer(
|
| 132 |
+
text,
|
| 133 |
+
return_tensors="pt",
|
| 134 |
+
truncation=True,
|
| 135 |
+
max_length=self.max_length,
|
| 136 |
+
padding=False,
|
| 137 |
+
).to(self.device)
|
| 138 |
+
|
| 139 |
+
input_ids = enc["input_ids"]
|
| 140 |
+
self.model.zero_grad()
|
| 141 |
+
embeds = self.model.get_input_embeddings()(
|
| 142 |
+
input_ids).detach().requires_grad_(True)
|
| 143 |
+
outputs = self.model(inputs_embeds=embeds,
|
| 144 |
+
attention_mask=enc.get("attention_mask"))
|
| 145 |
+
pred_id = int(torch.argmax(outputs.logits, dim=-1)[0])
|
| 146 |
+
outputs.logits[0, pred_id].backward()
|
| 147 |
+
importance = embeds.grad[0].norm(dim=-1).cpu().numpy()
|
| 148 |
+
|
| 149 |
+
tokens = self.tokenizer.convert_ids_to_tokens(
|
| 150 |
+
input_ids[0].cpu().tolist())
|
| 151 |
+
SPECIAL = {"[CLS]", "[SEP]", "[PAD]", "<s>",
|
| 152 |
+
"</s>", "<pad>", "<cls>", "<sep>", "<unk>"}
|
| 153 |
+
|
| 154 |
+
words = []
|
| 155 |
+
current_word = ""
|
| 156 |
+
current_score = 0.0
|
| 157 |
+
for tok, score in zip(tokens, importance):
|
| 158 |
+
if tok in SPECIAL:
|
| 159 |
+
if current_word:
|
| 160 |
+
words.append((current_word, current_score))
|
| 161 |
+
current_word = ""
|
| 162 |
+
current_score = 0.0
|
| 163 |
+
continue
|
| 164 |
+
|
| 165 |
+
is_continuation = tok.startswith("##")
|
| 166 |
+
is_new_word = tok.startswith("Ġ") or tok.startswith("▁")
|
| 167 |
+
clean = tok.replace("##", "").replace("Ġ", "").replace("▁", "")
|
| 168 |
+
|
| 169 |
+
if is_continuation:
|
| 170 |
+
current_word += clean
|
| 171 |
+
current_score = max(current_score, float(score))
|
| 172 |
+
elif is_new_word:
|
| 173 |
+
if current_word:
|
| 174 |
+
words.append((current_word, current_score))
|
| 175 |
+
current_word = clean
|
| 176 |
+
current_score = float(score)
|
| 177 |
+
else:
|
| 178 |
+
if current_word:
|
| 179 |
+
words.append((current_word, current_score))
|
| 180 |
+
current_word = clean
|
| 181 |
+
current_score = float(score)
|
| 182 |
+
|
| 183 |
+
if current_word:
|
| 184 |
+
words.append((current_word, current_score))
|
| 185 |
+
|
| 186 |
+
if not words:
|
| 187 |
+
return []
|
| 188 |
+
|
| 189 |
+
max_s = max(s for _, s in words) or 1.0
|
| 190 |
+
return [{"word": w, "attention": round(s / max_s, 4)} for w, s in words if w.strip()]
|
| 191 |
+
|
| 192 |
+
except Exception as e:
|
| 193 |
+
print(f"[attention_weights] failed: {e}")
|
| 194 |
+
import traceback
|
| 195 |
+
traceback.print_exc()
|
| 196 |
+
return []
|
| 197 |
+
|
| 198 |
+
def shap_explain(self, text: str) -> list[dict]:
|
| 199 |
+
"""
|
| 200 |
+
Word-level SHAP explanation using RoBERTa for better context.
|
| 201 |
+
Returns words sorted by absolute SHAP value, most influential first.
|
| 202 |
+
"""
|
| 203 |
+
try:
|
| 204 |
+
import shap
|
| 205 |
+
|
| 206 |
+
clf = get_classifier("roberta")
|
| 207 |
+
|
| 208 |
+
def predict_proba(texts):
|
| 209 |
+
results = []
|
| 210 |
+
for t in texts:
|
| 211 |
+
enc = clf.tokenizer(
|
| 212 |
+
t, return_tensors="pt", truncation=True,
|
| 213 |
+
max_length=clf.max_length, padding=True,
|
| 214 |
+
).to(clf.device)
|
| 215 |
+
with torch.no_grad():
|
| 216 |
+
logits = clf.model(**enc).logits
|
| 217 |
+
probs = torch.softmax(logits, dim=-1)[0].cpu().numpy()
|
| 218 |
+
results.append(probs)
|
| 219 |
+
return np.array(results)
|
| 220 |
+
|
| 221 |
+
masker = shap.maskers.Text(r"\W+")
|
| 222 |
+
explainer = shap.Explainer(
|
| 223 |
+
predict_proba, masker, output_names=list(ID2LABEL.values()))
|
| 224 |
+
shap_values = explainer([text], max_evals=200, batch_size=8)
|
| 225 |
+
|
| 226 |
+
enc = clf.tokenizer(text, return_tensors="pt", truncation=True,
|
| 227 |
+
max_length=clf.max_length).to(clf.device)
|
| 228 |
+
with torch.no_grad():
|
| 229 |
+
pred_id = int(torch.argmax(clf.model(**enc).logits, dim=-1)[0])
|
| 230 |
+
|
| 231 |
+
words = shap_values.data[0]
|
| 232 |
+
values = shap_values.values[0, :, pred_id]
|
| 233 |
+
|
| 234 |
+
max_abs = float(np.max(np.abs(values))) if len(values) else 1.0
|
| 235 |
+
if max_abs == 0:
|
| 236 |
+
max_abs = 1.0
|
| 237 |
+
|
| 238 |
+
result = []
|
| 239 |
+
for word, val in zip(words, values):
|
| 240 |
+
w = word.strip()
|
| 241 |
+
if not w:
|
| 242 |
+
continue
|
| 243 |
+
result.append(
|
| 244 |
+
{"word": w, "shap_value": round(float(val) / max_abs, 4)})
|
| 245 |
+
|
| 246 |
+
# Keep original sentence order so inline text rendering makes sense
|
| 247 |
+
return result
|
| 248 |
+
|
| 249 |
+
except Exception as e:
|
| 250 |
+
print(f"[shap_explain] failed: {e}")
|
| 251 |
+
import traceback
|
| 252 |
+
traceback.print_exc()
|
| 253 |
+
return []
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
_classifiers: dict[str, FakeNewsClassifier] = {}
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
def get_classifier(model_key: str = "distilbert") -> FakeNewsClassifier:
|
| 260 |
+
"""Get or create a cached classifier instance."""
|
| 261 |
+
if model_key not in _classifiers:
|
| 262 |
+
_classifiers[model_key] = FakeNewsClassifier(model_key)
|
| 263 |
+
return _classifiers[model_key]
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
def predict(text: str, model_key: str = "distilbert") -> dict:
|
| 267 |
+
"""Convenience wrapper for single prediction."""
|
| 268 |
+
return get_classifier(model_key).predict(text)
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
def generate_explanation_text(
|
| 272 |
+
shap_tokens: list[dict],
|
| 273 |
+
label: str,
|
| 274 |
+
confidence: float,
|
| 275 |
+
model_key: str,
|
| 276 |
+
) -> str:
|
| 277 |
+
"""
|
| 278 |
+
Build a natural-language paragraph explaining the prediction from SHAP data.
|
| 279 |
+
No LLM required — derived entirely from token scores and prediction metadata.
|
| 280 |
+
"""
|
| 281 |
+
if not shap_tokens:
|
| 282 |
+
return (
|
| 283 |
+
f"The {model_key} model classified this article as {label} "
|
| 284 |
+
f"with {round(confidence * 100)}% confidence, but no word-level "
|
| 285 |
+
f"explanation data was available for this prediction."
|
| 286 |
+
)
|
| 287 |
+
|
| 288 |
+
positive = sorted(
|
| 289 |
+
[t for t in shap_tokens if t["shap_value"] > 0.05],
|
| 290 |
+
key=lambda x: x["shap_value"], reverse=True
|
| 291 |
+
)[:5]
|
| 292 |
+
negative = sorted(
|
| 293 |
+
[t for t in shap_tokens if t["shap_value"] < -0.05],
|
| 294 |
+
key=lambda x: x["shap_value"]
|
| 295 |
+
)[:3]
|
| 296 |
+
|
| 297 |
+
conf_pct = round(confidence * 100)
|
| 298 |
+
model_display = {"distilbert": "DistilBERT", "roberta": "RoBERTa",
|
| 299 |
+
"xlnet": "XLNet"}.get(model_key, model_key)
|
| 300 |
+
|
| 301 |
+
conf_phrase = (
|
| 302 |
+
"with very high confidence" if conf_pct >= 90 else
|
| 303 |
+
"with high confidence" if conf_pct >= 75 else
|
| 304 |
+
"with moderate confidence" if conf_pct >= 55 else
|
| 305 |
+
"with low confidence"
|
| 306 |
+
)
|
| 307 |
+
|
| 308 |
+
label_descriptions = {
|
| 309 |
+
"True": "factual and credible reporting",
|
| 310 |
+
"Fake": "fabricated or misleading content",
|
| 311 |
+
"Satire": "satirical or parody content",
|
| 312 |
+
"Bias": "politically or ideologically biased reporting",
|
| 313 |
+
}
|
| 314 |
+
label_desc = label_descriptions.get(label, label)
|
| 315 |
+
|
| 316 |
+
parts = [
|
| 317 |
+
f"{model_display} classified this article as {label} ({label_desc}) "
|
| 318 |
+
f"{conf_phrase} ({conf_pct}%)."
|
| 319 |
+
]
|
| 320 |
+
|
| 321 |
+
if positive:
|
| 322 |
+
word_list = ", ".join(f'"{t["word"]}"' for t in positive)
|
| 323 |
+
parts.append(
|
| 324 |
+
f"The words most strongly associated with this classification were {word_list}, "
|
| 325 |
+
f"which the model weighted heavily toward a {label} prediction."
|
| 326 |
+
)
|
| 327 |
+
|
| 328 |
+
if negative:
|
| 329 |
+
word_list = ", ".join(f'"{t["word"]}"' for t in negative)
|
| 330 |
+
parts.append(
|
| 331 |
+
f"On the other hand, terms like {word_list} pulled against this classification, "
|
| 332 |
+
f"suggesting some linguistic signals that are inconsistent with {label} content."
|
| 333 |
+
)
|
| 334 |
+
elif not negative:
|
| 335 |
+
parts.append(
|
| 336 |
+
f"The model found little linguistic evidence contradicting this classification."
|
| 337 |
+
)
|
| 338 |
+
|
| 339 |
+
if conf_pct < 65:
|
| 340 |
+
parts.append(
|
| 341 |
+
"The relatively lower confidence suggests the article contains mixed signals "
|
| 342 |
+
"and the prediction should be interpreted with caution."
|
| 343 |
+
)
|
| 344 |
+
|
| 345 |
+
return " ".join(parts)
|
src/models/train.py
ADDED
|
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Training script for fake news detection.
|
| 3 |
+
|
| 4 |
+
Usage:
|
| 5 |
+
python -m src.models.train --model distilbert
|
| 6 |
+
python -m src.models.train --model roberta --epochs 5
|
| 7 |
+
python -m src.models.train --all
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
from src.data.dataset import build_dataset, LABEL2ID, ID2LABEL
|
| 11 |
+
from src.models.evaluate import compute_metrics, full_report
|
| 12 |
+
import os
|
| 13 |
+
import sys
|
| 14 |
+
import json
|
| 15 |
+
import argparse
|
| 16 |
+
from pathlib import Path
|
| 17 |
+
from datetime import datetime
|
| 18 |
+
|
| 19 |
+
import torch
|
| 20 |
+
from transformers import (
|
| 21 |
+
AutoTokenizer,
|
| 22 |
+
AutoModelForSequenceClassification,
|
| 23 |
+
TrainingArguments,
|
| 24 |
+
Trainer,
|
| 25 |
+
EarlyStoppingCallback,
|
| 26 |
+
)
|
| 27 |
+
from dotenv import load_dotenv
|
| 28 |
+
|
| 29 |
+
sys.path.insert(0, str(Path(__file__).parents[2]))
|
| 30 |
+
load_dotenv()
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
MODELS = {
|
| 34 |
+
"distilbert": "distilbert-base-uncased",
|
| 35 |
+
"roberta": "roberta-base",
|
| 36 |
+
"xlnet": "xlnet-base-cased",
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
PROJECT_ROOT = Path(__file__).parents[2]
|
| 40 |
+
MODELS_DIR = PROJECT_ROOT / "models"
|
| 41 |
+
DATA_CSV = PROJECT_ROOT / "data" / "processed" / "Dataset_Clean.csv"
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def get_training_args(model_key, output_dir, epochs, batch_size, learning_rate, use_wandb) -> TrainingArguments:
|
| 45 |
+
return TrainingArguments(
|
| 46 |
+
output_dir=str(output_dir / "checkpoints"),
|
| 47 |
+
num_train_epochs=epochs,
|
| 48 |
+
per_device_train_batch_size=batch_size,
|
| 49 |
+
per_device_eval_batch_size=batch_size * 2,
|
| 50 |
+
learning_rate=learning_rate,
|
| 51 |
+
weight_decay=0.01,
|
| 52 |
+
warmup_ratio=0.06,
|
| 53 |
+
eval_strategy="epoch",
|
| 54 |
+
save_strategy="epoch",
|
| 55 |
+
load_best_model_at_end=True,
|
| 56 |
+
metric_for_best_model="f1_macro",
|
| 57 |
+
greater_is_better=True,
|
| 58 |
+
save_total_limit=2,
|
| 59 |
+
logging_dir=str(output_dir / "logs"),
|
| 60 |
+
logging_steps=50,
|
| 61 |
+
report_to="wandb" if use_wandb else "none",
|
| 62 |
+
run_name=f"{model_key}-{datetime.now().strftime('%Y%m%d-%H%M')}",
|
| 63 |
+
fp16=torch.cuda.is_available(),
|
| 64 |
+
dataloader_num_workers=0,
|
| 65 |
+
push_to_hub=False,
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def train_model(model_key, epochs=3, batch_size=16, learning_rate=2e-5, max_length=256, use_wandb=False) -> dict:
|
| 70 |
+
"""Full training run for one model. Returns evaluation metrics."""
|
| 71 |
+
model_name = MODELS[model_key]
|
| 72 |
+
output_dir = MODELS_DIR / model_key
|
| 73 |
+
|
| 74 |
+
print("\n" + "=" * 60)
|
| 75 |
+
print(f"TRAINING: {model_key} ({model_name})")
|
| 76 |
+
print(f"Epochs: {epochs} | Batch: {batch_size} | LR: {learning_rate}")
|
| 77 |
+
print(
|
| 78 |
+
f"Device: {'GPU (' + torch.cuda.get_device_name(0) + ')' if torch.cuda.is_available() else 'CPU'}")
|
| 79 |
+
print("=" * 60 + "\n")
|
| 80 |
+
|
| 81 |
+
print("[1/4] Building dataset…")
|
| 82 |
+
tokenized = build_dataset(
|
| 83 |
+
csv_path=DATA_CSV, tokenizer_name=model_name, max_length=max_length)
|
| 84 |
+
|
| 85 |
+
print("[2/4] Loading model…")
|
| 86 |
+
model = AutoModelForSequenceClassification.from_pretrained(
|
| 87 |
+
model_name, num_labels=4, id2label=ID2LABEL, label2id=LABEL2ID, ignore_mismatched_sizes=True,
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
print("[3/4] Setting up trainer…")
|
| 91 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 92 |
+
|
| 93 |
+
if use_wandb:
|
| 94 |
+
import wandb
|
| 95 |
+
wandb.init(
|
| 96 |
+
project=os.getenv("WANDB_PROJECT", "fake-news-detection"),
|
| 97 |
+
name=f"{model_key}-{datetime.now().strftime('%Y%m%d-%H%M')}",
|
| 98 |
+
config={"model": model_name, "epochs": epochs, "batch_size": batch_size,
|
| 99 |
+
"learning_rate": learning_rate, "max_length": max_length},
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
trainer = Trainer(
|
| 103 |
+
model=model,
|
| 104 |
+
args=get_training_args(model_key, output_dir,
|
| 105 |
+
epochs, batch_size, learning_rate, use_wandb),
|
| 106 |
+
train_dataset=tokenized["train"],
|
| 107 |
+
eval_dataset=tokenized["validation"],
|
| 108 |
+
compute_metrics=compute_metrics,
|
| 109 |
+
callbacks=[EarlyStoppingCallback(early_stopping_patience=2)],
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
print("[4/4] Training…\n")
|
| 113 |
+
trainer.train()
|
| 114 |
+
|
| 115 |
+
print(f"\n[✓] Saving model to {output_dir}")
|
| 116 |
+
trainer.save_model(str(output_dir))
|
| 117 |
+
AutoTokenizer.from_pretrained(model_name).save_pretrained(str(output_dir))
|
| 118 |
+
|
| 119 |
+
print("[✓] Evaluating on test set…")
|
| 120 |
+
metrics = full_report(model, tokenized["test"])
|
| 121 |
+
|
| 122 |
+
metrics_path = output_dir / "metrics.json"
|
| 123 |
+
with open(metrics_path, "w") as f:
|
| 124 |
+
json.dump(metrics["report"], f, indent=2)
|
| 125 |
+
print(f"[✓] Metrics saved to {metrics_path}")
|
| 126 |
+
|
| 127 |
+
if use_wandb:
|
| 128 |
+
import wandb
|
| 129 |
+
wandb.log(metrics["report"])
|
| 130 |
+
wandb.finish()
|
| 131 |
+
|
| 132 |
+
return metrics
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
def main():
|
| 136 |
+
parser = argparse.ArgumentParser(
|
| 137 |
+
description="Train fake news detection models")
|
| 138 |
+
parser.add_argument(
|
| 139 |
+
"--model", choices=list(MODELS.keys()), default="distilbert")
|
| 140 |
+
parser.add_argument("--all", action="store_true",
|
| 141 |
+
help="Train all three models sequentially")
|
| 142 |
+
parser.add_argument("--epochs", type=int, default=3)
|
| 143 |
+
parser.add_argument("--batch-size", type=int, default=16)
|
| 144 |
+
parser.add_argument("--lr", type=float, default=2e-5)
|
| 145 |
+
parser.add_argument("--max-length", type=int, default=256)
|
| 146 |
+
parser.add_argument("--wandb", action="store_true")
|
| 147 |
+
args = parser.parse_args()
|
| 148 |
+
|
| 149 |
+
targets = list(MODELS.keys()) if args.all else [args.model]
|
| 150 |
+
all_metrics = {}
|
| 151 |
+
for model_key in targets:
|
| 152 |
+
all_metrics[model_key] = train_model(
|
| 153 |
+
model_key=model_key, epochs=args.epochs, batch_size=args.batch_size,
|
| 154 |
+
learning_rate=args.lr, max_length=args.max_length, use_wandb=args.wandb,
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
print("\n" + "=" * 60)
|
| 158 |
+
print("TRAINING SUMMARY")
|
| 159 |
+
print("=" * 60)
|
| 160 |
+
for key, m in all_metrics.items():
|
| 161 |
+
r = m["report"]
|
| 162 |
+
print(f"\n{key.upper()}")
|
| 163 |
+
print(f" Accuracy: {r.get('accuracy', 'N/A'):.4f}")
|
| 164 |
+
print(
|
| 165 |
+
f" Macro F1: {r.get('macro avg', {}).get('f1-score', 'N/A'):.4f}")
|
| 166 |
+
print(
|
| 167 |
+
f" Weighted F1: {r.get('weighted avg', {}).get('f1-score', 'N/A'):.4f}")
|
| 168 |
+
print("\n" + "=" * 60 + "\n")
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
if __name__ == "__main__":
|
| 172 |
+
main()
|
src/utils/__init__.py
ADDED
|
File without changes
|
src/utils/__pycache__/__init__.cpython-313.pyc
ADDED
|
Binary file (198 Bytes). View file
|
|
|
src/utils/__pycache__/gnews_client.cpython-313.pyc
ADDED
|
Binary file (6.04 kB). View file
|
|
|
src/utils/__pycache__/supabase_client.cpython-313.pyc
ADDED
|
Binary file (4.75 kB). View file
|
|
|
src/utils/gnews_client.py
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import requests
|
| 3 |
+
from typing import List, Dict, Any, Optional
|
| 4 |
+
from datetime import datetime, timedelta
|
| 5 |
+
from dotenv import load_dotenv
|
| 6 |
+
|
| 7 |
+
load_dotenv()
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class GNewsClient:
|
| 11 |
+
def __init__(self):
|
| 12 |
+
self.api_key = os.getenv("GNEWS_API_KEY")
|
| 13 |
+
self.base_url = os.getenv("GNEWS_API_URL", "https://gnews.io/api/v4")
|
| 14 |
+
if not self.api_key:
|
| 15 |
+
raise ValueError(
|
| 16 |
+
"GNEWS_API_KEY must be set in environment variables")
|
| 17 |
+
|
| 18 |
+
def search_news(
|
| 19 |
+
self,
|
| 20 |
+
query: str = "politics",
|
| 21 |
+
lang: str = "en",
|
| 22 |
+
country: Optional[str] = None,
|
| 23 |
+
max_results: int = 10,
|
| 24 |
+
from_date: Optional[datetime] = None,
|
| 25 |
+
to_date: Optional[datetime] = None,
|
| 26 |
+
) -> List[Dict[str, Any]]:
|
| 27 |
+
"""Search for news articles by query."""
|
| 28 |
+
params = {
|
| 29 |
+
"q": query,
|
| 30 |
+
"lang": lang,
|
| 31 |
+
"max": min(max_results, 100),
|
| 32 |
+
"apikey": self.api_key,
|
| 33 |
+
}
|
| 34 |
+
if country:
|
| 35 |
+
params["country"] = country
|
| 36 |
+
if from_date:
|
| 37 |
+
params["from"] = from_date.strftime("%Y-%m-%dT%H:%M:%SZ")
|
| 38 |
+
if to_date:
|
| 39 |
+
params["to"] = to_date.strftime("%Y-%m-%dT%H:%M:%SZ")
|
| 40 |
+
|
| 41 |
+
try:
|
| 42 |
+
response = requests.get(
|
| 43 |
+
f"{self.base_url}/search", params=params, timeout=10)
|
| 44 |
+
response.raise_for_status()
|
| 45 |
+
return self._format_articles(response.json().get("articles", []))
|
| 46 |
+
except requests.exceptions.RequestException as e:
|
| 47 |
+
print(f"Error fetching news: {e}")
|
| 48 |
+
return []
|
| 49 |
+
|
| 50 |
+
def get_top_headlines(
|
| 51 |
+
self,
|
| 52 |
+
category: Optional[str] = None,
|
| 53 |
+
lang: str = "en",
|
| 54 |
+
country: Optional[str] = None,
|
| 55 |
+
max_results: int = 10,
|
| 56 |
+
) -> List[Dict[str, Any]]:
|
| 57 |
+
"""Get top headlines, optionally filtered by category."""
|
| 58 |
+
params = {
|
| 59 |
+
"lang": lang,
|
| 60 |
+
"max": min(max_results, 100),
|
| 61 |
+
"apikey": self.api_key,
|
| 62 |
+
}
|
| 63 |
+
if category:
|
| 64 |
+
params["category"] = category
|
| 65 |
+
if country:
|
| 66 |
+
params["country"] = country
|
| 67 |
+
|
| 68 |
+
try:
|
| 69 |
+
response = requests.get(
|
| 70 |
+
f"{self.base_url}/top-headlines", params=params, timeout=10)
|
| 71 |
+
response.raise_for_status()
|
| 72 |
+
return self._format_articles(response.json().get("articles", []))
|
| 73 |
+
except requests.exceptions.RequestException as e:
|
| 74 |
+
print(f"Error fetching headlines: {e}")
|
| 75 |
+
return []
|
| 76 |
+
|
| 77 |
+
def _format_articles(self, articles: List[Dict]) -> List[Dict[str, Any]]:
|
| 78 |
+
return [
|
| 79 |
+
{
|
| 80 |
+
"title": article.get("title", ""),
|
| 81 |
+
"description": article.get("description", ""),
|
| 82 |
+
"content": article.get("content", ""),
|
| 83 |
+
"url": article.get("url", ""),
|
| 84 |
+
"image": article.get("image", ""),
|
| 85 |
+
"published_at": article.get("publishedAt", ""),
|
| 86 |
+
"source": article.get("source", {}).get("name", ""),
|
| 87 |
+
"source_url": article.get("source", {}).get("url", ""),
|
| 88 |
+
}
|
| 89 |
+
for article in articles
|
| 90 |
+
]
|
| 91 |
+
|
| 92 |
+
def get_recent_news_for_analysis(
|
| 93 |
+
self,
|
| 94 |
+
topics: List[str] = ["politics", "breaking news", "world news"],
|
| 95 |
+
max_per_topic: int = 5,
|
| 96 |
+
) -> List[Dict[str, Any]]:
|
| 97 |
+
"""Fetch and deduplicate articles across multiple topics."""
|
| 98 |
+
all_articles = []
|
| 99 |
+
seen_urls: set = set()
|
| 100 |
+
for topic in topics:
|
| 101 |
+
articles = self.search_news(
|
| 102 |
+
query=topic,
|
| 103 |
+
max_results=max_per_topic,
|
| 104 |
+
from_date=datetime.now() - timedelta(days=1),
|
| 105 |
+
)
|
| 106 |
+
for article in articles:
|
| 107 |
+
url = article.get("url", "")
|
| 108 |
+
if url and url not in seen_urls:
|
| 109 |
+
seen_urls.add(url)
|
| 110 |
+
all_articles.append(article)
|
| 111 |
+
return all_articles
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
_gnews_client: Optional[GNewsClient] = None
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def get_gnews_client() -> GNewsClient:
|
| 118 |
+
global _gnews_client
|
| 119 |
+
if _gnews_client is None:
|
| 120 |
+
_gnews_client = GNewsClient()
|
| 121 |
+
return _gnews_client
|
src/utils/supabase_client.py
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from typing import Optional, Dict, Any, List
|
| 3 |
+
from datetime import datetime
|
| 4 |
+
from supabase import create_client, Client
|
| 5 |
+
from dotenv import load_dotenv
|
| 6 |
+
|
| 7 |
+
load_dotenv()
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class SupabaseClient:
|
| 11 |
+
def __init__(self):
|
| 12 |
+
self.url = os.getenv("SUPABASE_URL")
|
| 13 |
+
self.key = os.getenv(
|
| 14 |
+
"SUPABASE_SERVICE_KEY") or os.getenv("SUPABASE_KEY")
|
| 15 |
+
if not self.url or not self.key:
|
| 16 |
+
raise ValueError(
|
| 17 |
+
"SUPABASE_URL and SUPABASE_SERVICE_KEY must be set")
|
| 18 |
+
self.client: Client = create_client(self.url, self.key)
|
| 19 |
+
|
| 20 |
+
def store_prediction(
|
| 21 |
+
self,
|
| 22 |
+
article_id: str,
|
| 23 |
+
text: str,
|
| 24 |
+
predicted_label: str,
|
| 25 |
+
confidence: float,
|
| 26 |
+
model_name: str,
|
| 27 |
+
explanation=None,
|
| 28 |
+
) -> Dict[str, Any]:
|
| 29 |
+
data = {
|
| 30 |
+
"article_id": article_id,
|
| 31 |
+
"text": text[:1000],
|
| 32 |
+
"predicted_label": predicted_label,
|
| 33 |
+
"confidence": confidence,
|
| 34 |
+
"model_name": model_name,
|
| 35 |
+
"explanation": explanation,
|
| 36 |
+
"created_at": datetime.utcnow().isoformat(),
|
| 37 |
+
}
|
| 38 |
+
response = self.client.table("predictions").insert(data).execute()
|
| 39 |
+
return response.data
|
| 40 |
+
|
| 41 |
+
def store_feedback(
|
| 42 |
+
self,
|
| 43 |
+
article_id: str,
|
| 44 |
+
predicted_label: str,
|
| 45 |
+
actual_label: str,
|
| 46 |
+
user_comment: Optional[str] = None,
|
| 47 |
+
) -> Dict[str, Any]:
|
| 48 |
+
data = {
|
| 49 |
+
"article_id": article_id,
|
| 50 |
+
"predicted_label": predicted_label,
|
| 51 |
+
"actual_label": actual_label,
|
| 52 |
+
"user_comment": user_comment,
|
| 53 |
+
"created_at": datetime.utcnow().isoformat(),
|
| 54 |
+
}
|
| 55 |
+
response = self.client.table("feedback").insert(data).execute()
|
| 56 |
+
return response.data
|
| 57 |
+
|
| 58 |
+
def get_prediction_stats(self) -> Dict[str, Any]:
|
| 59 |
+
total = self.client.table("predictions").select(
|
| 60 |
+
"*", count="exact").execute()
|
| 61 |
+
by_label_rows = self.client.table(
|
| 62 |
+
"predictions").select("predicted_label").execute()
|
| 63 |
+
label_counts: Dict[str, int] = {}
|
| 64 |
+
for row in by_label_rows.data:
|
| 65 |
+
lbl = row["predicted_label"]
|
| 66 |
+
label_counts[lbl] = label_counts.get(lbl, 0) + 1
|
| 67 |
+
return {
|
| 68 |
+
"total_predictions": total.count,
|
| 69 |
+
"by_label": label_counts,
|
| 70 |
+
}
|
| 71 |
+
|
| 72 |
+
def get_feedback_for_training(self, limit: int = 1000) -> List[Dict[str, Any]]:
|
| 73 |
+
response = self.client.table("feedback").select(
|
| 74 |
+
"*").limit(limit).execute()
|
| 75 |
+
return response.data
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
_supabase_client: Optional[SupabaseClient] = None
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def get_supabase_client() -> SupabaseClient:
|
| 82 |
+
global _supabase_client
|
| 83 |
+
if _supabase_client is None:
|
| 84 |
+
_supabase_client = SupabaseClient()
|
| 85 |
+
return _supabase_client
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def reset_client():
|
| 89 |
+
"""Force re-initialisation."""
|
| 90 |
+
global _supabase_client
|
| 91 |
+
_supabase_client = None
|