Merge pull request #25 from cyberalertnepal/Testing
Browse filesImplement rate limiting, Nepali support, and RAG pipeline fixes
- .env-example +32 -0
- .gitignore +3 -0
- README.md +13 -0
- READMEs.md +152 -0
- docs/detector/ai_human_image_checker.md +132 -0
- features/ai_human_image_classifier/controller.py +35 -0
- features/ai_human_image_classifier/inferencer.py +48 -0
- features/ai_human_image_classifier/main.py +27 -0
- features/ai_human_image_classifier/model_loader.py +80 -0
- features/ai_human_image_classifier/preprocessor.py +34 -0
- features/ai_human_image_classifier/routes.py +44 -0
- features/nepali_text_classifier/preprocess.py +5 -6
- features/rag_chatbot/__init__.py +0 -0
- features/rag_chatbot/controller.py +182 -0
- features/rag_chatbot/document_handler.py +37 -0
- features/rag_chatbot/rag_pipeline.py +327 -0
- features/rag_chatbot/routes.py +111 -0
- features/real_forged_classifier/controller.py +36 -0
- features/real_forged_classifier/inferencer.py +52 -0
- features/real_forged_classifier/main.py +26 -0
- features/real_forged_classifier/model.py +34 -0
- features/real_forged_classifier/model_loader.py +60 -0
- features/real_forged_classifier/preprocessor.py +67 -0
- features/real_forged_classifier/routes.py +37 -0
- features/text_classifier/controller.py +7 -4
- features/text_classifier/preprocess.py +5 -7
- requirements.txt +1 -1
- test.md +31 -0
.env-example
CHANGED
|
@@ -1,2 +1,34 @@
|
|
| 1 |
MY_SECRET_TOKEN="SECRET_CODE_TOKEN"
|
| 2 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
MY_SECRET_TOKEN="SECRET_CODE_TOKEN"
|
| 2 |
|
| 3 |
+
# CHROMA_HOST = "localhost" (Host gareko address rakhney)
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
# EXAMPLE CONFIGURATIONS FOR DIFFERENT PROVIDERS(Use only one at once)
|
| 7 |
+
# ===========================================
|
| 8 |
+
|
| 9 |
+
# FOR OPENAI:(PAID)
|
| 10 |
+
# LLM_PROVIDER=openai
|
| 11 |
+
# LLM_API_KEY=sk-your-openai-api-key
|
| 12 |
+
# LLM_MODEL=gpt-3.5-turbo
|
| 13 |
+
# # Other options: gpt-4, gpt-4-turbo-preview, etc.
|
| 14 |
+
|
| 15 |
+
# FOR GROQ:(FREE: BABAL XA-> prefer this)
|
| 16 |
+
# LLM_PROVIDER=groq
|
| 17 |
+
# LLM_API_KEY=gsk_your-groq-api-key
|
| 18 |
+
# LLM_MODEL=llama-3.3-70b-versatile
|
| 19 |
+
# # Other options: llama-3.1-70b-versatile, mixtral-8x7b-32768, etc.
|
| 20 |
+
|
| 21 |
+
# FOR OPENROUTER:(FREE: LASTAI RATE LIMIT LAGAUXA)
|
| 22 |
+
# LLM_PROVIDER=openrouter
|
| 23 |
+
# LLM_API_KEY=sk-or-your-openrouter-api-key
|
| 24 |
+
# LLM_MODEL=meta-llama/llama-3.1-8b-instruct:free
|
| 25 |
+
# # Other options: anthropic/claude-3-haiku, google/gemma-7b-it, etc.
|
| 26 |
+
|
| 27 |
+
# ===========================================
|
| 28 |
+
# ADVANCED CONFIGURATION
|
| 29 |
+
# ===========================================
|
| 30 |
+
# Temperature (0.0 to 1.0) - controls randomness
|
| 31 |
+
# LLM_TEMPERATURE=0.1
|
| 32 |
+
|
| 33 |
+
# Maximum tokens for response
|
| 34 |
+
# LLM_MAX_TOKENS=4096
|
.gitignore
CHANGED
|
@@ -66,3 +66,6 @@ notebooks
|
|
| 66 |
np_text_model/classifier/sentencepiece.bpe.model
|
| 67 |
np_text_model/classifier/tokenizer.json
|
| 68 |
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
np_text_model/classifier/sentencepiece.bpe.model
|
| 67 |
np_text_model/classifier/tokenizer.json
|
| 68 |
|
| 69 |
+
# vector database
|
| 70 |
+
chroma_data
|
| 71 |
+
chroma_database
|
README.md
CHANGED
|
@@ -1,3 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
# AI-Content-Checker
|
| 2 |
|
| 3 |
A modular AI content detection system with support for **image classification**, **image edit detection**, **Nepali text classification**, and **general text classification**. Built for performance and extensibility, it is ideal for detecting AI-generated content in both visual and textual forms.
|
|
@@ -150,3 +162,4 @@ AI-Checker/
|
|
| 150 |
## 📄 License
|
| 151 |
|
| 152 |
See full license terms here: [`LICENSE.md`](license.md)
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: Testing AI Contain
|
| 3 |
+
emoji: 🤖
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: green
|
| 6 |
+
sdk: docker
|
| 7 |
+
sdk_version: "latest"
|
| 8 |
+
app_file: app.py
|
| 9 |
+
pinned: false
|
| 10 |
+
---
|
| 11 |
+
|
| 12 |
+
# AI-Contain-Checker
|
| 13 |
# AI-Content-Checker
|
| 14 |
|
| 15 |
A modular AI content detection system with support for **image classification**, **image edit detection**, **Nepali text classification**, and **general text classification**. Built for performance and extensibility, it is ideal for detecting AI-generated content in both visual and textual forms.
|
|
|
|
| 162 |
## 📄 License
|
| 163 |
|
| 164 |
See full license terms here: [`LICENSE.md`](license.md)
|
| 165 |
+
|
READMEs.md
ADDED
|
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# AI-Contain-Checker
|
| 2 |
+
|
| 3 |
+
A modular AI content detection system with support for **image classification**, **image edit detection**, **Nepali text classification**, and **general text classification**. Built for performance and extensibility, it is ideal for detecting AI-generated content in both visual and textual forms.
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
## 🌟 Features
|
| 7 |
+
|
| 8 |
+
### 🖼️ Image Classifier
|
| 9 |
+
|
| 10 |
+
* **Purpose**: Classifies whether an image is AI-generated or a real-life photo.
|
| 11 |
+
* **Model**: Fine-tuned **InceptionV3** CNN.
|
| 12 |
+
* **Dataset**: Custom curated dataset with **\~79,950 images** for binary classification.
|
| 13 |
+
* **Location**: [`features/image_classifier`](features/image_classifier)
|
| 14 |
+
* **Docs**: [`docs/features/image_classifier.md`](docs/features/image_classifier.md)
|
| 15 |
+
|
| 16 |
+
### 🖌️ Image Edit Detector
|
| 17 |
+
|
| 18 |
+
* **Purpose**: Detects image tampering or post-processing.
|
| 19 |
+
* **Techniques Used**:
|
| 20 |
+
|
| 21 |
+
* **Error Level Analysis (ELA)**: Visualizes compression artifacts.
|
| 22 |
+
* **Fast Fourier Transform (FFT)**: Detects unnatural frequency patterns.
|
| 23 |
+
* **Location**: [`features/image_edit_detector`](features/image_edit_detector)
|
| 24 |
+
* **Docs**:
|
| 25 |
+
|
| 26 |
+
* [ELA](docs/detector/ELA.md)
|
| 27 |
+
* [FFT](docs/detector/fft.md )
|
| 28 |
+
* [Metadata Analysis](docs/detector/meta.md)
|
| 29 |
+
* [Backend Notes](docs/detector/note-for-backend.md)
|
| 30 |
+
|
| 31 |
+
### 📝 Nepali Text Classifier
|
| 32 |
+
|
| 33 |
+
* **Purpose**: Determines if Nepali text content is AI-generated or written by a human.
|
| 34 |
+
* **Model**: Based on `XLMRClassifier` fine-tuned on Nepali language data.
|
| 35 |
+
* **Dataset**: Scraped dataset of **\~18,000** Nepali texts.
|
| 36 |
+
* **Location**: [`features/nepali_text_classifier`](features/nepali_text_classifier)
|
| 37 |
+
* **Docs**: [`docs/features/nepali_text_classifier.md`](docs/features/nepali_text_classifier.md)
|
| 38 |
+
|
| 39 |
+
### 🌐 English Text Classifier
|
| 40 |
+
|
| 41 |
+
* **Purpose**: Detects if English text is AI-generated or human-written.
|
| 42 |
+
* **Pipeline**:
|
| 43 |
+
|
| 44 |
+
* Uses **GPT2 tokenizer** for input preprocessing.
|
| 45 |
+
* Custom binary classifier to differentiate between AI and human-written content.
|
| 46 |
+
* **Location**: [`features/text_classifier`](features/text_classifier)
|
| 47 |
+
* **Docs**: [`docs/features/text_classifier.md`](docs/features/text_classifier.md)
|
| 48 |
+
|
| 49 |
+
---
|
| 50 |
+
|
| 51 |
+
## 🗂️ Project Structure
|
| 52 |
+
|
| 53 |
+
```bash
|
| 54 |
+
AI-Checker/
|
| 55 |
+
│
|
| 56 |
+
├── app.py # Main FastAPI entry point
|
| 57 |
+
├── config.py # Configuration settings
|
| 58 |
+
├── Dockerfile # Docker build script
|
| 59 |
+
├── Procfile # Deployment file for Heroku or similar
|
| 60 |
+
├── requirements.txt # Python dependencies
|
| 61 |
+
├── README.md # You are here 📘
|
| 62 |
+
│
|
| 63 |
+
├── features/ # Core detection modules
|
| 64 |
+
│ ├── image_classifier/
|
| 65 |
+
│ ├── image_edit_detector/
|
| 66 |
+
│ ├── nepali_text_classifier/
|
| 67 |
+
│ └── text_classifier/
|
| 68 |
+
│
|
| 69 |
+
├── docs/ # Internal and API documentation
|
| 70 |
+
│ ├── api_endpoints.md
|
| 71 |
+
│ ├── deployment.md
|
| 72 |
+
│ ├── detector/
|
| 73 |
+
│ │ ├── ELA.md
|
| 74 |
+
│ │ ├── fft.md
|
| 75 |
+
│ │ ├── meta.md
|
| 76 |
+
│ │ └── note-for-backend.md
|
| 77 |
+
│ ├── functions.md
|
| 78 |
+
│ ├── nestjs_integration.md
|
| 79 |
+
│ ├── security.md
|
| 80 |
+
│ ├── setup.md
|
| 81 |
+
│ └── structure.md
|
| 82 |
+
│
|
| 83 |
+
├── IMG_Models/ # Saved image classifier model(s)
|
| 84 |
+
│ └── latest-my_cnn_model.h5
|
| 85 |
+
│
|
| 86 |
+
├── notebooks/ # Experimental and debug notebooks
|
| 87 |
+
├── static/ # Static assets if needed
|
| 88 |
+
└── test.md # Test notes
|
| 89 |
+
````
|
| 90 |
+
|
| 91 |
+
---
|
| 92 |
+
|
| 93 |
+
## 📚 Documentation Links
|
| 94 |
+
|
| 95 |
+
* [API Endpoints](docs/api_endpoints.md)
|
| 96 |
+
* [Deployment Guide](docs/deployment.md)
|
| 97 |
+
* [Detector Documentation](docs/detector/)
|
| 98 |
+
|
| 99 |
+
* [Error Level Analysis (ELA)](docs/detector/ELA.md)
|
| 100 |
+
* [Fast Fourier Transform (FFT)](docs/detector/fft.md)
|
| 101 |
+
* [Metadata Analysis](docs/detector/meta.md)
|
| 102 |
+
* [Backend Notes](docs/detector/note-for-backend.md)
|
| 103 |
+
* [Functions Overview](docs/functions.md)
|
| 104 |
+
* [NestJS Integration Guide](docs/nestjs_integration.md)
|
| 105 |
+
* [Security Details](docs/security.md)
|
| 106 |
+
* [Setup Instructions](docs/setup.md)
|
| 107 |
+
* [Project Structure](docs/structure.md)
|
| 108 |
+
|
| 109 |
+
---
|
| 110 |
+
|
| 111 |
+
## 🚀 Usage
|
| 112 |
+
|
| 113 |
+
1. **Install dependencies**
|
| 114 |
+
|
| 115 |
+
```bash
|
| 116 |
+
pip install -r requirements.txt
|
| 117 |
+
```
|
| 118 |
+
|
| 119 |
+
2. **Run the API**
|
| 120 |
+
|
| 121 |
+
```bash
|
| 122 |
+
uvicorn app:app --reload
|
| 123 |
+
```
|
| 124 |
+
|
| 125 |
+
3. **Build Docker (optional)**
|
| 126 |
+
|
| 127 |
+
```bash
|
| 128 |
+
docker build -t ai-contain-checker .
|
| 129 |
+
docker run -p 8000:8000 ai-contain-checker
|
| 130 |
+
```
|
| 131 |
+
|
| 132 |
+
---
|
| 133 |
+
|
| 134 |
+
## 🔐 Security & Integration
|
| 135 |
+
|
| 136 |
+
* **Token Authentication** and **IP Whitelisting** supported.
|
| 137 |
+
* NestJS integration guide: [`docs/nestjs_integration.md`](docs/nestjs_integration.md)
|
| 138 |
+
* Rate limiting handled using `slowapi`.
|
| 139 |
+
|
| 140 |
+
---
|
| 141 |
+
|
| 142 |
+
## 🛡️ Future Plans
|
| 143 |
+
|
| 144 |
+
* Add **video classifier** module.
|
| 145 |
+
* Expand dataset for **multilingual** AI content detection.
|
| 146 |
+
* Add **fine-tuning UI** for models.
|
| 147 |
+
|
| 148 |
+
---
|
| 149 |
+
|
| 150 |
+
## 📄 License
|
| 151 |
+
|
| 152 |
+
See full license terms here: [`LICENSE.md`](license.md)
|
docs/detector/ai_human_image_checker.md
ADDED
|
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Real vs. Fake Image Classification for Production Pipeline
|
| 2 |
+
==========================================================
|
| 3 |
+
|
| 4 |
+
1\. Business Problem
|
| 5 |
+
--------------------
|
| 6 |
+
|
| 7 |
+
This project addresses the critical business need to automatically identify and flag manipulated or synthetically generated images. By accurately classifying images as **"real"** or **"fake,"** we can enhance the integrity of our platform, prevent the spread of misinformation, and protect our users from fraudulent content. This solution is designed for integration into our production pipeline to process images in real-time.
|
| 8 |
+
|
| 9 |
+
2\. Solution Overview
|
| 10 |
+
---------------------
|
| 11 |
+
|
| 12 |
+
This solution leverages OpenAI's CLIP (Contrastive Language-Image Pre-Training) model to differentiate between real and fake images. The system operates as follows:
|
| 13 |
+
|
| 14 |
+
1. **Feature Extraction:** A pre-trained CLIP model ('ViT-L/14') converts input images into 768-dimensional feature vectors.
|
| 15 |
+
|
| 16 |
+
2. **Classification:** A Support Vector Machine (SVM) model, trained on our internal dataset of real and fake images, classifies the feature vectors.
|
| 17 |
+
|
| 18 |
+
3. **Deployment:** The trained model is deployed as a service that can be integrated into our production image processing pipeline.
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
The model has achieved an accuracy of **98.29%** on our internal test set, demonstrating its effectiveness in distinguishing between real and fake images.
|
| 22 |
+
|
| 23 |
+
3\. Getting Started
|
| 24 |
+
-------------------
|
| 25 |
+
|
| 26 |
+
### 3.1. Dependencies
|
| 27 |
+
|
| 28 |
+
To ensure a reproducible environment, all dependencies are listed in the requirements.txt file. Install them using pip:
|
| 29 |
+
|
| 30 |
+
```bash
|
| 31 |
+
pip install -r requirements.txt
|
| 32 |
+
```
|
| 33 |
+
|
| 34 |
+
**requirements.txt**:
|
| 35 |
+
- numpy
|
| 36 |
+
- Pillow
|
| 37 |
+
- torch
|
| 38 |
+
- clip-by-openai
|
| 39 |
+
- scikit-learn
|
| 40 |
+
- tqdm
|
| 41 |
+
- seaborn
|
| 42 |
+
- matplotlib
|
| 43 |
+
|
| 44 |
+
### 3.2. Data Preparation
|
| 45 |
+
|
| 46 |
+
The model was trained on a dataset of real and fake images obtained form kaggle the dataset link is https://www.kaggle.com/datasets/tristanzhang32/ai-generated-images-vs-real-images/data$0.
|
| 47 |
+
|
| 48 |
+
### 3.3. Usage
|
| 49 |
+
|
| 50 |
+
#### 3.3.1. Feature Extraction
|
| 51 |
+
|
| 52 |
+
To extract features from a new dataset, run the following command:
|
| 53 |
+
|
| 54 |
+
```
|
| 55 |
+
python extract_features.py --data_dir /path/to/your/data --output_file features.npz
|
| 56 |
+
```
|
| 57 |
+
|
| 58 |
+
#### 3.3.2. Model Training
|
| 59 |
+
|
| 60 |
+
To retrain the SVM model on a new set of extracted features, run:
|
| 61 |
+
|
| 62 |
+
```
|
| 63 |
+
python train_model.py --features_file features.npz --model_output_path model.joblib
|
| 64 |
+
```
|
| 65 |
+
|
| 66 |
+
#### 3.3.3. Inference
|
| 67 |
+
|
| 68 |
+
To classify a single image using the trained model, use the provided inference script:
|
| 69 |
+
```
|
| 70 |
+
python classify.py --image_path /path/to/your/image.jpg --model_path model.joblib
|
| 71 |
+
```
|
| 72 |
+
|
| 73 |
+
4\. Production Deployment
|
| 74 |
+
-------------------------
|
| 75 |
+
|
| 76 |
+
The image classification model is deployed as a microservice. The service exposes an API endpoint that accepts an image and returns a classification result ("real" or "fake").
|
| 77 |
+
|
| 78 |
+
### 4.1. API Specification
|
| 79 |
+
|
| 80 |
+
* **Endpoint:** /classify
|
| 81 |
+
|
| 82 |
+
* **Method:** POST
|
| 83 |
+
|
| 84 |
+
* **Request Body:** multipart/form-data with a single field image.
|
| 85 |
+
|
| 86 |
+
* **Response:**
|
| 87 |
+
|
| 88 |
+
* JSON{ "classification": "real", "confidence": 0.95}
|
| 89 |
+
|
| 90 |
+
* JSON{ "error": "Error message"}
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
### 4.2. Scalability and Monitoring
|
| 94 |
+
|
| 95 |
+
The service is deployed in a containerized environment (e.g., Docker) and managed by an orchestrator (e.g., Kubernetes) to ensure scalability and high availability. Monitoring and logging are in place to track model performance, API latency, and error rates.
|
| 96 |
+
|
| 97 |
+
5\. Model Versioning
|
| 98 |
+
--------------------
|
| 99 |
+
|
| 100 |
+
We use a combination of Git for code versioning and a model registry for tracking trained model artifacts. Each model is versioned and associated with the commit hash of the code that produced it. The current production model is **v1.2.0**.
|
| 101 |
+
|
| 102 |
+
6\. Testing
|
| 103 |
+
-----------
|
| 104 |
+
|
| 105 |
+
The project includes a suite of tests to ensure correctness and reliability:
|
| 106 |
+
|
| 107 |
+
* **Unit tests:** To verify individual functions and components.
|
| 108 |
+
|
| 109 |
+
* **Integration tests:** To test the interaction between different parts of the system.
|
| 110 |
+
|
| 111 |
+
* **Model evaluation tests:** To continuously monitor model performance on a golden dataset.
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
To run the tests, execute:
|
| 115 |
+
```
|
| 116 |
+
pytest
|
| 117 |
+
```
|
| 118 |
+
|
| 119 |
+
7\. Future Work
|
| 120 |
+
---------------
|
| 121 |
+
|
| 122 |
+
* **Explore more advanced classifiers:** Investigate the use of neural network-based classifiers on top of CLIP features.
|
| 123 |
+
|
| 124 |
+
* **Fine-tune the CLIP model:** For even better performance, we can fine-tune the CLIP model on our specific domain of images.
|
| 125 |
+
|
| 126 |
+
* **Expand the training dataset:** Continuously augment the training data with new examples of real and fake images to improve the model's robustness.
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
8\. Contact/Support
|
| 130 |
+
-------------------
|
| 131 |
+
|
| 132 |
+
For any questions or issues regarding this project, please contact the Machine Learning team at [your-team-email@yourcompany.com](mailto:your-team-email@yourcompany.com) .
|
features/ai_human_image_classifier/controller.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import IO
|
| 2 |
+
from preprocessor import preprocessor
|
| 3 |
+
from inferencer import inferencer
|
| 4 |
+
|
| 5 |
+
class ClassificationController:
|
| 6 |
+
"""
|
| 7 |
+
Controller to handle the image classification logic.
|
| 8 |
+
"""
|
| 9 |
+
def classify_image(self, image_file: IO) -> dict:
|
| 10 |
+
"""
|
| 11 |
+
Orchestrates the classification of a single image file.
|
| 12 |
+
|
| 13 |
+
Args:
|
| 14 |
+
image_file (IO): The image file to classify.
|
| 15 |
+
|
| 16 |
+
Returns:
|
| 17 |
+
dict: The classification result.
|
| 18 |
+
"""
|
| 19 |
+
try:
|
| 20 |
+
# Step 1: Preprocess the image
|
| 21 |
+
image_tensor = preprocessor.process(image_file)
|
| 22 |
+
|
| 23 |
+
# Step 2: Perform inference
|
| 24 |
+
result = inferencer.predict(image_tensor)
|
| 25 |
+
|
| 26 |
+
return result
|
| 27 |
+
except ValueError as e:
|
| 28 |
+
# Handle specific errors like invalid images
|
| 29 |
+
return {"error": str(e)}
|
| 30 |
+
except Exception as e:
|
| 31 |
+
# Handle unexpected errors
|
| 32 |
+
print(f"An unexpected error occurred: {e}")
|
| 33 |
+
return {"error": "An internal error occurred during classification."}
|
| 34 |
+
|
| 35 |
+
controller = ClassificationController()
|
features/ai_human_image_classifier/inferencer.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
from model_loader import models
|
| 4 |
+
|
| 5 |
+
class Inferencer:
|
| 6 |
+
|
| 7 |
+
def __init__(self):
|
| 8 |
+
self.clip_model = models.clip_model
|
| 9 |
+
self.svm_model = models.svm_model
|
| 10 |
+
|
| 11 |
+
@torch.no_grad()
|
| 12 |
+
def predict(self, image_tensor:torch.Tensor) -> dict:
|
| 13 |
+
"""
|
| 14 |
+
Takes a preprocessed image tensor and returns the classification result.
|
| 15 |
+
|
| 16 |
+
Args:
|
| 17 |
+
image_tensor (torch.Tensor): The preprocessed image tensor.
|
| 18 |
+
|
| 19 |
+
Returns:
|
| 20 |
+
dict: A dictionary containing the classification label and confidence score.
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
image_features = self.clip_model.encode_image(image_tensor)
|
| 24 |
+
image_features_np = image_features.cpu().numpy()
|
| 25 |
+
|
| 26 |
+
prediction = self.svm_model.predict(image_features_np)[0]
|
| 27 |
+
|
| 28 |
+
if hasattr(self.svm_model, "predict_proba"):
|
| 29 |
+
# If yes, use predict_proba for a true confidence score
|
| 30 |
+
confidence_scores = self.svm_model.predict_proba(image_features_np)[0]
|
| 31 |
+
confidence = float(np.max(confidence_scores))
|
| 32 |
+
else:
|
| 33 |
+
# If no, use decision_function as a fallback confidence measure.
|
| 34 |
+
# The absolute value of the decision function score indicates confidence.
|
| 35 |
+
# We can apply a sigmoid function to scale it to a [0, 1] range for consistency.
|
| 36 |
+
decision_score = self.svm_model.decision_function(image_features_np)[0]
|
| 37 |
+
confidence = 1 / (1 + np.exp(-np.abs(decision_score)))
|
| 38 |
+
confidence = float(confidence)
|
| 39 |
+
|
| 40 |
+
label_map = {0: 'real', 1: 'fake'}
|
| 41 |
+
classification_label = label_map.get(prediction, "unknown")
|
| 42 |
+
|
| 43 |
+
return {
|
| 44 |
+
"classification": classification_label,
|
| 45 |
+
"confidence": confidence
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
inferencer = Inferencer()
|
features/ai_human_image_classifier/main.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import FastAPI
|
| 2 |
+
from routes import router as api_router
|
| 3 |
+
|
| 4 |
+
# Initialize the FastAPI app
|
| 5 |
+
app = FastAPI(
|
| 6 |
+
title="Real vs. Fake Image Classification API",
|
| 7 |
+
description="An API to classify images as real or fake using OpenAI's CLIP and an SVM model.",
|
| 8 |
+
version="1.0.0"
|
| 9 |
+
)
|
| 10 |
+
|
| 11 |
+
# Include the API router
|
| 12 |
+
# All routes defined in routes.py will be available under the /api prefix
|
| 13 |
+
app.include_router(api_router, prefix="/api", tags=["Classification"])
|
| 14 |
+
|
| 15 |
+
@app.get("/", tags=["Root"])
|
| 16 |
+
async def read_root():
|
| 17 |
+
"""
|
| 18 |
+
A simple root endpoint to confirm the API is running.
|
| 19 |
+
"""
|
| 20 |
+
return {"message": "Welcome to the Image Classification API. Go to /docs for the API documentation."}
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
# To run this application:
|
| 24 |
+
# 1. Make sure you have all dependencies from requirements.txt installed.
|
| 25 |
+
# 2. Make sure the 'svm_model.joblib' file is in the same directory.
|
| 26 |
+
# 3. Run the following command in your terminal:
|
| 27 |
+
# uvicorn main:app --reload
|
features/ai_human_image_classifier/model_loader.py
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import clip
|
| 2 |
+
import torch
|
| 3 |
+
import joblib
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from huggingface_hub import hf_hub_download
|
| 6 |
+
|
| 7 |
+
class ModelLoader:
|
| 8 |
+
"""
|
| 9 |
+
A class to load and hold the machine learning models.
|
| 10 |
+
This ensures that models are loaded only once.
|
| 11 |
+
"""
|
| 12 |
+
def __init__(self, clip_model_name: str, svm_repo_id: str, svm_filename: str):
|
| 13 |
+
"""
|
| 14 |
+
Initializes the ModelLoader and loads the models.
|
| 15 |
+
|
| 16 |
+
Args:
|
| 17 |
+
clip_model_name (str): The name of the CLIP model to load (e.g., 'ViT-L/14').
|
| 18 |
+
svm_repo_id (str): The repository ID on Hugging Face (e.g., 'rhnsa/ai_human_image_detector').
|
| 19 |
+
svm_filename (str): The name of the model file in the repository (e.g., 'model.joblib').
|
| 20 |
+
"""
|
| 21 |
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 22 |
+
print(f"Using device: {self.device}")
|
| 23 |
+
|
| 24 |
+
self.clip_model, self.clip_preprocess = self._load_clip_model(clip_model_name)
|
| 25 |
+
self.svm_model = self._load_svm_model(repo_id=svm_repo_id, filename=svm_filename)
|
| 26 |
+
print("Models loaded successfully.")
|
| 27 |
+
|
| 28 |
+
def _load_clip_model(self, model_name: str):
|
| 29 |
+
"""
|
| 30 |
+
Loads the specified CLIP model and its preprocessor.
|
| 31 |
+
|
| 32 |
+
Args:
|
| 33 |
+
model_name (str): The name of the CLIP model.
|
| 34 |
+
|
| 35 |
+
Returns:
|
| 36 |
+
A tuple containing the loaded CLIP model and its preprocess function.
|
| 37 |
+
"""
|
| 38 |
+
try:
|
| 39 |
+
model, preprocess = clip.load(model_name, device=self.device)
|
| 40 |
+
return model, preprocess
|
| 41 |
+
except Exception as e:
|
| 42 |
+
print(f"Error loading CLIP model: {e}")
|
| 43 |
+
raise
|
| 44 |
+
|
| 45 |
+
def _load_svm_model(self, repo_id: str, filename: str):
|
| 46 |
+
"""
|
| 47 |
+
Downloads and loads the SVM model from a Hugging Face Hub repository.
|
| 48 |
+
|
| 49 |
+
Args:
|
| 50 |
+
repo_id (str): The repository ID on Hugging Face.
|
| 51 |
+
filename (str): The name of the model file in the repository.
|
| 52 |
+
|
| 53 |
+
Returns:
|
| 54 |
+
The loaded SVM model object.
|
| 55 |
+
"""
|
| 56 |
+
print(f"Downloading SVM model from Hugging Face repo: {repo_id}")
|
| 57 |
+
try:
|
| 58 |
+
# Download the model file from the Hub. It returns the cached path.
|
| 59 |
+
model_path = hf_hub_download(repo_id=repo_id, filename=filename)
|
| 60 |
+
print(f"SVM model downloaded to: {model_path}")
|
| 61 |
+
|
| 62 |
+
# Load the model from the downloaded path
|
| 63 |
+
svm_model = joblib.load(model_path)
|
| 64 |
+
return svm_model
|
| 65 |
+
except Exception as e:
|
| 66 |
+
print(f"Error downloading or loading SVM model from Hugging Face: {e}")
|
| 67 |
+
raise
|
| 68 |
+
|
| 69 |
+
# --- Global Model Instance ---
|
| 70 |
+
# This creates a single instance of the models that can be imported by other modules.
|
| 71 |
+
CLIP_MODEL_NAME = 'ViT-L/14'
|
| 72 |
+
SVM_REPO_ID = 'rhnsa/ai_human_image_detector'
|
| 73 |
+
SVM_FILENAME = 'svm_model_real.joblib' # The name of your model file in the Hugging Face repo
|
| 74 |
+
|
| 75 |
+
# This instance will be created when the application starts.
|
| 76 |
+
models = ModelLoader(
|
| 77 |
+
clip_model_name=CLIP_MODEL_NAME,
|
| 78 |
+
svm_repo_id=SVM_REPO_ID,
|
| 79 |
+
svm_filename=SVM_FILENAME
|
| 80 |
+
)
|
features/ai_human_image_classifier/preprocessor.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from PIL import Image
|
| 2 |
+
import torch
|
| 3 |
+
from typing import IO
|
| 4 |
+
from model_loader import models
|
| 5 |
+
|
| 6 |
+
class ImagePreprocessor:
|
| 7 |
+
|
| 8 |
+
def __init__(self):
|
| 9 |
+
self.preprocess = models.clip_preprocess
|
| 10 |
+
self.device = models.device
|
| 11 |
+
|
| 12 |
+
def process(self, image_file: IO) -> torch.Tensor:
|
| 13 |
+
"""
|
| 14 |
+
Opens an image file, preprocesses it, and returns it as a tensor.
|
| 15 |
+
|
| 16 |
+
Args:
|
| 17 |
+
image_file (IO): The image file object (e.g., from a file upload).
|
| 18 |
+
|
| 19 |
+
Returns:
|
| 20 |
+
torch.Tensor: The preprocessed image as a tensor, ready for the model.
|
| 21 |
+
"""
|
| 22 |
+
try:
|
| 23 |
+
# Open the image from the file-like object
|
| 24 |
+
image = Image.open(image_file).convert("RGB")
|
| 25 |
+
except Exception as e:
|
| 26 |
+
print(f"Error opening image: {e}")
|
| 27 |
+
# You might want to raise a custom exception here
|
| 28 |
+
raise ValueError("Invalid or corrupted image file.")
|
| 29 |
+
|
| 30 |
+
# Apply the CLIP preprocessing transformations and move to the correct device
|
| 31 |
+
image_tensor = self.preprocess(image).unsqueeze(0).to(self.device)
|
| 32 |
+
return image_tensor
|
| 33 |
+
|
| 34 |
+
preprocessor = ImagePreprocessor()
|
features/ai_human_image_classifier/routes.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import APIRouter, File, UploadFile, HTTPException, status
|
| 2 |
+
from fastapi.responses import JSONResponse
|
| 3 |
+
from controller import controller
|
| 4 |
+
|
| 5 |
+
from fastapi import Request, Depends
|
| 6 |
+
from fastapi.security import HTTPBearer
|
| 7 |
+
from slowapi import Limiter
|
| 8 |
+
from slowapi.util import get_remote_address
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
router = APIRouter()
|
| 12 |
+
limiter = Limiter(key_func=get_remote_address)
|
| 13 |
+
security = HTTPBearer()
|
| 14 |
+
# Create an API router
|
| 15 |
+
router = APIRouter()
|
| 16 |
+
|
| 17 |
+
@router.post("/classify", summary="Classify an image as Real or Fake")
|
| 18 |
+
async def classify_image_endpoint(image: UploadFile = File(...)):
|
| 19 |
+
"""
|
| 20 |
+
Accepts an image file and classifies it as 'real' or 'fake'.
|
| 21 |
+
|
| 22 |
+
- **image**: The image file to be classified (e.g., JPEG, PNG).
|
| 23 |
+
|
| 24 |
+
Returns a JSON object with the classification and a confidence score.
|
| 25 |
+
"""
|
| 26 |
+
# Check for a valid image content type
|
| 27 |
+
if not image.content_type.startswith("image/"):
|
| 28 |
+
raise HTTPException(
|
| 29 |
+
status_code=status.HTTP_415_UNSUPPORTED_MEDIA_TYPE,
|
| 30 |
+
detail="Unsupported file type. Please upload an image (e.g., JPEG, PNG)."
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
# The controller expects a file-like object, which `image.file` provides
|
| 34 |
+
result = controller.classify_image(image.file)
|
| 35 |
+
|
| 36 |
+
if "error" in result:
|
| 37 |
+
# If the controller returned an error, forward it as an HTTP exception
|
| 38 |
+
raise HTTPException(
|
| 39 |
+
status_code=status.HTTP_400_BAD_REQUEST,
|
| 40 |
+
detail=result["error"]
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
return JSONResponse(content=result, status_code=status.HTTP_200_OK)
|
| 44 |
+
|
features/nepali_text_classifier/preprocess.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
| 1 |
-
import fitz # PyMuPDF
|
| 2 |
import docx
|
| 3 |
from io import BytesIO
|
| 4 |
import logging
|
| 5 |
from fastapi import HTTPException
|
| 6 |
-
|
| 7 |
|
| 8 |
def parse_docx(file: BytesIO):
|
| 9 |
doc = docx.Document(file)
|
|
@@ -15,11 +15,10 @@ def parse_docx(file: BytesIO):
|
|
| 15 |
|
| 16 |
def parse_pdf(file: BytesIO):
|
| 17 |
try:
|
| 18 |
-
doc =
|
| 19 |
text = ""
|
| 20 |
-
for
|
| 21 |
-
|
| 22 |
-
text += page.get_text()
|
| 23 |
return text
|
| 24 |
except Exception as e:
|
| 25 |
logging.error(f"Error while processing PDF: {str(e)}")
|
|
|
|
| 1 |
+
# import fitz # PyMuPDF
|
| 2 |
import docx
|
| 3 |
from io import BytesIO
|
| 4 |
import logging
|
| 5 |
from fastapi import HTTPException
|
| 6 |
+
from pypdf import PdfReader
|
| 7 |
|
| 8 |
def parse_docx(file: BytesIO):
|
| 9 |
doc = docx.Document(file)
|
|
|
|
| 15 |
|
| 16 |
def parse_pdf(file: BytesIO):
|
| 17 |
try:
|
| 18 |
+
doc = PdfReader(file)
|
| 19 |
text = ""
|
| 20 |
+
for page in doc.pages:
|
| 21 |
+
text += page.extract_text()
|
|
|
|
| 22 |
return text
|
| 23 |
except Exception as e:
|
| 24 |
logging.error(f"Error while processing PDF: {str(e)}")
|
features/rag_chatbot/__init__.py
ADDED
|
File without changes
|
features/rag_chatbot/controller.py
ADDED
|
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import asyncio
|
| 3 |
+
import logging
|
| 4 |
+
from io import BytesIO
|
| 5 |
+
from typing import Dict, Any
|
| 6 |
+
|
| 7 |
+
from fastapi import HTTPException, UploadFile, status, Depends
|
| 8 |
+
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
| 9 |
+
|
| 10 |
+
from .rag_pipeline import route_and_process_query, add_document_to_rag, check_system_health
|
| 11 |
+
from .document_handler import extract_text_from_file
|
| 12 |
+
|
| 13 |
+
# Configure logging
|
| 14 |
+
logging.basicConfig(level=logging.INFO)
|
| 15 |
+
logger = logging.getLogger(__name__)
|
| 16 |
+
|
| 17 |
+
security = HTTPBearer()
|
| 18 |
+
|
| 19 |
+
# Supported file types
|
| 20 |
+
SUPPORTED_CONTENT_TYPES = {
|
| 21 |
+
"application/pdf",
|
| 22 |
+
"application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
| 23 |
+
"text/plain"
|
| 24 |
+
}
|
| 25 |
+
|
| 26 |
+
MAX_FILE_SIZE = 100 * 1024 * 1024 # 100MB
|
| 27 |
+
|
| 28 |
+
async def verify_token(credentials: HTTPAuthorizationCredentials = Depends(security)):
|
| 29 |
+
"""Verify Bearer token from Authorization header."""
|
| 30 |
+
token = credentials.credentials
|
| 31 |
+
expected_token = os.getenv("MY_SECRET_TOKEN")
|
| 32 |
+
|
| 33 |
+
if not expected_token:
|
| 34 |
+
logger.error("MY_SECRET_TOKEN not configured")
|
| 35 |
+
raise HTTPException(
|
| 36 |
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
| 37 |
+
detail="Server configuration error"
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
if token != expected_token:
|
| 41 |
+
logger.warning(f"Invalid token attempt: {token[:10]}...")
|
| 42 |
+
raise HTTPException(
|
| 43 |
+
status_code=status.HTTP_403_FORBIDDEN,
|
| 44 |
+
detail="Invalid or expired token"
|
| 45 |
+
)
|
| 46 |
+
return token
|
| 47 |
+
|
| 48 |
+
async def handle_rag_query(query: str) -> Dict[str, Any]:
|
| 49 |
+
"""Handle an incoming query by routing it and getting the appropriate answer."""
|
| 50 |
+
|
| 51 |
+
# Input validation
|
| 52 |
+
if not query or not query.strip():
|
| 53 |
+
raise HTTPException(
|
| 54 |
+
status_code=status.HTTP_400_BAD_REQUEST,
|
| 55 |
+
detail="Query cannot be empty"
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
if len(query) > 1000: # Reasonable limit
|
| 59 |
+
raise HTTPException(
|
| 60 |
+
status_code=status.HTTP_400_BAD_REQUEST,
|
| 61 |
+
detail="Query too long. Please limit to 1000 characters."
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
try:
|
| 65 |
+
logger.info(f"Processing query: {query[:50]}...")
|
| 66 |
+
|
| 67 |
+
# Process query in thread pool
|
| 68 |
+
response = await asyncio.to_thread(route_and_process_query, query)
|
| 69 |
+
|
| 70 |
+
logger.info(f"Query processed successfully. Route: {response.get('route', 'Unknown')}")
|
| 71 |
+
return response
|
| 72 |
+
|
| 73 |
+
except Exception as e:
|
| 74 |
+
logger.error(f"Error processing query: {e}")
|
| 75 |
+
raise HTTPException(
|
| 76 |
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
| 77 |
+
detail="Error processing your query. Please try again."
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
async def handle_document_upload(file: UploadFile) -> Dict[str, str]:
|
| 81 |
+
"""Handle uploading a document to the RAG's vector store."""
|
| 82 |
+
|
| 83 |
+
# File validation
|
| 84 |
+
if not file.filename:
|
| 85 |
+
raise HTTPException(
|
| 86 |
+
status_code=status.HTTP_400_BAD_REQUEST,
|
| 87 |
+
detail="No file provided"
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
if file.content_type not in SUPPORTED_CONTENT_TYPES:
|
| 91 |
+
raise HTTPException(
|
| 92 |
+
status_code=status.HTTP_415_UNSUPPORTED_MEDIA_TYPE,
|
| 93 |
+
detail=f"Unsupported file type: {file.content_type}. "
|
| 94 |
+
f"Supported types: {', '.join(SUPPORTED_CONTENT_TYPES)}"
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
# Check file size
|
| 98 |
+
contents = await file.read()
|
| 99 |
+
if len(contents) > MAX_FILE_SIZE:
|
| 100 |
+
raise HTTPException(
|
| 101 |
+
status_code=status.HTTP_413_REQUEST_ENTITY_TOO_LARGE,
|
| 102 |
+
detail=f"File too large. Maximum size: {MAX_FILE_SIZE / (1024*1024):.1f}MB"
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
# Reset file pointer
|
| 106 |
+
await file.seek(0)
|
| 107 |
+
|
| 108 |
+
try:
|
| 109 |
+
logger.info(f"Processing file upload: {file.filename}")
|
| 110 |
+
|
| 111 |
+
# Extract text from file
|
| 112 |
+
text = await extract_text_from_file(file)
|
| 113 |
+
|
| 114 |
+
if not text or not text.strip():
|
| 115 |
+
raise HTTPException(
|
| 116 |
+
status_code=status.HTTP_400_BAD_REQUEST,
|
| 117 |
+
detail="The file appears to be empty or could not be read."
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
if len(text) < 50: # Too short to be meaningful
|
| 121 |
+
raise HTTPException(
|
| 122 |
+
status_code=status.HTTP_400_BAD_REQUEST,
|
| 123 |
+
detail="The extracted text is too short to be meaningful."
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
# Add to RAG system
|
| 127 |
+
success = await asyncio.to_thread(
|
| 128 |
+
add_document_to_rag,
|
| 129 |
+
text,
|
| 130 |
+
{
|
| 131 |
+
"source": file.filename,
|
| 132 |
+
"content_type": file.content_type,
|
| 133 |
+
"size": len(contents)
|
| 134 |
+
}
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
if not success:
|
| 138 |
+
raise HTTPException(
|
| 139 |
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
| 140 |
+
detail="Failed to add document to the knowledge base"
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
logger.info(f"Successfully processed file: {file.filename}")
|
| 144 |
+
|
| 145 |
+
return {
|
| 146 |
+
"message": f"Successfully uploaded and processed '{file.filename}'. "
|
| 147 |
+
f"It is now available for querying.",
|
| 148 |
+
"filename": file.filename,
|
| 149 |
+
"text_length": len(text),
|
| 150 |
+
"content_type": file.content_type
|
| 151 |
+
}
|
| 152 |
+
|
| 153 |
+
except HTTPException:
|
| 154 |
+
raise
|
| 155 |
+
except Exception as e:
|
| 156 |
+
logger.error(f"Error processing file {file.filename}: {e}")
|
| 157 |
+
raise HTTPException(
|
| 158 |
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
| 159 |
+
detail="Error processing the file. Please try again."
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
async def handle_health_check() -> Dict[str, Any]:
|
| 163 |
+
"""Handle health check requests."""
|
| 164 |
+
try:
|
| 165 |
+
health_status = await asyncio.to_thread(check_system_health)
|
| 166 |
+
|
| 167 |
+
if health_status["status"] == "unhealthy":
|
| 168 |
+
raise HTTPException(
|
| 169 |
+
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
| 170 |
+
detail="Service is currently unhealthy"
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
return health_status
|
| 174 |
+
|
| 175 |
+
except HTTPException:
|
| 176 |
+
raise
|
| 177 |
+
except Exception as e:
|
| 178 |
+
logger.error(f"Health check failed: {e}")
|
| 179 |
+
raise HTTPException(
|
| 180 |
+
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
| 181 |
+
detail="Health check failed"
|
| 182 |
+
)
|
features/rag_chatbot/document_handler.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from io import BytesIO
|
| 2 |
+
from fastapi import UploadFile, HTTPException
|
| 3 |
+
import PyPDF2
|
| 4 |
+
import docx
|
| 5 |
+
|
| 6 |
+
async def extract_text_from_file(file: UploadFile) -> str:
|
| 7 |
+
"""Extracts text from various file types."""
|
| 8 |
+
content = await file.read()
|
| 9 |
+
file_stream = BytesIO(content)
|
| 10 |
+
|
| 11 |
+
if file.content_type == "application/pdf":
|
| 12 |
+
return extract_text_from_pdf(file_stream)
|
| 13 |
+
elif file.content_type == "application/vnd.openxmlformats-officedocument.wordprocessingml.document":
|
| 14 |
+
return extract_text_from_docx(file_stream)
|
| 15 |
+
elif file.content_type == "text/plain":
|
| 16 |
+
return file_stream.read().decode("utf-8")
|
| 17 |
+
else:
|
| 18 |
+
raise HTTPException(
|
| 19 |
+
status_code=415,
|
| 20 |
+
detail="Unsupported file type. Please upload a .pdf, .docx, or .txt file."
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
def extract_text_from_pdf(file_stream: BytesIO) -> str:
|
| 24 |
+
"""Extracts text from a PDF file."""
|
| 25 |
+
reader = PyPDF2.PdfReader(file_stream)
|
| 26 |
+
text = ""
|
| 27 |
+
for page in reader.pages:
|
| 28 |
+
text += page.extract_text() or ""
|
| 29 |
+
return text
|
| 30 |
+
|
| 31 |
+
def extract_text_from_docx(file_stream: BytesIO) -> str:
|
| 32 |
+
"""Extracts text from a DOCX file."""
|
| 33 |
+
doc = docx.Document(file_stream)
|
| 34 |
+
text = ""
|
| 35 |
+
for para in doc.paragraphs:
|
| 36 |
+
text += para.text + "\n"
|
| 37 |
+
return text
|
features/rag_chatbot/rag_pipeline.py
ADDED
|
@@ -0,0 +1,327 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import chromadb
|
| 3 |
+
from dotenv import load_dotenv
|
| 4 |
+
from langchain_core.documents import Document
|
| 5 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
| 6 |
+
from langchain_community.embeddings import HuggingFaceEmbeddings
|
| 7 |
+
from langchain_community.llms import OpenAI
|
| 8 |
+
from langchain.chains.question_answering import load_qa_chain
|
| 9 |
+
from langchain_community.vectorstores import Chroma
|
| 10 |
+
from langchain.chains import LLMChain
|
| 11 |
+
from langchain.prompts import PromptTemplate
|
| 12 |
+
from langchain.chat_models import ChatOpenAI
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
load_dotenv()
|
| 16 |
+
|
| 17 |
+
# ChromaDB configuration
|
| 18 |
+
CHROMA_HOST = os.getenv("CHROMA_HOST", "localhost") # change in env in production when hosted
|
| 19 |
+
COLLECTION_NAME = "company_docs_collection"
|
| 20 |
+
|
| 21 |
+
# LLM Provider Configuration
|
| 22 |
+
LLM_PROVIDER = os.getenv("LLM_PROVIDER", "openai").lower()
|
| 23 |
+
LLM_API_KEY = os.getenv("LLM_API_KEY")
|
| 24 |
+
LLM_MODEL = os.getenv("LLM_MODEL", "gpt-3.5-turbo")
|
| 25 |
+
LLM_TEMPERATURE = float(os.getenv("LLM_TEMPERATURE", "0"))
|
| 26 |
+
LLM_MAX_TOKENS = int(os.getenv("LLM_MAX_TOKENS", "2048"))
|
| 27 |
+
|
| 28 |
+
# Provider-specific configurations
|
| 29 |
+
PROVIDER_CONFIGS = {
|
| 30 |
+
"openai": {
|
| 31 |
+
"api_base": "https://api.openai.com/v1",
|
| 32 |
+
"default_model": "gpt-3.5-turbo"
|
| 33 |
+
},
|
| 34 |
+
"groq": {
|
| 35 |
+
"api_base": "https://api.groq.com/openai/v1",
|
| 36 |
+
"default_model": "llama-3.3-70b-versatile"
|
| 37 |
+
},
|
| 38 |
+
"openrouter": {
|
| 39 |
+
"api_base": "https://openrouter.ai/api/v1",
|
| 40 |
+
"default_model": "mistralai/mistral-small-3.2-24b-instruct:free"
|
| 41 |
+
}
|
| 42 |
+
}
|
| 43 |
+
|
| 44 |
+
vector_store = None
|
| 45 |
+
company_qa_chain = None
|
| 46 |
+
query_router_chain = None
|
| 47 |
+
cybersecurity_chain = None
|
| 48 |
+
llm = None
|
| 49 |
+
|
| 50 |
+
def get_llm_config():
|
| 51 |
+
"""Get the appropriate LLM configuration based on the provider."""
|
| 52 |
+
if LLM_PROVIDER not in PROVIDER_CONFIGS:
|
| 53 |
+
raise ValueError(f"Unsupported LLM provider: {LLM_PROVIDER}. Supported: {list(PROVIDER_CONFIGS.keys())}")
|
| 54 |
+
|
| 55 |
+
config = PROVIDER_CONFIGS[LLM_PROVIDER].copy()
|
| 56 |
+
|
| 57 |
+
# Use provided model or fall back to default
|
| 58 |
+
model = LLM_MODEL if LLM_MODEL != "gpt-3.5-turbo" else config["default_model"]
|
| 59 |
+
|
| 60 |
+
return {
|
| 61 |
+
"model": model,
|
| 62 |
+
"openai_api_key": LLM_API_KEY,
|
| 63 |
+
"openai_api_base": config["api_base"],
|
| 64 |
+
"temperature": LLM_TEMPERATURE,
|
| 65 |
+
"max_tokens": LLM_MAX_TOKENS,
|
| 66 |
+
}
|
| 67 |
+
|
| 68 |
+
def initialize_llm():
|
| 69 |
+
"""Initialize the LLM based on the configured provider."""
|
| 70 |
+
if not LLM_API_KEY:
|
| 71 |
+
raise ValueError(f"LLM_API_KEY environment variable is required for {LLM_PROVIDER}")
|
| 72 |
+
|
| 73 |
+
config = get_llm_config()
|
| 74 |
+
|
| 75 |
+
print(f"Initializing {LLM_PROVIDER.upper()} with model: {config['model']}")
|
| 76 |
+
|
| 77 |
+
return ChatOpenAI(**config)
|
| 78 |
+
|
| 79 |
+
def initialize_pipelines():
|
| 80 |
+
"""Initializes all required models, chains, and the vector store."""
|
| 81 |
+
global vector_store, company_qa_chain, query_router_chain, cybersecurity_chain, llm
|
| 82 |
+
|
| 83 |
+
try:
|
| 84 |
+
# Initialize LLM
|
| 85 |
+
llm = initialize_llm()
|
| 86 |
+
|
| 87 |
+
# Initialize embeddings
|
| 88 |
+
embeddings = HuggingFaceEmbeddings(
|
| 89 |
+
model_name="all-MiniLM-L6-v2",
|
| 90 |
+
model_kwargs={'device': 'cpu'},
|
| 91 |
+
encode_kwargs={'normalize_embeddings': True}
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
# Initialize ChromaDB client
|
| 95 |
+
try:
|
| 96 |
+
chroma_client = chromadb.HttpClient(host=CHROMA_HOST, port=8000)
|
| 97 |
+
chroma_client.heartbeat()
|
| 98 |
+
except Exception as e:
|
| 99 |
+
raise ConnectionError("Failed to connect to ChromaDB.") from e
|
| 100 |
+
|
| 101 |
+
# Initialize vector store
|
| 102 |
+
vector_store = Chroma(
|
| 103 |
+
client=chroma_client,
|
| 104 |
+
collection_name=COLLECTION_NAME,
|
| 105 |
+
embedding_function=embeddings,
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
# Query Router Chain
|
| 109 |
+
router_template = """You are a query classifier. Classify the following query into one of these categories:
|
| 110 |
+
- COMPANY: Questions about our company, its products, services, or general information
|
| 111 |
+
- CYBERSECURITY: Questions about cybersecurity, security threats, best practices, or vulnerabilities
|
| 112 |
+
- OFF_TOPIC: Questions that don't fit the above categories
|
| 113 |
+
|
| 114 |
+
Query: {query}
|
| 115 |
+
|
| 116 |
+
Respond with only the category name (COMPANY, CYBERSECURITY, or OFF_TOPIC):"""
|
| 117 |
+
|
| 118 |
+
router_prompt = PromptTemplate(
|
| 119 |
+
input_variables=["query"],
|
| 120 |
+
template=router_template
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
query_router_chain = LLMChain(
|
| 124 |
+
llm=llm,
|
| 125 |
+
prompt=router_prompt
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
# Custom Company QA Chain
|
| 129 |
+
company_qa_template = """You are a helpful assistant for CyberAlertNepal. Answer the following question about our company using the information provided and links if only available. Give a natural, direct and polite response.
|
| 130 |
+
|
| 131 |
+
Question: {question}
|
| 132 |
+
|
| 133 |
+
Information:
|
| 134 |
+
{context}
|
| 135 |
+
|
| 136 |
+
Answer:"""
|
| 137 |
+
|
| 138 |
+
company_qa_prompt = PromptTemplate(
|
| 139 |
+
input_variables=["question", "context"],
|
| 140 |
+
template=company_qa_template
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
company_qa_chain = LLMChain(
|
| 144 |
+
llm=llm,
|
| 145 |
+
prompt=company_qa_prompt
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
# Cybersecurity Chain
|
| 149 |
+
cybersecurity_template = """You are a cybersecurity professional. Answer the following question truthfully and concisely.
|
| 150 |
+
If you are not 100% sure about the answer, simply respond with: "I am not sure about the answer."
|
| 151 |
+
Do not add extra explanations or assumptions. Do not provide false or speculative information.
|
| 152 |
+
|
| 153 |
+
Question: {question}
|
| 154 |
+
|
| 155 |
+
Provide a comprehensive and accurate answer about cybersecurity:"""
|
| 156 |
+
|
| 157 |
+
cybersecurity_prompt = PromptTemplate(
|
| 158 |
+
input_variables=["question"],
|
| 159 |
+
template=cybersecurity_template
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
cybersecurity_chain = LLMChain(
|
| 163 |
+
llm=llm,
|
| 164 |
+
prompt=cybersecurity_prompt
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
print(f"Successfully initialized pipelines with {LLM_PROVIDER.upper()}")
|
| 168 |
+
|
| 169 |
+
except Exception as e:
|
| 170 |
+
print(f"Error initializing pipelines: {e}")
|
| 171 |
+
raise
|
| 172 |
+
|
| 173 |
+
def add_document_to_rag(text: str, metadata: dict):
|
| 174 |
+
"""Splits a document and adds it to the ChromaDB index."""
|
| 175 |
+
global vector_store
|
| 176 |
+
|
| 177 |
+
if not vector_store:
|
| 178 |
+
initialize_pipelines()
|
| 179 |
+
|
| 180 |
+
try:
|
| 181 |
+
text_splitter = RecursiveCharacterTextSplitter(
|
| 182 |
+
chunk_size=1000,
|
| 183 |
+
chunk_overlap=200
|
| 184 |
+
)
|
| 185 |
+
docs = text_splitter.create_documents([text], metadatas=[metadata])
|
| 186 |
+
|
| 187 |
+
if not docs:
|
| 188 |
+
print("Document was empty after splitting, not adding to ChromaDB.")
|
| 189 |
+
return False
|
| 190 |
+
|
| 191 |
+
vector_store.add_documents(docs)
|
| 192 |
+
print("Successfully added documents.")
|
| 193 |
+
return True
|
| 194 |
+
|
| 195 |
+
except Exception as e:
|
| 196 |
+
print(f"Error adding document to RAG: {e}")
|
| 197 |
+
return False
|
| 198 |
+
|
| 199 |
+
def route_and_process_query(query: str):
|
| 200 |
+
"""Routes the query and processes it using the appropriate pipeline."""
|
| 201 |
+
global query_router_chain, vector_store, company_qa_chain, cybersecurity_chain
|
| 202 |
+
|
| 203 |
+
if not all([query_router_chain, vector_store, company_qa_chain, cybersecurity_chain]):
|
| 204 |
+
initialize_pipelines()
|
| 205 |
+
|
| 206 |
+
try:
|
| 207 |
+
# 1. Classify the query
|
| 208 |
+
route_result = query_router_chain.run(query)
|
| 209 |
+
route = route_result.strip().upper()
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
# 2. Route to appropriate logic
|
| 213 |
+
if "CYBERSECURITY" in route:
|
| 214 |
+
answer = cybersecurity_chain.run(question=query)
|
| 215 |
+
return {
|
| 216 |
+
"answer": answer,
|
| 217 |
+
"source": "Cybersecurity Knowledge Base",
|
| 218 |
+
"route": "CYBERSECURITY",
|
| 219 |
+
"provider": LLM_PROVIDER.upper(),
|
| 220 |
+
"model": get_llm_config()["model"]
|
| 221 |
+
}
|
| 222 |
+
|
| 223 |
+
elif "COMPANY" in route:
|
| 224 |
+
# Perform similarity search on ChromaDB
|
| 225 |
+
docs = vector_store.similarity_search(query, k=3)
|
| 226 |
+
|
| 227 |
+
if not docs:
|
| 228 |
+
return {
|
| 229 |
+
"answer": "I could not find any relevant information to answer your question.",
|
| 230 |
+
"source": "Company Documents",
|
| 231 |
+
"route": "COMPANY",
|
| 232 |
+
"provider": LLM_PROVIDER.upper(),
|
| 233 |
+
"model": get_llm_config()["model"]
|
| 234 |
+
}
|
| 235 |
+
|
| 236 |
+
# Combine document content for context
|
| 237 |
+
context = "\n\n".join([doc.page_content for doc in docs])
|
| 238 |
+
|
| 239 |
+
# Run the custom QA chain
|
| 240 |
+
answer = company_qa_chain.run(question=query, context=context)
|
| 241 |
+
sources = list(set([doc.metadata.get("source", "Unknown") for doc in docs]))
|
| 242 |
+
|
| 243 |
+
return {
|
| 244 |
+
"answer": answer,
|
| 245 |
+
"source": "Company Documents",
|
| 246 |
+
"documents": sources,
|
| 247 |
+
"route": "COMPANY",
|
| 248 |
+
"provider": LLM_PROVIDER.upper(),
|
| 249 |
+
"model": get_llm_config()["model"]
|
| 250 |
+
}
|
| 251 |
+
|
| 252 |
+
else: # OFF_TOPIC
|
| 253 |
+
return {
|
| 254 |
+
"answer": "I am a specialized assistant of CyberAlertNepal. I cannot answer questions outside of cybersecurity topics.",
|
| 255 |
+
"source": "N/A",
|
| 256 |
+
"route": "OFF_TOPIC",
|
| 257 |
+
"provider": LLM_PROVIDER.upper(),
|
| 258 |
+
"model": get_llm_config()["model"]
|
| 259 |
+
}
|
| 260 |
+
|
| 261 |
+
except Exception as e:
|
| 262 |
+
print(f"Error processing query: {e}")
|
| 263 |
+
return {
|
| 264 |
+
"answer": "I encountered an error while processing your query. Please try again.",
|
| 265 |
+
"source": "Error",
|
| 266 |
+
"route": None,
|
| 267 |
+
"documents": None,
|
| 268 |
+
"provider": LLM_PROVIDER.upper(),
|
| 269 |
+
"error": str(e)
|
| 270 |
+
}
|
| 271 |
+
|
| 272 |
+
def check_system_health():
|
| 273 |
+
"""Check if all components are properly initialized."""
|
| 274 |
+
try:
|
| 275 |
+
# Test ChromaDB connection
|
| 276 |
+
if vector_store:
|
| 277 |
+
vector_store._client.heartbeat()
|
| 278 |
+
|
| 279 |
+
# Test if all chains are initialized
|
| 280 |
+
components = {
|
| 281 |
+
"vector_store": vector_store is not None,
|
| 282 |
+
"company_qa_chain": company_qa_chain is not None,
|
| 283 |
+
"query_router_chain": query_router_chain is not None,
|
| 284 |
+
"cybersecurity_chain": cybersecurity_chain is not None,
|
| 285 |
+
"llm": llm is not None
|
| 286 |
+
}
|
| 287 |
+
|
| 288 |
+
return {
|
| 289 |
+
"status": "healthy" if all(components.values()) else "unhealthy",
|
| 290 |
+
"components": components,
|
| 291 |
+
"provider": LLM_PROVIDER.upper(),
|
| 292 |
+
"model": get_llm_config()["model"] if llm else "Not initialized"
|
| 293 |
+
}
|
| 294 |
+
|
| 295 |
+
except Exception as e:
|
| 296 |
+
return {
|
| 297 |
+
"status": "unhealthy",
|
| 298 |
+
"error": str(e),
|
| 299 |
+
"provider": LLM_PROVIDER.upper()
|
| 300 |
+
}
|
| 301 |
+
|
| 302 |
+
def test_llm_connection():
|
| 303 |
+
"""Test the LLM API connection."""
|
| 304 |
+
try:
|
| 305 |
+
if not llm:
|
| 306 |
+
initialize_pipelines()
|
| 307 |
+
|
| 308 |
+
# Simple test query
|
| 309 |
+
test_response = llm("Say 'Hello, LLM is working!'")
|
| 310 |
+
return {
|
| 311 |
+
"success": True,
|
| 312 |
+
"provider": LLM_PROVIDER.upper(),
|
| 313 |
+
"model": get_llm_config()["model"],
|
| 314 |
+
"response": str(test_response)
|
| 315 |
+
}
|
| 316 |
+
except Exception as e:
|
| 317 |
+
return {
|
| 318 |
+
"success": False,
|
| 319 |
+
"provider": LLM_PROVIDER.upper(),
|
| 320 |
+
"error": str(e)
|
| 321 |
+
}
|
| 322 |
+
|
| 323 |
+
# Initialize pipelines on module import
|
| 324 |
+
try:
|
| 325 |
+
initialize_pipelines()
|
| 326 |
+
except Exception as e:
|
| 327 |
+
print(f"Failed to initialize pipelines on startup: {e}")
|
features/rag_chatbot/routes.py
ADDED
|
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import APIRouter, Depends, HTTPException, UploadFile, File, Request
|
| 2 |
+
from fastapi.security import HTTPBearer
|
| 3 |
+
from pydantic import BaseModel, Field
|
| 4 |
+
from slowapi.util import get_remote_address
|
| 5 |
+
from slowapi import Limiter
|
| 6 |
+
from typing import Optional
|
| 7 |
+
from config import ACCESS_RATE
|
| 8 |
+
from .controller import (
|
| 9 |
+
handle_rag_query,
|
| 10 |
+
handle_document_upload,
|
| 11 |
+
handle_health_check,
|
| 12 |
+
verify_token,
|
| 13 |
+
)
|
| 14 |
+
|
| 15 |
+
limiter = Limiter(key_func=get_remote_address)
|
| 16 |
+
router = APIRouter(prefix="/rag", tags=["RAG Chatbot"])
|
| 17 |
+
security = HTTPBearer()
|
| 18 |
+
|
| 19 |
+
class QueryInput(BaseModel):
|
| 20 |
+
query: str = Field(..., min_length=1, max_length=1000, description="The question to ask")
|
| 21 |
+
|
| 22 |
+
class QueryResponse(BaseModel):
|
| 23 |
+
answer: str
|
| 24 |
+
source: str
|
| 25 |
+
route: Optional[str] = None
|
| 26 |
+
documents: Optional[list] = None
|
| 27 |
+
error: Optional[str] = None
|
| 28 |
+
|
| 29 |
+
class UploadResponse(BaseModel):
|
| 30 |
+
message: str
|
| 31 |
+
filename: str
|
| 32 |
+
text_length: int
|
| 33 |
+
content_type: str
|
| 34 |
+
|
| 35 |
+
class HealthResponse(BaseModel):
|
| 36 |
+
status: str
|
| 37 |
+
components: Optional[dict] = None
|
| 38 |
+
error: Optional[str] = None
|
| 39 |
+
|
| 40 |
+
@router.post("/question", response_model=QueryResponse)
|
| 41 |
+
@limiter.limit(ACCESS_RATE)
|
| 42 |
+
async def ask_question(
|
| 43 |
+
request: Request,
|
| 44 |
+
data: QueryInput,
|
| 45 |
+
token: str = Depends(verify_token)
|
| 46 |
+
) -> QueryResponse:
|
| 47 |
+
"""
|
| 48 |
+
Ask a question to the RAG chatbot.
|
| 49 |
+
|
| 50 |
+
The chatbot can answer:
|
| 51 |
+
- Company-related questions (based on uploaded documents)
|
| 52 |
+
- Cybersecurity questions (from knowledge base)
|
| 53 |
+
"""
|
| 54 |
+
response = await handle_rag_query(data.query)
|
| 55 |
+
return QueryResponse(**response)
|
| 56 |
+
|
| 57 |
+
@router.post("/upload", response_model=UploadResponse)
|
| 58 |
+
@limiter.limit(ACCESS_RATE)
|
| 59 |
+
async def upload_document(
|
| 60 |
+
request: Request,
|
| 61 |
+
file: UploadFile = File(..., description="Document file (PDF, DOCX, or TXT)"),
|
| 62 |
+
token: str = Depends(verify_token)
|
| 63 |
+
) -> UploadResponse:
|
| 64 |
+
"""
|
| 65 |
+
Upload a document to the company knowledge base.
|
| 66 |
+
|
| 67 |
+
Supported formats:
|
| 68 |
+
- PDF (.pdf)
|
| 69 |
+
- Word documents (.docx)
|
| 70 |
+
- Plain text (.txt)
|
| 71 |
+
|
| 72 |
+
Maximum file size: 10MB
|
| 73 |
+
"""
|
| 74 |
+
response = await handle_document_upload(file)
|
| 75 |
+
return UploadResponse(**response)
|
| 76 |
+
|
| 77 |
+
@router.get("/health", response_model=HealthResponse)
|
| 78 |
+
@limiter.limit(ACCESS_RATE)
|
| 79 |
+
async def health_check(request: Request) -> HealthResponse:
|
| 80 |
+
"""
|
| 81 |
+
Check the health status of the RAG system.
|
| 82 |
+
|
| 83 |
+
Returns the status of all components:
|
| 84 |
+
- ChromaDB connection
|
| 85 |
+
- Vector store
|
| 86 |
+
- AI chains
|
| 87 |
+
"""
|
| 88 |
+
response = await handle_health_check()
|
| 89 |
+
return HealthResponse(**response)
|
| 90 |
+
|
| 91 |
+
@router.get("/info")
|
| 92 |
+
@limiter.limit(ACCESS_RATE)
|
| 93 |
+
async def get_system_info(request: Request):
|
| 94 |
+
"""Get information about the RAG system capabilities."""
|
| 95 |
+
return {
|
| 96 |
+
"name": "RAG Chatbot",
|
| 97 |
+
"version": "1.0.0",
|
| 98 |
+
"description": "A specialized chatbot for cybersecurity and company-related questions",
|
| 99 |
+
"capabilities": [
|
| 100 |
+
"Company document Q&A (based on uploaded documents)",
|
| 101 |
+
"Cybersecurity knowledge and best practices",
|
| 102 |
+
"Document upload and processing (PDF, DOCX, TXT)"
|
| 103 |
+
],
|
| 104 |
+
"supported_file_types": [
|
| 105 |
+
"application/pdf",
|
| 106 |
+
"application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
| 107 |
+
"text/plain"
|
| 108 |
+
],
|
| 109 |
+
"max_file_size_mb": 10,
|
| 110 |
+
"max_query_length": 1000
|
| 111 |
+
}
|
features/real_forged_classifier/controller.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import IO
|
| 2 |
+
from preprocessor import preprocessor
|
| 3 |
+
from inferencer import interferencer
|
| 4 |
+
|
| 5 |
+
class ClassificationController:
|
| 6 |
+
"""
|
| 7 |
+
Controller to handle the image classification logic.
|
| 8 |
+
"""
|
| 9 |
+
def classify_image(self, image_file: IO) -> dict:
|
| 10 |
+
"""
|
| 11 |
+
Orchestrates the classification of a single image file.
|
| 12 |
+
|
| 13 |
+
Args:
|
| 14 |
+
image_file (IO): The image file to classify.
|
| 15 |
+
|
| 16 |
+
Returns:
|
| 17 |
+
dict: The classification result.
|
| 18 |
+
"""
|
| 19 |
+
try:
|
| 20 |
+
# Step 1: Preprocess the image
|
| 21 |
+
image_tensor = preprocessor.process(image_file)
|
| 22 |
+
|
| 23 |
+
# Step 2: Perform inference
|
| 24 |
+
result = interferencer.predict(image_tensor)
|
| 25 |
+
|
| 26 |
+
return result
|
| 27 |
+
except ValueError as e:
|
| 28 |
+
# Handle specific errors like invalid images
|
| 29 |
+
return {"error": str(e)}
|
| 30 |
+
except Exception as e:
|
| 31 |
+
# Handle unexpected errors
|
| 32 |
+
print(f"An unexpected error occurred: {e}")
|
| 33 |
+
return {"error": "An internal error occurred during classification."}
|
| 34 |
+
|
| 35 |
+
# Create a single instance of the controller
|
| 36 |
+
controller = ClassificationController()
|
features/real_forged_classifier/inferencer.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
import numpy as np
|
| 4 |
+
|
| 5 |
+
# Import the globally loaded models instance
|
| 6 |
+
from model_loader import models
|
| 7 |
+
|
| 8 |
+
class Interferencer:
|
| 9 |
+
"""
|
| 10 |
+
Performs inference using the FFT CNN model.
|
| 11 |
+
"""
|
| 12 |
+
def __init__(self):
|
| 13 |
+
"""
|
| 14 |
+
Initializes the interferencer with the loaded model.
|
| 15 |
+
"""
|
| 16 |
+
self.fft_model = models.fft_model
|
| 17 |
+
|
| 18 |
+
@torch.no_grad()
|
| 19 |
+
def predict(self, image_tensor: torch.Tensor) -> dict:
|
| 20 |
+
"""
|
| 21 |
+
Takes a preprocessed image tensor and returns the classification result.
|
| 22 |
+
|
| 23 |
+
Args:
|
| 24 |
+
image_tensor (torch.Tensor): The preprocessed image tensor.
|
| 25 |
+
|
| 26 |
+
Returns:
|
| 27 |
+
dict: A dictionary containing the classification label and confidence score.
|
| 28 |
+
"""
|
| 29 |
+
# 1. Get model outputs (logits)
|
| 30 |
+
outputs = self.fft_model(image_tensor)
|
| 31 |
+
|
| 32 |
+
# 2. Apply softmax to get probabilities
|
| 33 |
+
probabilities = F.softmax(outputs, dim=1)
|
| 34 |
+
|
| 35 |
+
# 3. Get the confidence and the predicted class index
|
| 36 |
+
confidence, predicted_idx = torch.max(probabilities, 1)
|
| 37 |
+
|
| 38 |
+
prediction = predicted_idx.item()
|
| 39 |
+
|
| 40 |
+
# 4. Map the prediction to a human-readable label
|
| 41 |
+
# Ensure this mapping matches the labels used during training
|
| 42 |
+
# Typically: 0 -> fake, 1 -> real
|
| 43 |
+
label_map = {0: 'fake', 1: 'real'}
|
| 44 |
+
classification_label = label_map.get(prediction, "unknown")
|
| 45 |
+
|
| 46 |
+
return {
|
| 47 |
+
"classification": classification_label,
|
| 48 |
+
"confidence": confidence.item()
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
+
# Create a single instance of the interferencer
|
| 52 |
+
interferencer = Interferencer()
|
features/real_forged_classifier/main.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import FastAPI
|
| 2 |
+
from routes import router as api_router
|
| 3 |
+
|
| 4 |
+
# Initialize the FastAPI app
|
| 5 |
+
app = FastAPI(
|
| 6 |
+
title="Real vs. Fake Image Classification API",
|
| 7 |
+
description="An API to classify images as real or forged using FFT and cnn.",
|
| 8 |
+
version="1.0.0"
|
| 9 |
+
)
|
| 10 |
+
|
| 11 |
+
# Include the API router
|
| 12 |
+
# All routes defined in routes.py will be available under the /api prefix
|
| 13 |
+
app.include_router(api_router, prefix="/api", tags=["Classification"])
|
| 14 |
+
|
| 15 |
+
@app.get("/", tags=["Root"])
|
| 16 |
+
async def read_root():
|
| 17 |
+
"""
|
| 18 |
+
A simple root endpoint to confirm the API is running.
|
| 19 |
+
"""
|
| 20 |
+
return {"message": "Welcome to the Image Classification API. Go to /docs for the API documentation."}
|
| 21 |
+
|
| 22 |
+
# To run this application:
|
| 23 |
+
# 1. Make sure you have all dependencies from requirements.txt installed.
|
| 24 |
+
# 2. Make sure the 'svm_model.joblib' file is in the same directory.
|
| 25 |
+
# 3. Run the following command in your terminal:
|
| 26 |
+
# uvicorn main:app --reload
|
features/real_forged_classifier/model.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
|
| 5 |
+
class FFTCNN(nn.Module):
|
| 6 |
+
"""
|
| 7 |
+
Defines the Convolutional Neural Network architecture.
|
| 8 |
+
This structure must match the model that was trained and saved.
|
| 9 |
+
"""
|
| 10 |
+
def __init__(self):
|
| 11 |
+
super(FFTCNN, self).__init__()
|
| 12 |
+
# Ensure 'self.' is used here to define the layers as instance attributes
|
| 13 |
+
self.conv_layers = nn.Sequential(
|
| 14 |
+
nn.Conv2d(1, 16, kernel_size=3, padding=1),
|
| 15 |
+
nn.ReLU(),
|
| 16 |
+
nn.MaxPool2d(kernel_size=2, stride=2),
|
| 17 |
+
nn.Conv2d(16, 32, kernel_size=3, padding=1),
|
| 18 |
+
nn.ReLU(),
|
| 19 |
+
nn.MaxPool2d(kernel_size=2, stride=2)
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
# Ensure 'self.' is used here as well
|
| 23 |
+
self.fc_layers = nn.Sequential(
|
| 24 |
+
nn.Linear(32 * 56 * 56, 128), # This size depends on your 224x224 input
|
| 25 |
+
nn.ReLU(),
|
| 26 |
+
nn.Linear(128, 2) # 2 output classes
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
def forward(self, x):
|
| 30 |
+
# Now, 'self.conv_layers' can be found because it was defined correctly
|
| 31 |
+
x = self.conv_layers(x)
|
| 32 |
+
x = x.view(x.size(0), -1) # Flatten the feature maps
|
| 33 |
+
x = self.fc_layers(x)
|
| 34 |
+
return x
|
features/real_forged_classifier/model_loader.py
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
from huggingface_hub import hf_hub_download
|
| 4 |
+
from model import FFTCNN # Import the model architecture
|
| 5 |
+
|
| 6 |
+
class ModelLoader:
|
| 7 |
+
"""
|
| 8 |
+
A class to load and hold the PyTorch CNN model.
|
| 9 |
+
"""
|
| 10 |
+
def __init__(self, model_repo_id: str, model_filename: str):
|
| 11 |
+
"""
|
| 12 |
+
Initializes the ModelLoader and loads the model.
|
| 13 |
+
|
| 14 |
+
Args:
|
| 15 |
+
model_repo_id (str): The repository ID on Hugging Face.
|
| 16 |
+
model_filename (str): The name of the model file (.pth) in the repository.
|
| 17 |
+
"""
|
| 18 |
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 19 |
+
print(f"Using device: {self.device}")
|
| 20 |
+
|
| 21 |
+
self.fft_model = self._load_fft_model(repo_id=model_repo_id, filename=model_filename)
|
| 22 |
+
print("FFT CNN model loaded successfully.")
|
| 23 |
+
|
| 24 |
+
def _load_fft_model(self, repo_id: str, filename: str):
|
| 25 |
+
"""
|
| 26 |
+
Downloads and loads the FFT CNN model from a Hugging Face Hub repository.
|
| 27 |
+
|
| 28 |
+
Args:
|
| 29 |
+
repo_id (str): The repository ID on Hugging Face.
|
| 30 |
+
filename (str): The name of the model file (.pth) in the repository.
|
| 31 |
+
|
| 32 |
+
Returns:
|
| 33 |
+
The loaded PyTorch model object.
|
| 34 |
+
"""
|
| 35 |
+
print(f"Downloading FFT CNN model from Hugging Face repo: {repo_id}")
|
| 36 |
+
try:
|
| 37 |
+
# Download the model file from the Hub. It returns the cached path.
|
| 38 |
+
model_path = hf_hub_download(repo_id=repo_id, filename=filename)
|
| 39 |
+
print(f"Model downloaded to: {model_path}")
|
| 40 |
+
|
| 41 |
+
# Initialize the model architecture
|
| 42 |
+
model = FFTCNN()
|
| 43 |
+
|
| 44 |
+
# Load the saved weights (state_dict) into the model
|
| 45 |
+
model.load_state_dict(torch.load(model_path, map_location=torch.device(self.device)))
|
| 46 |
+
|
| 47 |
+
# Set the model to evaluation mode
|
| 48 |
+
model.to(self.device)
|
| 49 |
+
model.eval()
|
| 50 |
+
|
| 51 |
+
return model
|
| 52 |
+
except Exception as e:
|
| 53 |
+
print(f"Error downloading or loading model from Hugging Face: {e}")
|
| 54 |
+
raise
|
| 55 |
+
|
| 56 |
+
# --- Global Model Instance ---
|
| 57 |
+
MODEL_REPO_ID = 'rhnsa/real_forged_classifier'
|
| 58 |
+
MODEL_FILENAME = 'fft_cnn_model_78.pth'
|
| 59 |
+
models = ModelLoader(model_repo_id=MODEL_REPO_ID, model_filename=MODEL_FILENAME)
|
| 60 |
+
|
features/real_forged_classifier/preprocessor.py
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from PIL import Image
|
| 2 |
+
import torch
|
| 3 |
+
import numpy as np
|
| 4 |
+
from typing import IO
|
| 5 |
+
import cv2
|
| 6 |
+
from torchvision import transforms
|
| 7 |
+
|
| 8 |
+
# Import the globally loaded models instance
|
| 9 |
+
from model_loader import models
|
| 10 |
+
|
| 11 |
+
class ImagePreprocessor:
|
| 12 |
+
"""
|
| 13 |
+
Handles preprocessing of images for the FFT CNN model.
|
| 14 |
+
"""
|
| 15 |
+
def __init__(self):
|
| 16 |
+
"""
|
| 17 |
+
Initializes the preprocessor.
|
| 18 |
+
"""
|
| 19 |
+
self.device = models.device
|
| 20 |
+
# Define the image transformations, matching the training process
|
| 21 |
+
self.transform = transforms.Compose([
|
| 22 |
+
transforms.ToPILImage(),
|
| 23 |
+
transforms.Resize((224, 224)),
|
| 24 |
+
transforms.ToTensor(),
|
| 25 |
+
])
|
| 26 |
+
|
| 27 |
+
def process(self, image_file: IO) -> torch.Tensor:
|
| 28 |
+
"""
|
| 29 |
+
Opens an image file, applies FFT, preprocesses it, and returns a tensor.
|
| 30 |
+
|
| 31 |
+
Args:
|
| 32 |
+
image_file (IO): The image file object (e.g., from a file upload).
|
| 33 |
+
|
| 34 |
+
Returns:
|
| 35 |
+
torch.Tensor: The preprocessed image as a tensor, ready for the model.
|
| 36 |
+
"""
|
| 37 |
+
try:
|
| 38 |
+
# Read the image file into a numpy array
|
| 39 |
+
image_np = np.frombuffer(image_file.read(), np.uint8)
|
| 40 |
+
# Decode the image as grayscale
|
| 41 |
+
img = cv2.imdecode(image_np, cv2.IMREAD_GRAYSCALE)
|
| 42 |
+
except Exception as e:
|
| 43 |
+
print(f"Error reading or decoding image: {e}")
|
| 44 |
+
raise ValueError("Invalid or corrupted image file.")
|
| 45 |
+
|
| 46 |
+
if img is None:
|
| 47 |
+
raise ValueError("Could not decode image. File may be empty or corrupted.")
|
| 48 |
+
|
| 49 |
+
# 1. Apply Fast Fourier Transform (FFT)
|
| 50 |
+
f = np.fft.fft2(img)
|
| 51 |
+
fshift = np.fft.fftshift(f)
|
| 52 |
+
magnitude_spectrum = 20 * np.log(np.abs(fshift) + 1) # Add 1 to avoid log(0)
|
| 53 |
+
|
| 54 |
+
# Normalize the magnitude spectrum to be in the range [0, 255]
|
| 55 |
+
magnitude_spectrum = cv2.normalize(magnitude_spectrum, None, 0, 255, cv2.NORM_MINMAX)
|
| 56 |
+
magnitude_spectrum = np.uint8(magnitude_spectrum)
|
| 57 |
+
|
| 58 |
+
# 2. Apply torchvision transforms
|
| 59 |
+
image_tensor = self.transform(magnitude_spectrum)
|
| 60 |
+
|
| 61 |
+
# Add a batch dimension and move to the correct device
|
| 62 |
+
image_tensor = image_tensor.unsqueeze(0).to(self.device)
|
| 63 |
+
|
| 64 |
+
return image_tensor
|
| 65 |
+
|
| 66 |
+
# Create a single instance of the preprocessor
|
| 67 |
+
preprocessor = ImagePreprocessor()
|
features/real_forged_classifier/routes.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import APIRouter, File, UploadFile, HTTPException, status
|
| 2 |
+
from fastapi.responses import JSONResponse
|
| 3 |
+
|
| 4 |
+
# Import the controller instance
|
| 5 |
+
from controller import controller
|
| 6 |
+
|
| 7 |
+
# Create an API router
|
| 8 |
+
router = APIRouter()
|
| 9 |
+
|
| 10 |
+
@router.post("/classify_forgery", summary="Classify an image as Real or Fake")
|
| 11 |
+
async def classify_image_endpoint(image: UploadFile = File(...)):
|
| 12 |
+
"""
|
| 13 |
+
Accepts an image file and classifies it as 'real' or 'fake'.
|
| 14 |
+
|
| 15 |
+
- **image**: The image file to be classified (e.g., JPEG, PNG).
|
| 16 |
+
|
| 17 |
+
Returns a JSON object with the classification and a confidence score.
|
| 18 |
+
"""
|
| 19 |
+
# Check for a valid image content type
|
| 20 |
+
if not image.content_type.startswith("image/"):
|
| 21 |
+
raise HTTPException(
|
| 22 |
+
status_code=status.HTTP_415_UNSUPPORTED_MEDIA_TYPE,
|
| 23 |
+
detail="Unsupported file type. Please upload an image (e.g., JPEG, PNG)."
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
# The controller expects a file-like object, which `image.file` provides
|
| 27 |
+
result = controller.classify_image(image.file)
|
| 28 |
+
|
| 29 |
+
if "error" in result:
|
| 30 |
+
# If the controller returned an error, forward it as an HTTP exception
|
| 31 |
+
raise HTTPException(
|
| 32 |
+
status_code=status.HTTP_400_BAD_REQUEST,
|
| 33 |
+
detail=result["error"]
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
return JSONResponse(content=result, status_code=status.HTTP_200_OK)
|
| 37 |
+
|
features/text_classifier/controller.py
CHANGED
|
@@ -60,12 +60,12 @@ async def handle_file_upload(file: UploadFile):
|
|
| 60 |
try:
|
| 61 |
file_contents = await extract_file_contents(file)
|
| 62 |
if len(file_contents) > 10000:
|
| 63 |
-
|
| 64 |
|
| 65 |
cleaned_text = file_contents.replace("\n", " ").replace("\t", " ").strip()
|
| 66 |
if not cleaned_text:
|
| 67 |
raise HTTPException(status_code=404, detail="The file is empty or only contains whitespace.")
|
| 68 |
-
|
| 69 |
label, perplexity, ai_likelihood = await asyncio.to_thread(classify_text, cleaned_text)
|
| 70 |
return {
|
| 71 |
"content": file_contents,
|
|
@@ -102,12 +102,15 @@ async def handle_sentence_level_analysis(text: str):
|
|
| 102 |
"ai_likelihood": ai_likelihood
|
| 103 |
})
|
| 104 |
|
| 105 |
-
return {"analysis": results}
|
|
|
|
|
|
|
| 106 |
async def handle_file_sentence(file: UploadFile):
|
| 107 |
try:
|
| 108 |
file_contents = await extract_file_contents(file)
|
| 109 |
if len(file_contents) > 10000:
|
| 110 |
-
raise HTTPException(status_code=413, detail="Text must be less than 10,000 characters")
|
|
|
|
| 111 |
|
| 112 |
cleaned_text = file_contents.replace("\n", " ").replace("\t", " ").strip()
|
| 113 |
if not cleaned_text:
|
|
|
|
| 60 |
try:
|
| 61 |
file_contents = await extract_file_contents(file)
|
| 62 |
if len(file_contents) > 10000:
|
| 63 |
+
return {"status_code": 413, "detail": "Text must be less than 10,000 characters"}
|
| 64 |
|
| 65 |
cleaned_text = file_contents.replace("\n", " ").replace("\t", " ").strip()
|
| 66 |
if not cleaned_text:
|
| 67 |
raise HTTPException(status_code=404, detail="The file is empty or only contains whitespace.")
|
| 68 |
+
# print(f"Cleaned text: '{cleaned_text}'") # Debugging statement
|
| 69 |
label, perplexity, ai_likelihood = await asyncio.to_thread(classify_text, cleaned_text)
|
| 70 |
return {
|
| 71 |
"content": file_contents,
|
|
|
|
| 102 |
"ai_likelihood": ai_likelihood
|
| 103 |
})
|
| 104 |
|
| 105 |
+
return {"analysis": results}
|
| 106 |
+
|
| 107 |
+
# Analyze each sentence from uploaded file
|
| 108 |
async def handle_file_sentence(file: UploadFile):
|
| 109 |
try:
|
| 110 |
file_contents = await extract_file_contents(file)
|
| 111 |
if len(file_contents) > 10000:
|
| 112 |
+
# raise HTTPException(status_code=413, detail="Text must be less than 10,000 characters")
|
| 113 |
+
return {"status_code": 413, "detail": "Text must be less than 10,000 characters"}
|
| 114 |
|
| 115 |
cleaned_text = file_contents.replace("\n", " ").replace("\t", " ").strip()
|
| 116 |
if not cleaned_text:
|
features/text_classifier/preprocess.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
import
|
| 2 |
import docx
|
| 3 |
from io import BytesIO
|
| 4 |
import logging
|
|
@@ -15,18 +15,16 @@ def parse_docx(file: BytesIO):
|
|
| 15 |
|
| 16 |
def parse_pdf(file: BytesIO):
|
| 17 |
try:
|
| 18 |
-
doc =
|
| 19 |
text = ""
|
| 20 |
-
for
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
return text
|
| 24 |
except Exception as e:
|
| 25 |
logging.error(f"Error while processing PDF: {str(e)}")
|
| 26 |
raise HTTPException(
|
| 27 |
status_code=500, detail="Error processing PDF file")
|
| 28 |
|
| 29 |
-
|
| 30 |
def parse_txt(file: BytesIO):
|
| 31 |
return file.read().decode("utf-8")
|
| 32 |
|
|
|
|
| 1 |
+
from pypdf import PdfReader
|
| 2 |
import docx
|
| 3 |
from io import BytesIO
|
| 4 |
import logging
|
|
|
|
| 15 |
|
| 16 |
def parse_pdf(file: BytesIO):
|
| 17 |
try:
|
| 18 |
+
doc = PdfReader(file)
|
| 19 |
text = ""
|
| 20 |
+
for page in doc.pages:
|
| 21 |
+
text += page.extract_text()
|
| 22 |
+
return text
|
|
|
|
| 23 |
except Exception as e:
|
| 24 |
logging.error(f"Error while processing PDF: {str(e)}")
|
| 25 |
raise HTTPException(
|
| 26 |
status_code=500, detail="Error processing PDF file")
|
| 27 |
|
|
|
|
| 28 |
def parse_txt(file: BytesIO):
|
| 29 |
return file.read().decode("utf-8")
|
| 30 |
|
requirements.txt
CHANGED
|
@@ -15,6 +15,6 @@ tensorflow
|
|
| 15 |
opencv-python
|
| 16 |
pillow
|
| 17 |
scipy
|
| 18 |
-
|
| 19 |
frontend
|
| 20 |
tools
|
|
|
|
| 15 |
opencv-python
|
| 16 |
pillow
|
| 17 |
scipy
|
| 18 |
+
pypdf
|
| 19 |
frontend
|
| 20 |
tools
|
test.md
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
**Update: Edited & AI-Generated Content Detection – Project Plan**
|
| 3 |
+
|
| 4 |
+
### 🔍 Phase 1: Rule-Based Image Detection (In Progress)
|
| 5 |
+
|
| 6 |
+
We're implementing three core techniques to individually flag edited or AI-generated images:
|
| 7 |
+
|
| 8 |
+
* **ELA (Error Level Analysis):** Highlights inconsistencies via JPEG recompression.
|
| 9 |
+
* **FFT (Frequency Analysis):** Uses 2D Fourier Transform to detect unnatural image frequency patterns.
|
| 10 |
+
* **Metadata Analysis:** Parses EXIF data to catch clues like editing software tags.
|
| 11 |
+
|
| 12 |
+
These give us visual + interpretable results for each image, and currently offer \~60–70% accuracy on typical AI-edited content.
|
| 13 |
+
|
| 14 |
+
---
|
| 15 |
+
|
| 16 |
+
### Phase 2: AI vs Human Detection System (Coming Soon)
|
| 17 |
+
|
| 18 |
+
**Goal:** Build an AI model that classifies whether content is AI- or human-made — initially focusing on **images**, and later expanding to **text**.
|
| 19 |
+
|
| 20 |
+
**Data Strategy:**
|
| 21 |
+
|
| 22 |
+
* Scraping large volumes of recent AI-gen images (e.g. SDXL, Gibbli, MidJourney).
|
| 23 |
+
* Balancing with high-quality human images.
|
| 24 |
+
|
| 25 |
+
**Model Plan:**
|
| 26 |
+
|
| 27 |
+
* Use ELA, FFT, and metadata as feature extractors.
|
| 28 |
+
* Feed these into a CNN or ensemble model.
|
| 29 |
+
* Later, unify into a full web-based platform (upload → get AI/human probability).
|
| 30 |
+
|
| 31 |
+
|