Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +33 -0
- README.md +64 -12
- frontend/app/components/dashboard/EconomicIndicators.tsx +34 -9
- frontend/app/components/dashboard/TrendingTopics.tsx +3 -1
- frontend/app/hooks/use-roger-data.ts +12 -6
- frontend/app/pages/Index.tsx +2 -0
- main.py +213 -54
- models/anomaly-detection/main.py +123 -70
- models/anomaly-detection/src/components/data_ingestion.py +4 -1
- models/anomaly-detection/src/components/data_transformation.py +6 -1
- models/anomaly-detection/src/components/data_validation.py +1 -1
- models/anomaly-detection/src/components/model_trainer.py +78 -27
- models/anomaly-detection/src/constants/__init__.py +1 -0
- models/anomaly-detection/src/constants/training_pipeline/__init__.py +65 -0
- models/anomaly-detection/src/entity/config_entity.py +1 -1
- models/anomaly-detection/src/exception/__init__.py +1 -0
- models/anomaly-detection/src/exception/exception.py +24 -0
- models/anomaly-detection/src/logging/__init__.py +1 -0
- models/anomaly-detection/src/logging/logger.py +32 -0
- models/anomaly-detection/src/pipeline/training_pipeline.py +4 -4
- models/currency-volatility-prediction/main.py +132 -118
- models/currency-volatility-prediction/src/components/model_trainer.py +5 -0
- models/currency-volatility-prediction/src/components/predictor.py +5 -1
- models/currency-volatility-prediction/src/exception/__init__.py +1 -0
- models/currency-volatility-prediction/src/exception/exception.py +17 -15
- models/currency-volatility-prediction/src/logging/__init__.py +1 -0
- models/currency-volatility-prediction/src/logging/logger.py +16 -4
- models/weather-prediction/main.py +196 -121
- models/weather-prediction/src/components/model_trainer.py +5 -0
- models/weather-prediction/src/components/predictor.py +4 -1
- models/weather-prediction/src/exception/__init__.py +1 -0
- models/weather-prediction/src/exception/exception.py +17 -15
- models/weather-prediction/src/logging/__init__.py +1 -0
- models/weather-prediction/src/logging/logger.py +16 -4
- src/graphs/economicalAgentGraph.py +3 -3
- src/graphs/intelligenceAgentGraph.py +3 -3
- src/graphs/meteorologicalAgentGraph.py +3 -3
- src/graphs/politicalAgentGraph.py +3 -3
- src/graphs/socialAgentGraph.py +17 -3
- src/nodes/socialAgentNode.py +136 -1
- src/rag.py +322 -38
- src/storage/storage_manager.py +94 -0
- src/utils/.browser_data/linkedin/BrowserMetrics-spare.pma +3 -0
- src/utils/.browser_data/linkedin/Crashpad/metadata +0 -0
- src/utils/.browser_data/linkedin/Crashpad/reports/1bb2b465-675d-47f0-b953-a844af38ce6b.dmp +3 -0
- src/utils/.browser_data/linkedin/Crashpad/reports/55792d7f-8397-4730-8518-c50a507a611a.dmp +3 -0
- src/utils/.browser_data/linkedin/Crashpad/reports/880fc1e0-3241-4d76-a26b-0f9d6135dcd6.dmp +3 -0
- src/utils/.browser_data/linkedin/Crashpad/settings.dat +0 -0
- src/utils/.browser_data/linkedin/Default/Account Web Data +0 -0
- src/utils/.browser_data/linkedin/Default/Account Web Data-journal +0 -0
.gitattributes
CHANGED
|
@@ -36,3 +36,36 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 36 |
ModelX[[:space:]]Final[[:space:]]Problem.pdf filter=lfs diff=lfs merge=lfs -text
|
| 37 |
trending_detection_visualization.png filter=lfs diff=lfs merge=lfs -text
|
| 38 |
vectorizer_anomaly_visualization.png filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
ModelX[[:space:]]Final[[:space:]]Problem.pdf filter=lfs diff=lfs merge=lfs -text
|
| 37 |
trending_detection_visualization.png filter=lfs diff=lfs merge=lfs -text
|
| 38 |
vectorizer_anomaly_visualization.png filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
src/utils/.browser_data/linkedin/BrowserMetrics-spare.pma filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
src/utils/.browser_data/linkedin/Crashpad/reports/1bb2b465-675d-47f0-b953-a844af38ce6b.dmp filter=lfs diff=lfs merge=lfs -text
|
| 41 |
+
src/utils/.browser_data/linkedin/Crashpad/reports/55792d7f-8397-4730-8518-c50a507a611a.dmp filter=lfs diff=lfs merge=lfs -text
|
| 42 |
+
src/utils/.browser_data/linkedin/Crashpad/reports/880fc1e0-3241-4d76-a26b-0f9d6135dcd6.dmp filter=lfs diff=lfs merge=lfs -text
|
| 43 |
+
src/utils/.browser_data/linkedin/Default/Cache/Cache_Data/data_1 filter=lfs diff=lfs merge=lfs -text
|
| 44 |
+
src/utils/.browser_data/linkedin/Default/Cache/Cache_Data/data_2 filter=lfs diff=lfs merge=lfs -text
|
| 45 |
+
src/utils/.browser_data/linkedin/Default/Cache/Cache_Data/data_3 filter=lfs diff=lfs merge=lfs -text
|
| 46 |
+
src/utils/.browser_data/linkedin/Default/Cache/Cache_Data/f_000002 filter=lfs diff=lfs merge=lfs -text
|
| 47 |
+
src/utils/.browser_data/linkedin/Default/Cache/Cache_Data/f_000003 filter=lfs diff=lfs merge=lfs -text
|
| 48 |
+
src/utils/.browser_data/linkedin/Default/Cache/Cache_Data/f_000006 filter=lfs diff=lfs merge=lfs -text
|
| 49 |
+
src/utils/.browser_data/linkedin/Default/Cache/Cache_Data/f_00000b filter=lfs diff=lfs merge=lfs -text
|
| 50 |
+
src/utils/.browser_data/linkedin/Default/Cache/Cache_Data/index filter=lfs diff=lfs merge=lfs -text
|
| 51 |
+
src/utils/.browser_data/linkedin/Default/Code[[:space:]]Cache/js/3d01be7861bd5850_0 filter=lfs diff=lfs merge=lfs -text
|
| 52 |
+
src/utils/.browser_data/linkedin/Default/Code[[:space:]]Cache/js/4f0cb78a57ef4137_0 filter=lfs diff=lfs merge=lfs -text
|
| 53 |
+
src/utils/.browser_data/linkedin/Default/Code[[:space:]]Cache/js/aaeed4cfeb9c324a_0 filter=lfs diff=lfs merge=lfs -text
|
| 54 |
+
src/utils/.browser_data/linkedin/Default/Code[[:space:]]Cache/js/bc082d8e612dbd10_0 filter=lfs diff=lfs merge=lfs -text
|
| 55 |
+
src/utils/.browser_data/linkedin/Default/Code[[:space:]]Cache/js/e3df1293cf5ee96e_0 filter=lfs diff=lfs merge=lfs -text
|
| 56 |
+
src/utils/.browser_data/linkedin/Default/DawnGraphiteCache/data_1 filter=lfs diff=lfs merge=lfs -text
|
| 57 |
+
src/utils/.browser_data/linkedin/Default/DawnGraphiteCache/index filter=lfs diff=lfs merge=lfs -text
|
| 58 |
+
src/utils/.browser_data/linkedin/Default/DawnWebGPUCache/data_1 filter=lfs diff=lfs merge=lfs -text
|
| 59 |
+
src/utils/.browser_data/linkedin/Default/DawnWebGPUCache/index filter=lfs diff=lfs merge=lfs -text
|
| 60 |
+
src/utils/.browser_data/linkedin/Default/GPUCache/data_1 filter=lfs diff=lfs merge=lfs -text
|
| 61 |
+
src/utils/.browser_data/linkedin/Default/GPUCache/data_2 filter=lfs diff=lfs merge=lfs -text
|
| 62 |
+
src/utils/.browser_data/linkedin/Default/GPUCache/index filter=lfs diff=lfs merge=lfs -text
|
| 63 |
+
src/utils/.browser_data/linkedin/Default/History filter=lfs diff=lfs merge=lfs -text
|
| 64 |
+
src/utils/.browser_data/linkedin/Default/Web[[:space:]]Data filter=lfs diff=lfs merge=lfs -text
|
| 65 |
+
src/utils/.browser_data/linkedin/GrShaderCache/data_1 filter=lfs diff=lfs merge=lfs -text
|
| 66 |
+
src/utils/.browser_data/linkedin/GrShaderCache/data_3 filter=lfs diff=lfs merge=lfs -text
|
| 67 |
+
src/utils/.browser_data/linkedin/GrShaderCache/index filter=lfs diff=lfs merge=lfs -text
|
| 68 |
+
src/utils/.browser_data/linkedin/GraphiteDawnCache/data_1 filter=lfs diff=lfs merge=lfs -text
|
| 69 |
+
src/utils/.browser_data/linkedin/GraphiteDawnCache/index filter=lfs diff=lfs merge=lfs -text
|
| 70 |
+
src/utils/.browser_data/linkedin/ShaderCache/data_1 filter=lfs diff=lfs merge=lfs -text
|
| 71 |
+
src/utils/.browser_data/linkedin/ShaderCache/index filter=lfs diff=lfs merge=lfs -text
|
README.md
CHANGED
|
@@ -11,7 +11,7 @@ pinned: false
|
|
| 11 |
|
| 12 |
**Real-Time Situational Awareness for Sri Lanka**
|
| 13 |
|
| 14 |
-
A multi-agent AI system that aggregates intelligence from
|
| 15 |
|
| 16 |
## 🌐 Live Demo
|
| 17 |
|
|
@@ -24,14 +24,14 @@ A multi-agent AI system that aggregates intelligence from 47+ data sources to pr
|
|
| 24 |
|
| 25 |
## 🎯 Key Features
|
| 26 |
|
| 27 |
-
✅ **
|
| 28 |
-
- Social
|
| 29 |
-
- Political
|
| 30 |
-
-
|
| 31 |
-
- Meteorological
|
| 32 |
-
- Intelligence Agent
|
| 33 |
-
-
|
| 34 |
-
-
|
| 35 |
|
| 36 |
✅ **Situational Awareness Dashboard**:
|
| 37 |
- **CEB Power Status** - Load shedding / power outage monitoring
|
|
@@ -109,6 +109,13 @@ A multi-agent AI system that aggregates intelligence from 47+ data sources to pr
|
|
| 109 |
- Supports: Western, Southern, Central, Northern, Eastern, Sabaragamuwa, Uva, North Western, North Central provinces
|
| 110 |
- Both frontend (MapView, DistrictInfoPanel) and backend are synchronized
|
| 111 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 112 |
---
|
| 113 |
|
| 114 |
## 🏗️ System Architecture
|
|
@@ -185,6 +192,40 @@ graph TD
|
|
| 185 |
- **Non-Blocking Refresh**: 60-second cycle with interruptible sleep
|
| 186 |
- `threading.Event.wait()` instead of blocking `time.sleep()`
|
| 187 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 188 |
---
|
| 189 |
|
| 190 |
### 2. Political Agent Graph (`politicalAgentGraph.py`)
|
|
@@ -870,9 +911,20 @@ Roger-Ultimate/
|
|
| 870 |
# LLM
|
| 871 |
GROQ_API_KEY=your_groq_key
|
| 872 |
|
| 873 |
-
#
|
| 874 |
-
|
| 875 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 876 |
|
| 877 |
# MLflow (DagsHub)
|
| 878 |
MLFLOW_TRACKING_URI=https://dagshub.com/...
|
|
|
|
| 11 |
|
| 12 |
**Real-Time Situational Awareness for Sri Lanka**
|
| 13 |
|
| 14 |
+
A multi-agent AI system that aggregates intelligence from **50+ data sources** to provide risk analysis and opportunity detection for businesses operating in Sri Lanka.
|
| 15 |
|
| 16 |
## 🌐 Live Demo
|
| 17 |
|
|
|
|
| 24 |
|
| 25 |
## 🎯 Key Features
|
| 26 |
|
| 27 |
+
✅ **5 Domain Agents + 2 Orchestrators** running in parallel:
|
| 28 |
+
- **Social Agent** - Reddit, Twitter, Facebook, Threads, BlueSky monitoring
|
| 29 |
+
- **Political Agent** - Gazette, Parliament, District Social Media
|
| 30 |
+
- **Economical Agent** - CSE Stock Market + Technical Indicators (SMA, EMA, RSI, MACD)
|
| 31 |
+
- **Meteorological Agent** - DMC Weather + RiverNet + **FloodWatch Integration**
|
| 32 |
+
- **Intelligence Agent** - Brand Monitoring + Threat Detection + **User-Configurable Targets**
|
| 33 |
+
- **Combined Agent (Orchestrator)** - Fan-out/Fan-in coordination, LLM filtering, feed ranking
|
| 34 |
+
- **Data Retrieval Agent** - Web scraping orchestration with anti-bot features
|
| 35 |
|
| 36 |
✅ **Situational Awareness Dashboard**:
|
| 37 |
- **CEB Power Status** - Load shedding / power outage monitoring
|
|
|
|
| 109 |
- Supports: Western, Southern, Central, Northern, Eastern, Sabaragamuwa, Uva, North Western, North Central provinces
|
| 110 |
- Both frontend (MapView, DistrictInfoPanel) and backend are synchronized
|
| 111 |
|
| 112 |
+
✅ **3-Tier Storage Architecture** with Deduplication:
|
| 113 |
+
- **Tier 1: SQLite** - Fast hash-based exact match (microseconds)
|
| 114 |
+
- **Tier 2: ChromaDB** - Semantic similarity search with sentence transformers (milliseconds)
|
| 115 |
+
- **Tier 3: Neo4j Aura** - Knowledge graph for event relationships and entity tracking
|
| 116 |
+
- Unified `StorageManager` orchestrates all backends
|
| 117 |
+
- Deduplication prevents duplicate feeds across all domain agents
|
| 118 |
+
|
| 119 |
---
|
| 120 |
|
| 121 |
## 🏗️ System Architecture
|
|
|
|
| 192 |
- **Non-Blocking Refresh**: 60-second cycle with interruptible sleep
|
| 193 |
- `threading.Event.wait()` instead of blocking `time.sleep()`
|
| 194 |
|
| 195 |
+
### Storage Data Flow
|
| 196 |
+
|
| 197 |
+
```
|
| 198 |
+
┌─────────────────────────────────────────────────────────────────────────────┐
|
| 199 |
+
│ DOMAIN AGENTS (Parallel) │
|
| 200 |
+
│ ┌──────────┐ ┌──────────┐ ┌──────────┐ ┌──────────┐ ┌──────────────┐ │
|
| 201 |
+
│ │ Social │ │Political │ │Economic │ │ Meteo │ │ Intelligence │ │
|
| 202 |
+
│ │ Agent │ │ Agent │ │ Agent │ │ Agent │ │ Agent │ │
|
| 203 |
+
│ └────┬─────┘ └────┬─────┘ └────┬─────┘ └────┬─────┘ └──────┬───────┘ │
|
| 204 |
+
│ └────────────┴────────────┴────────────┴──────────────┘ │
|
| 205 |
+
│ │ Fan-In │
|
| 206 |
+
│ ┌────────────▼─────────────┐ │
|
| 207 |
+
│ │ CombinedAgentNode │ │
|
| 208 |
+
│ │ (LLM Filter + Rank) │ │
|
| 209 |
+
│ └────────────┬─────────────┘ │
|
| 210 |
+
└─────────────────────────────────┼───────────────────────────────────────────┘
|
| 211 |
+
│
|
| 212 |
+
┌─────────────▼──────────────┐
|
| 213 |
+
│ StorageManager │
|
| 214 |
+
│ (3-Tier Deduplication) │
|
| 215 |
+
└─────────────┬──────────────┘
|
| 216 |
+
┌───────────────────────┼──────────────────────────┐
|
| 217 |
+
│ │ │
|
| 218 |
+
▼ ▼ ▼
|
| 219 |
+
┌─────────────────┐ ┌──────────────────┐ ┌─────────────────────────┐
|
| 220 |
+
│ SQLite │ │ ChromaDB │ │ Neo4j Aura │
|
| 221 |
+
│ (Fast Cache) │ │ (Vector Store) │ │ (Knowledge Graph) │
|
| 222 |
+
│ ───────────── │ │ ────────────── │ │ ─────────────────── │
|
| 223 |
+
│ Hash-based │ │ Semantic search │ │ Event relationships │
|
| 224 |
+
│ Exact match │ │ Similarity 0.85 │ │ Domain nodes │
|
| 225 |
+
│ ~microseconds │ │ ~milliseconds │ │ Entity tracking │
|
| 226 |
+
└─────────────────┘ └──────────────────┘ └─────────────────────────┘
|
| 227 |
+
```
|
| 228 |
+
|
| 229 |
---
|
| 230 |
|
| 231 |
### 2. Political Agent Graph (`politicalAgentGraph.py`)
|
|
|
|
| 911 |
# LLM
|
| 912 |
GROQ_API_KEY=your_groq_key
|
| 913 |
|
| 914 |
+
# Neo4j (Knowledge Graph)
|
| 915 |
+
NEO4J_URI=neo4j+s://your-instance.databases.neo4j.io
|
| 916 |
+
NEO4J_USERNAME=neo4j
|
| 917 |
+
NEO4J_PASSWORD=your_password
|
| 918 |
+
NEO4J_ENABLED=true
|
| 919 |
+
NEO4J_DATABASE=neo4j
|
| 920 |
+
|
| 921 |
+
# ChromaDB (Vector Store)
|
| 922 |
+
CHROMADB_PATH=./data/chromadb
|
| 923 |
+
CHROMADB_COLLECTION=Roger_feeds
|
| 924 |
+
CHROMADB_SIMILARITY_THRESHOLD=0.85
|
| 925 |
+
|
| 926 |
+
# SQLite (Fast Cache)
|
| 927 |
+
SQLITE_DB_PATH=./data/cache/feeds.db
|
| 928 |
|
| 929 |
# MLflow (DagsHub)
|
| 930 |
MLFLOW_TRACKING_URI=https://dagshub.com/...
|
frontend/app/components/dashboard/EconomicIndicators.tsx
CHANGED
|
@@ -2,7 +2,7 @@
|
|
| 2 |
|
| 3 |
import { Card } from "../ui/card";
|
| 4 |
import { Badge } from "../ui/badge";
|
| 5 |
-
import { TrendingUp, TrendingDown, Minus, Landmark, DollarSign, Percent, Building2 } from "lucide-react";
|
| 6 |
|
| 7 |
interface EconomicIndicatorsProps {
|
| 8 |
economyData?: Record<string, unknown> | null;
|
|
@@ -15,6 +15,7 @@ const EconomicIndicators = ({ economyData }: EconomicIndicatorsProps) => {
|
|
| 15 |
const exchangeRate = indicators?.exchange_rate || {};
|
| 16 |
const forexReserves = indicators?.forex_reserves || {};
|
| 17 |
const dataAsOf = economyData?.data_as_of as string;
|
|
|
|
| 18 |
|
| 19 |
const getTrendIcon = (trend: string) => {
|
| 20 |
if (trend === "improving" || trend === "stable") return <TrendingUp className="w-3 h-3 text-success" />;
|
|
@@ -22,6 +23,14 @@ const EconomicIndicators = ({ economyData }: EconomicIndicatorsProps) => {
|
|
| 22 |
return <Minus className="w-3 h-3 text-muted-foreground" />;
|
| 23 |
};
|
| 24 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
return (
|
| 26 |
<Card className="p-4 bg-card border-border">
|
| 27 |
<div className="flex items-center justify-between mb-3">
|
|
@@ -34,9 +43,17 @@ const EconomicIndicators = ({ economyData }: EconomicIndicatorsProps) => {
|
|
| 34 |
<p className="text-xs text-muted-foreground">CBSL Indicators</p>
|
| 35 |
</div>
|
| 36 |
</div>
|
| 37 |
-
<
|
| 38 |
-
{
|
| 39 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
</div>
|
| 41 |
|
| 42 |
<div className="grid grid-cols-2 gap-2">
|
|
@@ -44,7 +61,7 @@ const EconomicIndicators = ({ economyData }: EconomicIndicatorsProps) => {
|
|
| 44 |
<div className="p-2 rounded-lg bg-muted/30 border border-border">
|
| 45 |
<div className="flex items-center gap-1 mb-1">
|
| 46 |
<Percent className="w-3 h-3 text-muted-foreground" />
|
| 47 |
-
<span className="text-xs text-muted-foreground">Inflation
|
| 48 |
</div>
|
| 49 |
<div className="flex items-center gap-1">
|
| 50 |
<span className="text-lg font-bold">{inflation.ccpi_yoy as number || 0}%</span>
|
|
@@ -59,18 +76,25 @@ const EconomicIndicators = ({ economyData }: EconomicIndicatorsProps) => {
|
|
| 59 |
<span className="text-xs text-muted-foreground">USD/LKR</span>
|
| 60 |
</div>
|
| 61 |
<div className="flex items-center gap-1">
|
| 62 |
-
<span className="text-lg font-bold">{
|
| 63 |
{getTrendIcon(exchangeRate.trend as string)}
|
| 64 |
</div>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
</div>
|
| 66 |
|
| 67 |
{/* Policy Rate */}
|
| 68 |
<div className="p-2 rounded-lg bg-muted/30 border border-border">
|
| 69 |
<div className="flex items-center gap-1 mb-1">
|
| 70 |
<Landmark className="w-3 h-3 text-muted-foreground" />
|
| 71 |
-
<span className="text-xs text-muted-foreground">
|
| 72 |
</div>
|
| 73 |
-
<span className="text-lg font-bold">{
|
| 74 |
</div>
|
| 75 |
|
| 76 |
{/* Forex Reserves */}
|
|
@@ -80,7 +104,7 @@ const EconomicIndicators = ({ economyData }: EconomicIndicatorsProps) => {
|
|
| 80 |
<span className="text-xs text-muted-foreground">Reserves</span>
|
| 81 |
</div>
|
| 82 |
<div className="flex items-center gap-1">
|
| 83 |
-
<span className="text-lg font-bold">${forexReserves.value as number || 0}B</span>
|
| 84 |
{getTrendIcon(forexReserves.trend as string)}
|
| 85 |
</div>
|
| 86 |
</div>
|
|
@@ -94,3 +118,4 @@ const EconomicIndicators = ({ economyData }: EconomicIndicatorsProps) => {
|
|
| 94 |
};
|
| 95 |
|
| 96 |
export default EconomicIndicators;
|
|
|
|
|
|
| 2 |
|
| 3 |
import { Card } from "../ui/card";
|
| 4 |
import { Badge } from "../ui/badge";
|
| 5 |
+
import { TrendingUp, TrendingDown, Minus, Landmark, DollarSign, Percent, Building2, Radio } from "lucide-react";
|
| 6 |
|
| 7 |
interface EconomicIndicatorsProps {
|
| 8 |
economyData?: Record<string, unknown> | null;
|
|
|
|
| 15 |
const exchangeRate = indicators?.exchange_rate || {};
|
| 16 |
const forexReserves = indicators?.forex_reserves || {};
|
| 17 |
const dataAsOf = economyData?.data_as_of as string;
|
| 18 |
+
const scrapeStatus = economyData?.scrape_status as string;
|
| 19 |
|
| 20 |
const getTrendIcon = (trend: string) => {
|
| 21 |
if (trend === "improving" || trend === "stable") return <TrendingUp className="w-3 h-3 text-success" />;
|
|
|
|
| 23 |
return <Minus className="w-3 h-3 text-muted-foreground" />;
|
| 24 |
};
|
| 25 |
|
| 26 |
+
// Get the exchange rate - prefer mid rate, fallback to sell or buy
|
| 27 |
+
const usdLkr = (exchangeRate.usd_lkr as number) ||
|
| 28 |
+
(exchangeRate.usd_lkr_sell as number) ||
|
| 29 |
+
(exchangeRate.usd_lkr_buy as number) || 0;
|
| 30 |
+
|
| 31 |
+
// Get policy rate - prefer overnight, fallback to SDFR
|
| 32 |
+
const policyRate = (policyRates.overnight_rate as number) || (policyRates.sdfr as number) || 0;
|
| 33 |
+
|
| 34 |
return (
|
| 35 |
<Card className="p-4 bg-card border-border">
|
| 36 |
<div className="flex items-center justify-between mb-3">
|
|
|
|
| 43 |
<p className="text-xs text-muted-foreground">CBSL Indicators</p>
|
| 44 |
</div>
|
| 45 |
</div>
|
| 46 |
+
<div className="flex items-center gap-1">
|
| 47 |
+
{scrapeStatus === "live" && (
|
| 48 |
+
<Badge className="bg-success/20 text-success text-xs flex items-center gap-1">
|
| 49 |
+
<Radio className="w-2 h-2 animate-pulse" />
|
| 50 |
+
LIVE
|
| 51 |
+
</Badge>
|
| 52 |
+
)}
|
| 53 |
+
<Badge className="bg-muted text-muted-foreground">
|
| 54 |
+
{dataAsOf || "Latest"}
|
| 55 |
+
</Badge>
|
| 56 |
+
</div>
|
| 57 |
</div>
|
| 58 |
|
| 59 |
<div className="grid grid-cols-2 gap-2">
|
|
|
|
| 61 |
<div className="p-2 rounded-lg bg-muted/30 border border-border">
|
| 62 |
<div className="flex items-center gap-1 mb-1">
|
| 63 |
<Percent className="w-3 h-3 text-muted-foreground" />
|
| 64 |
+
<span className="text-xs text-muted-foreground">CCPI Inflation</span>
|
| 65 |
</div>
|
| 66 |
<div className="flex items-center gap-1">
|
| 67 |
<span className="text-lg font-bold">{inflation.ccpi_yoy as number || 0}%</span>
|
|
|
|
| 76 |
<span className="text-xs text-muted-foreground">USD/LKR</span>
|
| 77 |
</div>
|
| 78 |
<div className="flex items-center gap-1">
|
| 79 |
+
<span className="text-lg font-bold">{usdLkr.toFixed(2)}</span>
|
| 80 |
{getTrendIcon(exchangeRate.trend as string)}
|
| 81 |
</div>
|
| 82 |
+
{/* Show Buy/Sell if available */}
|
| 83 |
+
{((exchangeRate.usd_lkr_buy as number | undefined) || (exchangeRate.usd_lkr_sell as number | undefined)) && (
|
| 84 |
+
<p className="text-xs text-muted-foreground mt-0.5">
|
| 85 |
+
Buy: {((exchangeRate.usd_lkr_buy as number | undefined)?.toFixed(2)) || "—"} |
|
| 86 |
+
Sell: {((exchangeRate.usd_lkr_sell as number | undefined)?.toFixed(2)) || "—"}
|
| 87 |
+
</p>
|
| 88 |
+
)}
|
| 89 |
</div>
|
| 90 |
|
| 91 |
{/* Policy Rate */}
|
| 92 |
<div className="p-2 rounded-lg bg-muted/30 border border-border">
|
| 93 |
<div className="flex items-center gap-1 mb-1">
|
| 94 |
<Landmark className="w-3 h-3 text-muted-foreground" />
|
| 95 |
+
<span className="text-xs text-muted-foreground">Policy Rate</span>
|
| 96 |
</div>
|
| 97 |
+
<span className="text-lg font-bold">{policyRate}%</span>
|
| 98 |
</div>
|
| 99 |
|
| 100 |
{/* Forex Reserves */}
|
|
|
|
| 104 |
<span className="text-xs text-muted-foreground">Reserves</span>
|
| 105 |
</div>
|
| 106 |
<div className="flex items-center gap-1">
|
| 107 |
+
<span className="text-lg font-bold">${(forexReserves.value as number) || 0}B</span>
|
| 108 |
{getTrendIcon(forexReserves.trend as string)}
|
| 109 |
</div>
|
| 110 |
</div>
|
|
|
|
| 118 |
};
|
| 119 |
|
| 120 |
export default EconomicIndicators;
|
| 121 |
+
|
frontend/app/components/dashboard/TrendingTopics.tsx
CHANGED
|
@@ -26,10 +26,12 @@ export const TrendingTopics: React.FC = () => {
|
|
| 26 |
const [loading, setLoading] = useState(true);
|
| 27 |
const [error, setError] = useState<string | null>(null);
|
| 28 |
|
|
|
|
|
|
|
| 29 |
useEffect(() => {
|
| 30 |
const fetchTrending = async () => {
|
| 31 |
try {
|
| 32 |
-
const response = await fetch(
|
| 33 |
const result = await response.json();
|
| 34 |
setData(result);
|
| 35 |
setError(null);
|
|
|
|
| 26 |
const [loading, setLoading] = useState(true);
|
| 27 |
const [error, setError] = useState<string | null>(null);
|
| 28 |
|
| 29 |
+
const API_BASE = process.env.NEXT_PUBLIC_API_URL || 'http://localhost:8000';
|
| 30 |
+
|
| 31 |
useEffect(() => {
|
| 32 |
const fetchTrending = async () => {
|
| 33 |
try {
|
| 34 |
+
const response = await fetch(`${API_BASE}/api/trending`);
|
| 35 |
const result = await response.json();
|
| 36 |
setData(result);
|
| 37 |
setError(null);
|
frontend/app/hooks/use-roger-data.ts
CHANGED
|
@@ -11,9 +11,10 @@ const API_BASE = process.env.NEXT_PUBLIC_API_URL || 'http://localhost:8000';
|
|
| 11 |
const WS_URL = API_BASE.replace('http', 'ws') + '/ws';
|
| 12 |
|
| 13 |
// Timeouts for resilient connection
|
| 14 |
-
const RECONNECT_DELAY =
|
| 15 |
const MAX_LOADING_TIME = 120000; // 2 minutes max loading time
|
| 16 |
-
const INITIAL_FETCH_DELAY =
|
|
|
|
| 17 |
|
| 18 |
export interface RogerEvent {
|
| 19 |
event_id: string;
|
|
@@ -96,6 +97,7 @@ export function useRogerData() {
|
|
| 96 |
const wsRef = useRef<WebSocket | null>(null);
|
| 97 |
const loadingTimeoutRef = useRef<NodeJS.Timeout | null>(null);
|
| 98 |
const initialFetchDoneRef = useRef(false);
|
|
|
|
| 99 |
|
| 100 |
// Fetch rivernet data
|
| 101 |
const fetchRiverData = useCallback(async () => {
|
|
@@ -213,9 +215,12 @@ export function useRogerData() {
|
|
| 213 |
};
|
| 214 |
|
| 215 |
websocket.onclose = () => {
|
| 216 |
-
console.log('[Roger] WebSocket disconnected. Reconnecting in
|
| 217 |
setIsConnected(false);
|
| 218 |
|
|
|
|
|
|
|
|
|
|
| 219 |
// Reconnect after delay
|
| 220 |
reconnectTimeout = setTimeout(() => {
|
| 221 |
connect();
|
|
@@ -288,12 +293,13 @@ export function useRogerData() {
|
|
| 288 |
}
|
| 289 |
}, [isConnected]);
|
| 290 |
|
| 291 |
-
// Fallback polling if WebSocket fails
|
| 292 |
useEffect(() => {
|
| 293 |
if (isConnected) return;
|
| 294 |
|
| 295 |
-
|
| 296 |
-
fetchData
|
|
|
|
| 297 |
|
| 298 |
return () => clearInterval(interval);
|
| 299 |
}, [isConnected, fetchData]);
|
|
|
|
| 11 |
const WS_URL = API_BASE.replace('http', 'ws') + '/ws';
|
| 12 |
|
| 13 |
// Timeouts for resilient connection
|
| 14 |
+
const RECONNECT_DELAY = 1000; // Reduced from 3s to 1s for faster recovery
|
| 15 |
const MAX_LOADING_TIME = 120000; // 2 minutes max loading time
|
| 16 |
+
const INITIAL_FETCH_DELAY = 1000; // Fetch from REST after 1s if no WS data
|
| 17 |
+
const FALLBACK_POLL_INTERVAL = 2000; // Poll REST every 2s when WS disconnected
|
| 18 |
|
| 19 |
export interface RogerEvent {
|
| 20 |
event_id: string;
|
|
|
|
| 97 |
const wsRef = useRef<WebSocket | null>(null);
|
| 98 |
const loadingTimeoutRef = useRef<NodeJS.Timeout | null>(null);
|
| 99 |
const initialFetchDoneRef = useRef(false);
|
| 100 |
+
const lastDataTimeRef = useRef<number>(Date.now()); // Track when we last got data
|
| 101 |
|
| 102 |
// Fetch rivernet data
|
| 103 |
const fetchRiverData = useCallback(async () => {
|
|
|
|
| 215 |
};
|
| 216 |
|
| 217 |
websocket.onclose = () => {
|
| 218 |
+
console.log('[Roger] WebSocket disconnected. Reconnecting in 1s...');
|
| 219 |
setIsConnected(false);
|
| 220 |
|
| 221 |
+
// IMMEDIATELY fetch from REST to prevent blank UI
|
| 222 |
+
fetchInitialData();
|
| 223 |
+
|
| 224 |
// Reconnect after delay
|
| 225 |
reconnectTimeout = setTimeout(() => {
|
| 226 |
connect();
|
|
|
|
| 293 |
}
|
| 294 |
}, [isConnected]);
|
| 295 |
|
| 296 |
+
// Fallback polling if WebSocket fails - more aggressive when disconnected
|
| 297 |
useEffect(() => {
|
| 298 |
if (isConnected) return;
|
| 299 |
|
| 300 |
+
console.log('[Roger] WebSocket disconnected - starting aggressive REST polling');
|
| 301 |
+
const interval = setInterval(fetchData, FALLBACK_POLL_INTERVAL);
|
| 302 |
+
fetchData(); // Initial fetch immediately
|
| 303 |
|
| 304 |
return () => clearInterval(interval);
|
| 305 |
}, [isConnected, fetchData]);
|
frontend/app/pages/Index.tsx
CHANGED
|
@@ -10,6 +10,7 @@ import WeatherPredictions from "../components/dashboard/WeatherPredictions";
|
|
| 10 |
import CurrencyPrediction from "../components/dashboard/CurrencyPrediction";
|
| 11 |
import NationalThreatCard from "../components/dashboard/NationalThreatCard";
|
| 12 |
import HistoricalIntel from "../components/dashboard/HistoricalIntel";
|
|
|
|
| 13 |
import SatelliteView from "../components/map/SatelliteView";
|
| 14 |
import LoadingScreen from "../components/LoadingScreen";
|
| 15 |
import { Activity, Map, Radio, BarChart3, Zap, Brain, Cloud, DollarSign, Satellite } from "lucide-react";
|
|
@@ -119,6 +120,7 @@ const Index = () => {
|
|
| 119 |
|
| 120 |
<TabsContent value="overview" className="space-y-6 animate-fade-in">
|
| 121 |
<DashboardOverview />
|
|
|
|
| 122 |
<div className="grid grid-cols-1 lg:grid-cols-2 gap-6">
|
| 123 |
<StockPredictions />
|
| 124 |
<CurrencyPrediction />
|
|
|
|
| 10 |
import CurrencyPrediction from "../components/dashboard/CurrencyPrediction";
|
| 11 |
import NationalThreatCard from "../components/dashboard/NationalThreatCard";
|
| 12 |
import HistoricalIntel from "../components/dashboard/HistoricalIntel";
|
| 13 |
+
import TrendingTopics from "../components/dashboard/TrendingTopics";
|
| 14 |
import SatelliteView from "../components/map/SatelliteView";
|
| 15 |
import LoadingScreen from "../components/LoadingScreen";
|
| 16 |
import { Activity, Map, Radio, BarChart3, Zap, Brain, Cloud, DollarSign, Satellite } from "lucide-react";
|
|
|
|
| 120 |
|
| 121 |
<TabsContent value="overview" className="space-y-6 animate-fade-in">
|
| 122 |
<DashboardOverview />
|
| 123 |
+
<TrendingTopics />
|
| 124 |
<div className="grid grid-cols-1 lg:grid-cols-2 gap-6">
|
| 125 |
<StockPredictions />
|
| 126 |
<CurrencyPrediction />
|
main.py
CHANGED
|
@@ -15,7 +15,7 @@ from pydantic import BaseModel
|
|
| 15 |
from typing import Dict, Any, List, Set, Optional
|
| 16 |
import asyncio
|
| 17 |
import json
|
| 18 |
-
from datetime import datetime, timedelta
|
| 19 |
import sys
|
| 20 |
import os
|
| 21 |
import logging
|
|
@@ -23,6 +23,12 @@ import threading
|
|
| 23 |
import time
|
| 24 |
import uuid # CRITICAL: Was missing, needed for event_id generation
|
| 25 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..')))
|
| 27 |
|
| 28 |
from src.graphs.combinedAgentGraph import graph
|
|
@@ -183,7 +189,7 @@ current_state: Dict[str, Any] = {
|
|
| 183 |
"avg_confidence": 0.0,
|
| 184 |
"high_priority_count": 0,
|
| 185 |
"total_events": 0,
|
| 186 |
-
"last_updated":
|
| 187 |
},
|
| 188 |
"run_count": 0,
|
| 189 |
"status": "initializing",
|
|
@@ -200,12 +206,12 @@ main_event_loop = None
|
|
| 200 |
# Storage manager
|
| 201 |
storage_manager = StorageManager()
|
| 202 |
|
| 203 |
-
# WebSocket settings - RESILIENT for long scraping operations
|
| 204 |
-
#
|
| 205 |
-
HEARTBEAT_INTERVAL =
|
| 206 |
-
HEARTBEAT_TIMEOUT =
|
| 207 |
-
HEARTBEAT_MISS_THRESHOLD =
|
| 208 |
-
SEND_TIMEOUT =
|
| 209 |
|
| 210 |
class ConnectionManager:
|
| 211 |
"""Manages active WebSocket with heartbeat"""
|
|
@@ -218,7 +224,7 @@ class ConnectionManager:
|
|
| 218 |
async with self._lock:
|
| 219 |
meta = {
|
| 220 |
"heartbeat_task": asyncio.create_task(self._heartbeat_loop(websocket)),
|
| 221 |
-
"last_pong":
|
| 222 |
"misses": 0
|
| 223 |
}
|
| 224 |
self.active_connections[websocket] = meta
|
|
@@ -276,7 +282,7 @@ class ConnectionManager:
|
|
| 276 |
if meta is None:
|
| 277 |
return
|
| 278 |
last_pong = meta.get("last_pong")
|
| 279 |
-
if last_pong and (
|
| 280 |
pong_received = True
|
| 281 |
meta['misses'] = 0
|
| 282 |
break
|
|
@@ -463,7 +469,7 @@ def run_graph_loop():
|
|
| 463 |
severity = event_data.get("severity", "medium")
|
| 464 |
impact_type = event_data.get("impact_type", "risk")
|
| 465 |
confidence = event_data.get("confidence_score", event_data.get("confidence", 0.5))
|
| 466 |
-
timestamp = event_data.get("timestamp",
|
| 467 |
|
| 468 |
# Check for duplicates
|
| 469 |
is_dup, _, _ = storage_manager.is_duplicate(summary)
|
|
@@ -525,7 +531,7 @@ async def database_polling_loop():
|
|
| 525 |
Runs concurrently with graph thread.
|
| 526 |
"""
|
| 527 |
global current_state
|
| 528 |
-
last_check =
|
| 529 |
|
| 530 |
logger.info("[DB_POLLER] Starting database polling loop")
|
| 531 |
|
|
@@ -535,7 +541,7 @@ async def database_polling_loop():
|
|
| 535 |
|
| 536 |
# Get new feeds since last check
|
| 537 |
new_feeds = storage_manager.get_feeds_since(last_check)
|
| 538 |
-
last_check =
|
| 539 |
|
| 540 |
if new_feeds:
|
| 541 |
logger.info(f"[DB_POLLER] Found {len(new_feeds)} new feeds")
|
|
@@ -556,7 +562,7 @@ async def database_polling_loop():
|
|
| 556 |
current_state['final_ranked_feed'] = unique_feeds + current_state.get('final_ranked_feed', [])
|
| 557 |
current_state['final_ranked_feed'] = current_state['final_ranked_feed'][:100] # Keep last 100
|
| 558 |
current_state['status'] = 'operational'
|
| 559 |
-
current_state['last_update'] =
|
| 560 |
|
| 561 |
# Mark first run as complete (frontend loading screen can now hide)
|
| 562 |
if not current_state.get('first_run_complete'):
|
|
@@ -775,6 +781,116 @@ def get_national_threat_score():
|
|
| 775 |
"error": str(e)
|
| 776 |
}
|
| 777 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 778 |
# ============================================
|
| 779 |
# SITUATIONAL AWARENESS API ENDPOINTS (NEW)
|
| 780 |
# ============================================
|
|
@@ -1096,41 +1212,49 @@ def record_topic_mention(topic: str, source: str = "manual", domain: str = "gene
|
|
| 1096 |
# ============================================
|
| 1097 |
|
| 1098 |
# Lazy-loaded anomaly detection components
|
| 1099 |
-
|
| 1100 |
_vectorizer = None
|
| 1101 |
_language_detector = None
|
| 1102 |
|
| 1103 |
|
| 1104 |
def _load_anomaly_components():
|
| 1105 |
-
"""Load anomaly detection
|
| 1106 |
-
global
|
| 1107 |
|
| 1108 |
-
if
|
| 1109 |
return True
|
| 1110 |
|
| 1111 |
try:
|
| 1112 |
import joblib
|
| 1113 |
from pathlib import Path
|
| 1114 |
|
| 1115 |
-
# Model
|
| 1116 |
-
models_dir = Path(__file__).parent / "models" / "anomaly-detection" / "src" / "components"
|
| 1117 |
output_dir = Path(__file__).parent / "models" / "anomaly-detection" / "output"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1118 |
|
| 1119 |
-
#
|
| 1120 |
-
|
| 1121 |
-
|
| 1122 |
-
|
| 1123 |
-
|
| 1124 |
-
|
| 1125 |
-
|
| 1126 |
-
|
| 1127 |
-
|
| 1128 |
-
|
| 1129 |
-
|
| 1130 |
-
break
|
| 1131 |
|
| 1132 |
-
if
|
| 1133 |
-
logger.warning("[AnomalyAPI] No trained
|
| 1134 |
return False
|
| 1135 |
|
| 1136 |
# Load vectorizer and language detector
|
|
@@ -1140,7 +1264,7 @@ def _load_anomaly_components():
|
|
| 1140 |
_vectorizer = get_vectorizer()
|
| 1141 |
_language_detector = detect_language
|
| 1142 |
|
| 1143 |
-
logger.info("[AnomalyAPI] ✓
|
| 1144 |
return True
|
| 1145 |
|
| 1146 |
except Exception as e:
|
|
@@ -1151,7 +1275,7 @@ def _load_anomaly_components():
|
|
| 1151 |
@app.post("/api/predict")
|
| 1152 |
def predict_anomaly(texts: List[str] = None, text: str = None):
|
| 1153 |
"""
|
| 1154 |
-
Run anomaly detection on text(s).
|
| 1155 |
|
| 1156 |
Args:
|
| 1157 |
texts: List of texts to analyze
|
|
@@ -1185,7 +1309,7 @@ def predict_anomaly(texts: List[str] = None, text: str = None):
|
|
| 1185 |
"message": "Model not trained yet. Using default scores."
|
| 1186 |
}
|
| 1187 |
|
| 1188 |
-
#
|
| 1189 |
predictions = []
|
| 1190 |
for t in texts:
|
| 1191 |
try:
|
|
@@ -1195,15 +1319,32 @@ def predict_anomaly(texts: List[str] = None, text: str = None):
|
|
| 1195 |
# Vectorize
|
| 1196 |
vector = _vectorizer.vectorize(t, lang)
|
| 1197 |
|
| 1198 |
-
#
|
| 1199 |
-
|
| 1200 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1201 |
|
| 1202 |
-
#
|
| 1203 |
-
|
| 1204 |
-
|
| 1205 |
-
|
| 1206 |
-
|
|
|
|
|
|
|
|
|
|
| 1207 |
else:
|
| 1208 |
score = 1.0 if prediction == -1 else 0.0
|
| 1209 |
|
|
@@ -1212,7 +1353,7 @@ def predict_anomaly(texts: List[str] = None, text: str = None):
|
|
| 1212 |
"is_anomaly": prediction == -1,
|
| 1213 |
"anomaly_score": float(score),
|
| 1214 |
"language": lang,
|
| 1215 |
-
"method":
|
| 1216 |
})
|
| 1217 |
|
| 1218 |
except Exception as e:
|
|
@@ -1228,7 +1369,8 @@ def predict_anomaly(texts: List[str] = None, text: str = None):
|
|
| 1228 |
"predictions": predictions,
|
| 1229 |
"total": len(predictions),
|
| 1230 |
"anomalies_found": sum(1 for p in predictions if p.get("is_anomaly")),
|
| 1231 |
-
"model_status": "loaded"
|
|
|
|
| 1232 |
}
|
| 1233 |
|
| 1234 |
except Exception as e:
|
|
@@ -1302,8 +1444,10 @@ def get_anomalies(limit: int = 20, threshold: float = 0.5):
|
|
| 1302 |
"message": "Using severity + keyword scoring. Train ML model for advanced detection."
|
| 1303 |
}
|
| 1304 |
|
| 1305 |
-
# ML
|
| 1306 |
anomalies = []
|
|
|
|
|
|
|
| 1307 |
for feed in feeds:
|
| 1308 |
summary = feed.get("summary", "")
|
| 1309 |
if not summary:
|
|
@@ -1312,10 +1456,22 @@ def get_anomalies(limit: int = 20, threshold: float = 0.5):
|
|
| 1312 |
try:
|
| 1313 |
lang, _ = _language_detector(summary)
|
| 1314 |
vector = _vectorizer.vectorize(summary, lang)
|
| 1315 |
-
prediction = _anomaly_model.predict([vector])[0]
|
| 1316 |
|
| 1317 |
-
|
| 1318 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1319 |
else:
|
| 1320 |
score = 1.0 if prediction == -1 else 0.0
|
| 1321 |
|
|
@@ -1327,7 +1483,8 @@ def get_anomalies(limit: int = 20, threshold: float = 0.5):
|
|
| 1327 |
**feed,
|
| 1328 |
"anomaly_score": float(round(normalized_score, 3)),
|
| 1329 |
"is_anomaly": prediction == -1,
|
| 1330 |
-
"language": lang
|
|
|
|
| 1331 |
})
|
| 1332 |
|
| 1333 |
if len(anomalies) >= limit:
|
|
@@ -1344,7 +1501,9 @@ def get_anomalies(limit: int = 20, threshold: float = 0.5):
|
|
| 1344 |
"anomalies": anomalies,
|
| 1345 |
"total": len(anomalies),
|
| 1346 |
"threshold": threshold,
|
| 1347 |
-
"model_status": "ml_active"
|
|
|
|
|
|
|
| 1348 |
}
|
| 1349 |
|
| 1350 |
except Exception as e:
|
|
@@ -2200,7 +2359,7 @@ async def websocket_endpoint(websocket: WebSocket):
|
|
| 2200 |
async with manager._lock:
|
| 2201 |
meta = manager.active_connections.get(websocket)
|
| 2202 |
if meta is not None:
|
| 2203 |
-
meta['last_pong'] =
|
| 2204 |
meta['misses'] = 0
|
| 2205 |
continue
|
| 2206 |
except json.JSONDecodeError:
|
|
|
|
| 15 |
from typing import Dict, Any, List, Set, Optional
|
| 16 |
import asyncio
|
| 17 |
import json
|
| 18 |
+
from datetime import datetime, timedelta, timezone
|
| 19 |
import sys
|
| 20 |
import os
|
| 21 |
import logging
|
|
|
|
| 23 |
import time
|
| 24 |
import uuid # CRITICAL: Was missing, needed for event_id generation
|
| 25 |
|
| 26 |
+
|
| 27 |
+
def utc_now() -> datetime:
|
| 28 |
+
"""Return current UTC time (Python 3.12+ compatible)."""
|
| 29 |
+
return datetime.now(timezone.utc)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..')))
|
| 33 |
|
| 34 |
from src.graphs.combinedAgentGraph import graph
|
|
|
|
| 189 |
"avg_confidence": 0.0,
|
| 190 |
"high_priority_count": 0,
|
| 191 |
"total_events": 0,
|
| 192 |
+
"last_updated": utc_now().isoformat()
|
| 193 |
},
|
| 194 |
"run_count": 0,
|
| 195 |
"status": "initializing",
|
|
|
|
| 206 |
# Storage manager
|
| 207 |
storage_manager = StorageManager()
|
| 208 |
|
| 209 |
+
# WebSocket settings - ULTRA-RESILIENT for long scraping operations
|
| 210 |
+
# Heavy graph cycles can take 2-3 minutes, so we need high tolerance
|
| 211 |
+
HEARTBEAT_INTERVAL = 60.0 # Send ping every 60s (increased from 45s)
|
| 212 |
+
HEARTBEAT_TIMEOUT = 45.0 # Wait 45s for pong (increased from 30s)
|
| 213 |
+
HEARTBEAT_MISS_THRESHOLD = 5 # Allow 5 misses = ~5 minutes tolerance
|
| 214 |
+
SEND_TIMEOUT = 15.0 # Increased for slow networks/heavy load
|
| 215 |
|
| 216 |
class ConnectionManager:
|
| 217 |
"""Manages active WebSocket with heartbeat"""
|
|
|
|
| 224 |
async with self._lock:
|
| 225 |
meta = {
|
| 226 |
"heartbeat_task": asyncio.create_task(self._heartbeat_loop(websocket)),
|
| 227 |
+
"last_pong": utc_now(),
|
| 228 |
"misses": 0
|
| 229 |
}
|
| 230 |
self.active_connections[websocket] = meta
|
|
|
|
| 282 |
if meta is None:
|
| 283 |
return
|
| 284 |
last_pong = meta.get("last_pong")
|
| 285 |
+
if last_pong and (utc_now() - last_pong).total_seconds() < (HEARTBEAT_INTERVAL + HEARTBEAT_TIMEOUT):
|
| 286 |
pong_received = True
|
| 287 |
meta['misses'] = 0
|
| 288 |
break
|
|
|
|
| 469 |
severity = event_data.get("severity", "medium")
|
| 470 |
impact_type = event_data.get("impact_type", "risk")
|
| 471 |
confidence = event_data.get("confidence_score", event_data.get("confidence", 0.5))
|
| 472 |
+
timestamp = event_data.get("timestamp", utc_now().isoformat())
|
| 473 |
|
| 474 |
# Check for duplicates
|
| 475 |
is_dup, _, _ = storage_manager.is_duplicate(summary)
|
|
|
|
| 531 |
Runs concurrently with graph thread.
|
| 532 |
"""
|
| 533 |
global current_state
|
| 534 |
+
last_check = utc_now()
|
| 535 |
|
| 536 |
logger.info("[DB_POLLER] Starting database polling loop")
|
| 537 |
|
|
|
|
| 541 |
|
| 542 |
# Get new feeds since last check
|
| 543 |
new_feeds = storage_manager.get_feeds_since(last_check)
|
| 544 |
+
last_check = utc_now()
|
| 545 |
|
| 546 |
if new_feeds:
|
| 547 |
logger.info(f"[DB_POLLER] Found {len(new_feeds)} new feeds")
|
|
|
|
| 562 |
current_state['final_ranked_feed'] = unique_feeds + current_state.get('final_ranked_feed', [])
|
| 563 |
current_state['final_ranked_feed'] = current_state['final_ranked_feed'][:100] # Keep last 100
|
| 564 |
current_state['status'] = 'operational'
|
| 565 |
+
current_state['last_update'] = utc_now().isoformat()
|
| 566 |
|
| 567 |
# Mark first run as complete (frontend loading screen can now hide)
|
| 568 |
if not current_state.get('first_run_complete'):
|
|
|
|
| 781 |
"error": str(e)
|
| 782 |
}
|
| 783 |
|
| 784 |
+
# ============================================
|
| 785 |
+
# INTEL CONFIG API - User Keywords & Profiles
|
| 786 |
+
# ============================================
|
| 787 |
+
|
| 788 |
+
# Global intel config (loaded from file)
|
| 789 |
+
INTEL_CONFIG_PATH = os.path.join(os.path.dirname(__file__), "data", "intel_config.json")
|
| 790 |
+
|
| 791 |
+
# Default config structure
|
| 792 |
+
DEFAULT_INTEL_CONFIG = {
|
| 793 |
+
"user_profiles": {
|
| 794 |
+
"twitter": [],
|
| 795 |
+
"facebook": [],
|
| 796 |
+
"linkedin": []
|
| 797 |
+
},
|
| 798 |
+
"user_keywords": [],
|
| 799 |
+
"user_products": []
|
| 800 |
+
}
|
| 801 |
+
|
| 802 |
+
|
| 803 |
+
def load_intel_config() -> dict:
|
| 804 |
+
"""Load intel config from JSON file."""
|
| 805 |
+
try:
|
| 806 |
+
if os.path.exists(INTEL_CONFIG_PATH):
|
| 807 |
+
with open(INTEL_CONFIG_PATH, "r", encoding="utf-8") as f:
|
| 808 |
+
return json.load(f)
|
| 809 |
+
except Exception as e:
|
| 810 |
+
logger.warning(f"[Intel Config] Error loading config: {e}")
|
| 811 |
+
return DEFAULT_INTEL_CONFIG.copy()
|
| 812 |
+
|
| 813 |
+
|
| 814 |
+
def save_intel_config(config: dict) -> bool:
|
| 815 |
+
"""Save intel config to JSON file."""
|
| 816 |
+
try:
|
| 817 |
+
os.makedirs(os.path.dirname(INTEL_CONFIG_PATH), exist_ok=True)
|
| 818 |
+
with open(INTEL_CONFIG_PATH, "w", encoding="utf-8") as f:
|
| 819 |
+
json.dump(config, f, indent=2, ensure_ascii=False)
|
| 820 |
+
return True
|
| 821 |
+
except Exception as e:
|
| 822 |
+
logger.error(f"[Intel Config] Error saving config: {e}")
|
| 823 |
+
return False
|
| 824 |
+
|
| 825 |
+
|
| 826 |
+
# Load config on startup
|
| 827 |
+
intel_config = load_intel_config()
|
| 828 |
+
|
| 829 |
+
|
| 830 |
+
@app.get("/api/intel/config")
|
| 831 |
+
def get_intel_config():
|
| 832 |
+
"""
|
| 833 |
+
Get current intelligence configuration.
|
| 834 |
+
|
| 835 |
+
Returns user-defined keywords, products, and social profiles to monitor.
|
| 836 |
+
"""
|
| 837 |
+
global intel_config
|
| 838 |
+
intel_config = load_intel_config() # Refresh from file
|
| 839 |
+
return {
|
| 840 |
+
"status": "success",
|
| 841 |
+
"config": intel_config
|
| 842 |
+
}
|
| 843 |
+
|
| 844 |
+
|
| 845 |
+
class IntelConfigUpdate(BaseModel):
|
| 846 |
+
user_profiles: dict = None
|
| 847 |
+
user_keywords: list = None
|
| 848 |
+
user_products: list = None
|
| 849 |
+
|
| 850 |
+
|
| 851 |
+
@app.post("/api/intel/config")
|
| 852 |
+
def update_intel_config(config_update: IntelConfigUpdate):
|
| 853 |
+
"""
|
| 854 |
+
Update intelligence configuration.
|
| 855 |
+
|
| 856 |
+
Accepts user-defined keywords, products, and social profiles.
|
| 857 |
+
Changes take effect on the next agent collection cycle.
|
| 858 |
+
"""
|
| 859 |
+
global intel_config
|
| 860 |
+
|
| 861 |
+
try:
|
| 862 |
+
# Update fields if provided
|
| 863 |
+
if config_update.user_profiles is not None:
|
| 864 |
+
intel_config["user_profiles"] = config_update.user_profiles
|
| 865 |
+
if config_update.user_keywords is not None:
|
| 866 |
+
intel_config["user_keywords"] = config_update.user_keywords
|
| 867 |
+
if config_update.user_products is not None:
|
| 868 |
+
intel_config["user_products"] = config_update.user_products
|
| 869 |
+
|
| 870 |
+
# Save to file
|
| 871 |
+
if save_intel_config(intel_config):
|
| 872 |
+
logger.info(f"[Intel Config] Updated: {len(intel_config.get('user_keywords', []))} keywords, "
|
| 873 |
+
f"{sum(len(v) for v in intel_config.get('user_profiles', {}).values())} profiles")
|
| 874 |
+
return {
|
| 875 |
+
"status": "updated",
|
| 876 |
+
"config": intel_config
|
| 877 |
+
}
|
| 878 |
+
else:
|
| 879 |
+
return {"status": "error", "error": "Failed to save configuration"}
|
| 880 |
+
except Exception as e:
|
| 881 |
+
logger.error(f"[Intel Config] Update error: {e}")
|
| 882 |
+
return {"status": "error", "error": str(e)}
|
| 883 |
+
|
| 884 |
+
|
| 885 |
+
def get_user_intel_config() -> dict:
|
| 886 |
+
"""
|
| 887 |
+
Get the current intel config for use by agents.
|
| 888 |
+
This function is called by social agents to get user-defined keywords and profiles.
|
| 889 |
+
"""
|
| 890 |
+
global intel_config
|
| 891 |
+
return intel_config
|
| 892 |
+
|
| 893 |
+
|
| 894 |
# ============================================
|
| 895 |
# SITUATIONAL AWARENESS API ENDPOINTS (NEW)
|
| 896 |
# ============================================
|
|
|
|
| 1212 |
# ============================================
|
| 1213 |
|
| 1214 |
# Lazy-loaded anomaly detection components
|
| 1215 |
+
_anomaly_models = {} # {language: model}
|
| 1216 |
_vectorizer = None
|
| 1217 |
_language_detector = None
|
| 1218 |
|
| 1219 |
|
| 1220 |
def _load_anomaly_components():
|
| 1221 |
+
"""Load per-language anomaly detection models and vectorizer"""
|
| 1222 |
+
global _anomaly_models, _vectorizer, _language_detector
|
| 1223 |
|
| 1224 |
+
if _anomaly_models:
|
| 1225 |
return True
|
| 1226 |
|
| 1227 |
try:
|
| 1228 |
import joblib
|
| 1229 |
from pathlib import Path
|
| 1230 |
|
| 1231 |
+
# Model directories
|
|
|
|
| 1232 |
output_dir = Path(__file__).parent / "models" / "anomaly-detection" / "output"
|
| 1233 |
+
artifacts_dir = Path(__file__).parent / "models" / "anomaly-detection" / "artifacts" / "model_trainer"
|
| 1234 |
+
|
| 1235 |
+
# Load per-language models
|
| 1236 |
+
for lang in ["english", "sinhala", "tamil"]:
|
| 1237 |
+
for search_dir in [artifacts_dir, output_dir]:
|
| 1238 |
+
model_path = search_dir / f"isolation_forest_{lang}.joblib"
|
| 1239 |
+
if model_path.exists():
|
| 1240 |
+
_anomaly_models[lang] = joblib.load(model_path)
|
| 1241 |
+
logger.info(f"[AnomalyAPI] Loaded {lang} model from {model_path.name}")
|
| 1242 |
+
break
|
| 1243 |
|
| 1244 |
+
# Fallback to legacy model if no per-language models found
|
| 1245 |
+
if not _anomaly_models:
|
| 1246 |
+
legacy_paths = [
|
| 1247 |
+
output_dir / "isolation_forest_embeddings_only.joblib",
|
| 1248 |
+
output_dir / "isolation_forest_model.joblib",
|
| 1249 |
+
]
|
| 1250 |
+
for legacy_path in legacy_paths:
|
| 1251 |
+
if legacy_path.exists():
|
| 1252 |
+
_anomaly_models["english"] = joblib.load(legacy_path)
|
| 1253 |
+
logger.info(f"[AnomalyAPI] Loaded legacy model: {legacy_path.name}")
|
| 1254 |
+
break
|
|
|
|
| 1255 |
|
| 1256 |
+
if not _anomaly_models:
|
| 1257 |
+
logger.warning("[AnomalyAPI] No trained models found. Run training first.")
|
| 1258 |
return False
|
| 1259 |
|
| 1260 |
# Load vectorizer and language detector
|
|
|
|
| 1264 |
_vectorizer = get_vectorizer()
|
| 1265 |
_language_detector = detect_language
|
| 1266 |
|
| 1267 |
+
logger.info(f"[AnomalyAPI] ✓ Loaded models for: {list(_anomaly_models.keys())}")
|
| 1268 |
return True
|
| 1269 |
|
| 1270 |
except Exception as e:
|
|
|
|
| 1275 |
@app.post("/api/predict")
|
| 1276 |
def predict_anomaly(texts: List[str] = None, text: str = None):
|
| 1277 |
"""
|
| 1278 |
+
Run anomaly detection on text(s) using per-language models.
|
| 1279 |
|
| 1280 |
Args:
|
| 1281 |
texts: List of texts to analyze
|
|
|
|
| 1309 |
"message": "Model not trained yet. Using default scores."
|
| 1310 |
}
|
| 1311 |
|
| 1312 |
+
# Process texts with per-language models
|
| 1313 |
predictions = []
|
| 1314 |
for t in texts:
|
| 1315 |
try:
|
|
|
|
| 1319 |
# Vectorize
|
| 1320 |
vector = _vectorizer.vectorize(t, lang)
|
| 1321 |
|
| 1322 |
+
# Select appropriate model
|
| 1323 |
+
if lang in _anomaly_models:
|
| 1324 |
+
model = _anomaly_models[lang]
|
| 1325 |
+
method = f"isolation_forest_{lang}"
|
| 1326 |
+
elif "english" in _anomaly_models:
|
| 1327 |
+
model = _anomaly_models["english"]
|
| 1328 |
+
method = "isolation_forest_english_fallback"
|
| 1329 |
+
else:
|
| 1330 |
+
# No model available
|
| 1331 |
+
predictions.append({
|
| 1332 |
+
"text": t[:100] + "..." if len(t) > 100 else t,
|
| 1333 |
+
"is_anomaly": False,
|
| 1334 |
+
"anomaly_score": 0.0,
|
| 1335 |
+
"language": lang,
|
| 1336 |
+
"method": "no_model"
|
| 1337 |
+
})
|
| 1338 |
+
continue
|
| 1339 |
|
| 1340 |
+
# Predict: -1 = anomaly, 1 = normal
|
| 1341 |
+
prediction = model.predict([vector])[0]
|
| 1342 |
+
|
| 1343 |
+
# Get anomaly score
|
| 1344 |
+
if hasattr(model, 'decision_function'):
|
| 1345 |
+
score = -model.decision_function([vector])[0]
|
| 1346 |
+
elif hasattr(model, 'score_samples'):
|
| 1347 |
+
score = -model.score_samples([vector])[0]
|
| 1348 |
else:
|
| 1349 |
score = 1.0 if prediction == -1 else 0.0
|
| 1350 |
|
|
|
|
| 1353 |
"is_anomaly": prediction == -1,
|
| 1354 |
"anomaly_score": float(score),
|
| 1355 |
"language": lang,
|
| 1356 |
+
"method": method
|
| 1357 |
})
|
| 1358 |
|
| 1359 |
except Exception as e:
|
|
|
|
| 1369 |
"predictions": predictions,
|
| 1370 |
"total": len(predictions),
|
| 1371 |
"anomalies_found": sum(1 for p in predictions if p.get("is_anomaly")),
|
| 1372 |
+
"model_status": "loaded",
|
| 1373 |
+
"models_available": list(_anomaly_models.keys())
|
| 1374 |
}
|
| 1375 |
|
| 1376 |
except Exception as e:
|
|
|
|
| 1444 |
"message": "Using severity + keyword scoring. Train ML model for advanced detection."
|
| 1445 |
}
|
| 1446 |
|
| 1447 |
+
# ML Models are loaded - use per-language models for scoring
|
| 1448 |
anomalies = []
|
| 1449 |
+
per_lang_counts = {"english": 0, "sinhala": 0, "tamil": 0}
|
| 1450 |
+
|
| 1451 |
for feed in feeds:
|
| 1452 |
summary = feed.get("summary", "")
|
| 1453 |
if not summary:
|
|
|
|
| 1456 |
try:
|
| 1457 |
lang, _ = _language_detector(summary)
|
| 1458 |
vector = _vectorizer.vectorize(summary, lang)
|
|
|
|
| 1459 |
|
| 1460 |
+
# Select appropriate model
|
| 1461 |
+
if lang in _anomaly_models:
|
| 1462 |
+
model = _anomaly_models[lang]
|
| 1463 |
+
method = f"isolation_forest_{lang}"
|
| 1464 |
+
elif "english" in _anomaly_models:
|
| 1465 |
+
model = _anomaly_models["english"]
|
| 1466 |
+
method = "isolation_forest_english_fallback"
|
| 1467 |
+
else:
|
| 1468 |
+
continue
|
| 1469 |
+
|
| 1470 |
+
per_lang_counts[lang] = per_lang_counts.get(lang, 0) + 1
|
| 1471 |
+
prediction = model.predict([vector])[0]
|
| 1472 |
+
|
| 1473 |
+
if hasattr(model, 'decision_function'):
|
| 1474 |
+
score = -model.decision_function([vector])[0]
|
| 1475 |
else:
|
| 1476 |
score = 1.0 if prediction == -1 else 0.0
|
| 1477 |
|
|
|
|
| 1483 |
**feed,
|
| 1484 |
"anomaly_score": float(round(normalized_score, 3)),
|
| 1485 |
"is_anomaly": prediction == -1,
|
| 1486 |
+
"language": lang,
|
| 1487 |
+
"detection_method": method
|
| 1488 |
})
|
| 1489 |
|
| 1490 |
if len(anomalies) >= limit:
|
|
|
|
| 1501 |
"anomalies": anomalies,
|
| 1502 |
"total": len(anomalies),
|
| 1503 |
"threshold": threshold,
|
| 1504 |
+
"model_status": "ml_active",
|
| 1505 |
+
"models_loaded": list(_anomaly_models.keys()),
|
| 1506 |
+
"per_language_counts": per_lang_counts
|
| 1507 |
}
|
| 1508 |
|
| 1509 |
except Exception as e:
|
|
|
|
| 2359 |
async with manager._lock:
|
| 2360 |
meta = manager.active_connections.get(websocket)
|
| 2361 |
if meta is not None:
|
| 2362 |
+
meta['last_pong'] = utc_now()
|
| 2363 |
meta['misses'] = 0
|
| 2364 |
continue
|
| 2365 |
except json.JSONDecodeError:
|
models/anomaly-detection/main.py
CHANGED
|
@@ -1,85 +1,138 @@
|
|
| 1 |
"""
|
| 2 |
-
|
| 3 |
-
|
| 4 |
"""
|
| 5 |
-
import os
|
| 6 |
import sys
|
| 7 |
-
import
|
| 8 |
from pathlib import Path
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
#
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
)
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
# Add src to path - AFTER logging is configured
|
| 23 |
-
sys.path.insert(0, str(Path(__file__).parent / "src"))
|
| 24 |
|
| 25 |
-
from src.pipeline import run_training_pipeline
|
| 26 |
-
from src.entity import PipelineConfig
|
| 27 |
|
| 28 |
|
| 29 |
-
def
|
| 30 |
-
"""
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
|
| 35 |
-
# Load environment variables
|
| 36 |
-
from dotenv import load_dotenv
|
| 37 |
-
load_dotenv()
|
| 38 |
-
|
| 39 |
-
# Create configuration
|
| 40 |
-
config = PipelineConfig()
|
| 41 |
-
|
| 42 |
-
# Run pipeline
|
| 43 |
try:
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 57 |
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
logger.info("\n--- Model Training ---")
|
| 66 |
-
logger.info(f"Best model: {artifact.model_trainer.best_model_name}")
|
| 67 |
-
logger.info(f"Best metrics: {artifact.model_trainer.best_model_metrics}")
|
| 68 |
-
logger.info(f"MLflow run: {artifact.model_trainer.mlflow_run_id}")
|
| 69 |
-
|
| 70 |
-
if artifact.model_trainer.n_anomalies:
|
| 71 |
-
logger.info(f"Anomalies detected: {artifact.model_trainer.n_anomalies}")
|
| 72 |
|
| 73 |
-
|
| 74 |
-
logger.info("PIPELINE COMPLETE")
|
| 75 |
-
logger.info("=" * 60)
|
| 76 |
|
| 77 |
-
return artifact
|
| 78 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 79 |
except Exception as e:
|
| 80 |
-
|
| 81 |
-
raise
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
if __name__ == "__main__":
|
| 85 |
-
main()
|
|
|
|
| 1 |
"""
|
| 2 |
+
Anomaly Detection Training Pipeline
|
| 3 |
+
Trains clustering and anomaly detection models on feed data
|
| 4 |
"""
|
|
|
|
| 5 |
import sys
|
| 6 |
+
import os
|
| 7 |
from pathlib import Path
|
| 8 |
+
from datetime import datetime
|
| 9 |
+
|
| 10 |
+
# Load environment variables from root .env BEFORE other imports
|
| 11 |
+
from dotenv import load_dotenv
|
| 12 |
+
ROOT_DIR = Path(__file__).parent.parent.parent # Go to ModelX-Ultimate
|
| 13 |
+
load_dotenv(ROOT_DIR / ".env") # Load root .env with MLflow credentials
|
| 14 |
+
|
| 15 |
+
from src.components.data_ingestion import DataIngestion
|
| 16 |
+
from src.components.data_validation import DataValidation
|
| 17 |
+
from src.components.data_transformation import DataTransformation
|
| 18 |
+
from src.components.model_trainer import ModelTrainer
|
| 19 |
+
from src.exception.exception import AnomalyDetectionException
|
| 20 |
+
from src.logging.logger import logging
|
| 21 |
+
from src.entity.config_entity import (
|
| 22 |
+
DataIngestionConfig, DataValidationConfig,
|
| 23 |
+
DataTransformationConfig, ModelTrainerConfig, PipelineConfig
|
| 24 |
)
|
| 25 |
+
from src.constants.training_pipeline import MODELS_TO_TRAIN, MLFLOW_EXPERIMENT_NAME
|
|
|
|
|
|
|
|
|
|
| 26 |
|
|
|
|
|
|
|
| 27 |
|
| 28 |
|
| 29 |
+
def train_pipeline(pipeline_config: PipelineConfig = None) -> dict:
|
| 30 |
+
"""
|
| 31 |
+
Train the anomaly detection pipeline.
|
| 32 |
+
|
| 33 |
+
Args:
|
| 34 |
+
pipeline_config: Pipeline configuration (optional)
|
| 35 |
+
|
| 36 |
+
Returns:
|
| 37 |
+
dict with training results
|
| 38 |
+
"""
|
| 39 |
+
result = {"status": "failed"}
|
| 40 |
+
|
| 41 |
+
if pipeline_config is None:
|
| 42 |
+
pipeline_config = PipelineConfig()
|
| 43 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
try:
|
| 45 |
+
logging.info("\n" + "=" * 60)
|
| 46 |
+
logging.info("ANOMALY DETECTION TRAINING PIPELINE")
|
| 47 |
+
logging.info(f"Started at: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
|
| 48 |
+
logging.info(f"Models to train: {MODELS_TO_TRAIN}")
|
| 49 |
+
logging.info(f"MLflow experiment: {MLFLOW_EXPERIMENT_NAME}")
|
| 50 |
+
logging.info("=" * 60 + "\n")
|
| 51 |
+
|
| 52 |
+
# Data Ingestion
|
| 53 |
+
data_ingestion_config = pipeline_config.data_ingestion
|
| 54 |
+
data_ingestion = DataIngestion(data_ingestion_config)
|
| 55 |
+
logging.info("Starting data ingestion...")
|
| 56 |
+
data_ingestion_artifact = data_ingestion.initiate_data_ingestion()
|
| 57 |
+
logging.info("✓ Data ingestion completed")
|
| 58 |
+
|
| 59 |
+
# Data Validation
|
| 60 |
+
data_validation_config = pipeline_config.data_validation
|
| 61 |
+
data_validation = DataValidation(data_validation_config)
|
| 62 |
+
logging.info("Starting data validation...")
|
| 63 |
+
data_validation_artifact = data_validation.initiate_data_validation(
|
| 64 |
+
data_ingestion_artifact.raw_data_path
|
| 65 |
+
)
|
| 66 |
+
logging.info("✓ Data validation completed")
|
| 67 |
+
|
| 68 |
+
# Data Transformation
|
| 69 |
+
data_transformation_config = pipeline_config.data_transformation
|
| 70 |
+
data_transformation = DataTransformation(data_transformation_config)
|
| 71 |
+
logging.info("Starting data transformation...")
|
| 72 |
+
data_transformation_artifact = data_transformation.initiate_data_transformation(
|
| 73 |
+
data_validation_artifact.validated_data_path
|
| 74 |
+
)
|
| 75 |
+
logging.info("✓ Data transformation completed")
|
| 76 |
+
|
| 77 |
+
# Model Training
|
| 78 |
+
model_trainer_config = pipeline_config.model_trainer
|
| 79 |
+
model_trainer = ModelTrainer(model_trainer_config)
|
| 80 |
+
logging.info("Starting model training...")
|
| 81 |
+
model_trainer_artifact = model_trainer.initiate_model_trainer(
|
| 82 |
+
data_transformation_artifact.feature_store_path
|
| 83 |
+
)
|
| 84 |
+
logging.info("✓ Model training completed")
|
| 85 |
+
|
| 86 |
+
result = {
|
| 87 |
+
"status": "success",
|
| 88 |
+
"best_model": model_trainer_artifact.best_model_name,
|
| 89 |
+
"best_model_path": model_trainer_artifact.best_model_path,
|
| 90 |
+
"best_metrics": model_trainer_artifact.best_model_metrics,
|
| 91 |
+
"n_anomalies": model_trainer_artifact.n_anomalies,
|
| 92 |
+
"mlflow_run_id": model_trainer_artifact.mlflow_run_id,
|
| 93 |
+
"data_ingestion": {
|
| 94 |
+
"total_records": data_ingestion_artifact.total_records,
|
| 95 |
+
"from_sqlite": data_ingestion_artifact.records_from_sqlite,
|
| 96 |
+
"from_csv": data_ingestion_artifact.records_from_csv
|
| 97 |
+
},
|
| 98 |
+
"data_validation": {
|
| 99 |
+
"valid_records": data_validation_artifact.valid_records,
|
| 100 |
+
"validation_status": data_validation_artifact.validation_status
|
| 101 |
+
},
|
| 102 |
+
"data_transformation": {
|
| 103 |
+
"language_distribution": data_transformation_artifact.language_distribution
|
| 104 |
+
}
|
| 105 |
+
}
|
| 106 |
+
|
| 107 |
+
logging.info("\n" + "=" * 60)
|
| 108 |
+
logging.info("PIPELINE RESULTS")
|
| 109 |
+
logging.info("=" * 60)
|
| 110 |
+
logging.info(f"Status: {result['status']}")
|
| 111 |
+
logging.info(f"Best model: {result['best_model']}")
|
| 112 |
+
logging.info(f"Anomalies detected: {result['n_anomalies']}")
|
| 113 |
+
logging.info(f"MLflow run: {result.get('mlflow_run_id', 'N/A')}")
|
| 114 |
+
logging.info("=" * 60 + "\n")
|
| 115 |
+
|
| 116 |
+
logging.info("✓ Pipeline completed successfully!")
|
| 117 |
|
| 118 |
+
except Exception as e:
|
| 119 |
+
logging.error(f"✗ Pipeline failed: {str(e)}")
|
| 120 |
+
result = {
|
| 121 |
+
"status": "failed",
|
| 122 |
+
"error": str(e)
|
| 123 |
+
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 124 |
|
| 125 |
+
return result
|
|
|
|
|
|
|
| 126 |
|
|
|
|
| 127 |
|
| 128 |
+
if __name__ == '__main__':
|
| 129 |
+
try:
|
| 130 |
+
results = train_pipeline()
|
| 131 |
+
|
| 132 |
+
if results["status"] == "failed":
|
| 133 |
+
logging.error("Pipeline failed - check logs for details")
|
| 134 |
+
sys.exit(1)
|
| 135 |
+
|
| 136 |
except Exception as e:
|
| 137 |
+
logging.error(f"Pipeline crashed: {e}")
|
| 138 |
+
raise AnomalyDetectionException(e, sys)
|
|
|
|
|
|
|
|
|
|
|
|
models/anomaly-detection/src/components/data_ingestion.py
CHANGED
|
@@ -183,7 +183,7 @@ class DataIngestion:
|
|
| 183 |
|
| 184 |
return df
|
| 185 |
|
| 186 |
-
def
|
| 187 |
"""
|
| 188 |
Execute data ingestion pipeline.
|
| 189 |
|
|
@@ -228,6 +228,9 @@ class DataIngestion:
|
|
| 228 |
output_path = Path(self.config.output_directory) / f"ingested_data_{timestamp}.parquet"
|
| 229 |
|
| 230 |
if is_data_available:
|
|
|
|
|
|
|
|
|
|
| 231 |
combined_df.to_parquet(output_path, index=False)
|
| 232 |
logger.info(f"[DataIngestion] Saved {total_records} records to {output_path}")
|
| 233 |
else:
|
|
|
|
| 183 |
|
| 184 |
return df
|
| 185 |
|
| 186 |
+
def initiate_data_ingestion(self) -> DataIngestionArtifact:
|
| 187 |
"""
|
| 188 |
Execute data ingestion pipeline.
|
| 189 |
|
|
|
|
| 228 |
output_path = Path(self.config.output_directory) / f"ingested_data_{timestamp}.parquet"
|
| 229 |
|
| 230 |
if is_data_available:
|
| 231 |
+
# Convert timestamp column to datetime to avoid parquet conversion error
|
| 232 |
+
if "timestamp" in combined_df.columns:
|
| 233 |
+
combined_df["timestamp"] = pd.to_datetime(combined_df["timestamp"], errors="coerce")
|
| 234 |
combined_df.to_parquet(output_path, index=False)
|
| 235 |
logger.info(f"[DataIngestion] Saved {total_records} records to {output_path}")
|
| 236 |
else:
|
models/anomaly-detection/src/components/data_transformation.py
CHANGED
|
@@ -330,7 +330,7 @@ class DataTransformation:
|
|
| 330 |
logger.info(f"[DataTransformation] Feature matrix shape: {feature_matrix.shape}")
|
| 331 |
return feature_matrix
|
| 332 |
|
| 333 |
-
def
|
| 334 |
"""
|
| 335 |
Execute data transformation pipeline.
|
| 336 |
Integrates with Vectorizer Agent Graph for LLM-enhanced processing.
|
|
@@ -409,6 +409,11 @@ class DataTransformation:
|
|
| 409 |
embeddings_path = Path(self.config.output_directory) / f"embeddings_{timestamp}.npy"
|
| 410 |
np.save(embeddings_path, embeddings)
|
| 411 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 412 |
# Save feature matrix
|
| 413 |
features_path = Path(self.config.output_directory) / f"features_{timestamp}.npy"
|
| 414 |
np.save(features_path, feature_matrix)
|
|
|
|
| 330 |
logger.info(f"[DataTransformation] Feature matrix shape: {feature_matrix.shape}")
|
| 331 |
return feature_matrix
|
| 332 |
|
| 333 |
+
def initiate_data_transformation(self, data_path: str) -> DataTransformationArtifact:
|
| 334 |
"""
|
| 335 |
Execute data transformation pipeline.
|
| 336 |
Integrates with Vectorizer Agent Graph for LLM-enhanced processing.
|
|
|
|
| 409 |
embeddings_path = Path(self.config.output_directory) / f"embeddings_{timestamp}.npy"
|
| 410 |
np.save(embeddings_path, embeddings)
|
| 411 |
|
| 412 |
+
# Save language labels for per-language model training
|
| 413 |
+
languages_path = Path(self.config.output_directory) / f"languages_{timestamp}.npy"
|
| 414 |
+
np.save(languages_path, df["language"].values)
|
| 415 |
+
logger.info(f"[DataTransformation] Saved language labels to {languages_path.name}")
|
| 416 |
+
|
| 417 |
# Save feature matrix
|
| 418 |
features_path = Path(self.config.output_directory) / f"features_{timestamp}.npy"
|
| 419 |
np.save(features_path, feature_matrix)
|
models/anomaly-detection/src/components/data_validation.py
CHANGED
|
@@ -182,7 +182,7 @@ class DataValidation:
|
|
| 182 |
|
| 183 |
return errors
|
| 184 |
|
| 185 |
-
def
|
| 186 |
"""
|
| 187 |
Execute data validation pipeline.
|
| 188 |
|
|
|
|
| 182 |
|
| 183 |
return errors
|
| 184 |
|
| 185 |
+
def initiate_data_validation(self, data_path: str) -> DataValidationArtifact:
|
| 186 |
"""
|
| 187 |
Execute data validation pipeline.
|
| 188 |
|
models/anomaly-detection/src/components/model_trainer.py
CHANGED
|
@@ -358,7 +358,7 @@ class ModelTrainer:
|
|
| 358 |
return func(X, trial)
|
| 359 |
return {"error": f"Unknown model: {model_name}"}
|
| 360 |
|
| 361 |
-
def
|
| 362 |
"""
|
| 363 |
Execute model training pipeline.
|
| 364 |
|
|
@@ -476,37 +476,88 @@ class ModelTrainer:
|
|
| 476 |
logger.info(f"[ModelTrainer] Best model: {best_model['name'] if best_model else 'N/A'}")
|
| 477 |
|
| 478 |
# ============================================
|
| 479 |
-
# TRAIN
|
| 480 |
# ============================================
|
| 481 |
-
#
|
| 482 |
-
#
|
|
|
|
| 483 |
try:
|
| 484 |
# Check if features include extra metadata (> 768 dims)
|
| 485 |
if X.shape[1] > 768:
|
| 486 |
-
|
| 487 |
-
|
| 488 |
-
# Extract only the first 768 dimensions (BERT embeddings)
|
| 489 |
-
X_embeddings_only = X[:, :768]
|
| 490 |
-
logger.info(f"[ModelTrainer] Embedding-only shape: {X_embeddings_only.shape}")
|
| 491 |
-
|
| 492 |
-
# Train Isolation Forest on embeddings only
|
| 493 |
-
embedding_model = IsolationForest(
|
| 494 |
-
contamination=0.1,
|
| 495 |
-
n_estimators=100,
|
| 496 |
-
random_state=42,
|
| 497 |
-
n_jobs=-1
|
| 498 |
-
)
|
| 499 |
-
embedding_model.fit(X_embeddings_only)
|
| 500 |
-
|
| 501 |
-
# Save to a dedicated path for the Vectorizer Agent
|
| 502 |
-
embedding_model_path = Path(self.config.output_directory) / "isolation_forest_embeddings_only.joblib"
|
| 503 |
-
joblib.dump(embedding_model, embedding_model_path)
|
| 504 |
-
|
| 505 |
-
logger.info(f"[ModelTrainer] Embedding-only model saved: {embedding_model_path}")
|
| 506 |
-
logger.info("[ModelTrainer] This model is for real-time inference by Vectorizer Agent")
|
| 507 |
else:
|
| 508 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 509 |
except Exception as e:
|
| 510 |
-
logger.warning(f"[ModelTrainer]
|
|
|
|
|
|
|
| 511 |
|
| 512 |
return artifact
|
|
|
|
|
|
| 358 |
return func(X, trial)
|
| 359 |
return {"error": f"Unknown model: {model_name}"}
|
| 360 |
|
| 361 |
+
def initiate_model_trainer(self, feature_path: str) -> ModelTrainerArtifact:
|
| 362 |
"""
|
| 363 |
Execute model training pipeline.
|
| 364 |
|
|
|
|
| 476 |
logger.info(f"[ModelTrainer] Best model: {best_model['name'] if best_model else 'N/A'}")
|
| 477 |
|
| 478 |
# ============================================
|
| 479 |
+
# TRAIN PER-LANGUAGE MODELS FOR LIVE INFERENCE
|
| 480 |
# ============================================
|
| 481 |
+
# Different BERT models produce embeddings in different vector spaces.
|
| 482 |
+
# We train separate Isolation Forest models per language to avoid
|
| 483 |
+
# mixing incompatible embeddings.
|
| 484 |
try:
|
| 485 |
# Check if features include extra metadata (> 768 dims)
|
| 486 |
if X.shape[1] > 768:
|
| 487 |
+
X_embeddings = X[:, :768] # Extract BERT embeddings only
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 488 |
else:
|
| 489 |
+
X_embeddings = X
|
| 490 |
+
|
| 491 |
+
logger.info(f"[ModelTrainer] Training per-language models on {X_embeddings.shape[0]} samples...")
|
| 492 |
+
|
| 493 |
+
# Load language labels from the same directory as features
|
| 494 |
+
feature_dir = Path(feature_path).parent
|
| 495 |
+
lang_files = list(feature_dir.glob("languages_*.npy"))
|
| 496 |
+
|
| 497 |
+
if lang_files:
|
| 498 |
+
# Get most recent language file
|
| 499 |
+
latest_lang_file = max(lang_files, key=lambda p: p.stem)
|
| 500 |
+
languages = np.load(latest_lang_file, allow_pickle=True)
|
| 501 |
+
logger.info(f"[ModelTrainer] Loaded language labels from {latest_lang_file.name}")
|
| 502 |
+
else:
|
| 503 |
+
# Fallback: try to load from transformed data parquet
|
| 504 |
+
parquet_files = list(feature_dir.glob("transformed_*.parquet"))
|
| 505 |
+
if parquet_files:
|
| 506 |
+
import pandas as pd
|
| 507 |
+
latest_parquet = max(parquet_files, key=lambda p: p.stem)
|
| 508 |
+
df_temp = pd.read_parquet(latest_parquet)
|
| 509 |
+
if "language" in df_temp.columns:
|
| 510 |
+
languages = df_temp["language"].values
|
| 511 |
+
logger.info(f"[ModelTrainer] Loaded {len(languages)} language labels from parquet")
|
| 512 |
+
else:
|
| 513 |
+
languages = np.array(["english"] * len(X_embeddings))
|
| 514 |
+
logger.warning("[ModelTrainer] No language column in parquet, defaulting to english")
|
| 515 |
+
else:
|
| 516 |
+
languages = np.array(["english"] * len(X_embeddings))
|
| 517 |
+
logger.warning("[ModelTrainer] No language data found, defaulting to english")
|
| 518 |
+
|
| 519 |
+
# Train per-language models
|
| 520 |
+
MIN_SAMPLES_PER_LANGUAGE = 10
|
| 521 |
+
per_lang_models = {}
|
| 522 |
+
|
| 523 |
+
for lang in ["english", "sinhala", "tamil"]:
|
| 524 |
+
lang_mask = languages == lang
|
| 525 |
+
X_lang = X_embeddings[lang_mask]
|
| 526 |
+
|
| 527 |
+
if len(X_lang) >= MIN_SAMPLES_PER_LANGUAGE:
|
| 528 |
+
logger.info(f"[ModelTrainer] Training {lang} model on {len(X_lang)} samples...")
|
| 529 |
+
|
| 530 |
+
lang_model = IsolationForest(
|
| 531 |
+
contamination=0.1,
|
| 532 |
+
n_estimators=100,
|
| 533 |
+
random_state=42,
|
| 534 |
+
n_jobs=-1
|
| 535 |
+
)
|
| 536 |
+
lang_model.fit(X_lang)
|
| 537 |
+
|
| 538 |
+
# Save per-language model
|
| 539 |
+
model_path = Path(self.config.output_directory) / f"isolation_forest_{lang}.joblib"
|
| 540 |
+
joblib.dump(lang_model, model_path)
|
| 541 |
+
per_lang_models[lang] = str(model_path)
|
| 542 |
+
|
| 543 |
+
logger.info(f"[ModelTrainer] ✓ Saved: isolation_forest_{lang}.joblib ({len(X_lang)} samples)")
|
| 544 |
+
else:
|
| 545 |
+
logger.warning(f"[ModelTrainer] Skipping {lang}: only {len(X_lang)} samples (min: {MIN_SAMPLES_PER_LANGUAGE})")
|
| 546 |
+
|
| 547 |
+
# Also save a legacy "embeddings_only" model for backward compatibility (trained on English)
|
| 548 |
+
if "english" in per_lang_models:
|
| 549 |
+
import shutil
|
| 550 |
+
english_model_path = Path(per_lang_models["english"])
|
| 551 |
+
legacy_path = Path(self.config.output_directory) / "isolation_forest_embeddings_only.joblib"
|
| 552 |
+
shutil.copy(english_model_path, legacy_path)
|
| 553 |
+
logger.info(f"[ModelTrainer] ✓ Legacy model copied: isolation_forest_embeddings_only.joblib")
|
| 554 |
+
|
| 555 |
+
logger.info(f"[ModelTrainer] Per-language training complete: {list(per_lang_models.keys())}")
|
| 556 |
+
|
| 557 |
except Exception as e:
|
| 558 |
+
logger.warning(f"[ModelTrainer] Per-language model training failed: {e}")
|
| 559 |
+
import traceback
|
| 560 |
+
traceback.print_exc()
|
| 561 |
|
| 562 |
return artifact
|
| 563 |
+
|
models/anomaly-detection/src/constants/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .training_pipeline import *
|
models/anomaly-detection/src/constants/training_pipeline/__init__.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Anomaly Detection Training Pipeline Constants
|
| 3 |
+
"""
|
| 4 |
+
import os
|
| 5 |
+
|
| 6 |
+
# Pipeline configuration
|
| 7 |
+
PIPELINE_NAME: str = "AnomalyDetection"
|
| 8 |
+
ARTIFACT_DIR: str = "artifacts"
|
| 9 |
+
|
| 10 |
+
# Data sources
|
| 11 |
+
SQLITE_DB_PATH = os.getenv(
|
| 12 |
+
"SQLITE_DB_PATH",
|
| 13 |
+
os.path.join(os.path.dirname(__file__), "..", "..", "..", "..", "..", "data", "feeds", "feed_cache.db")
|
| 14 |
+
)
|
| 15 |
+
CSV_DIRECTORY = os.path.join(os.path.dirname(__file__), "..", "..", "..", "..", "..", "datasets", "political_feeds")
|
| 16 |
+
|
| 17 |
+
# Data Ingestion
|
| 18 |
+
DATA_INGESTION_DIR_NAME: str = "data_ingestion"
|
| 19 |
+
DATA_INGESTION_FEATURE_STORE_DIR: str = "feature_store"
|
| 20 |
+
DATA_INGESTION_INGESTED_DIR: str = "ingested"
|
| 21 |
+
FILE_NAME: str = "ingested_data.parquet"
|
| 22 |
+
MIN_TEXT_LENGTH: int = 10
|
| 23 |
+
BATCH_SIZE: int = 1000
|
| 24 |
+
|
| 25 |
+
# Data Validation
|
| 26 |
+
DATA_VALIDATION_DIR_NAME: str = "data_validation"
|
| 27 |
+
DATA_VALIDATION_VALID_DIR: str = "validated"
|
| 28 |
+
DATA_VALIDATION_INVALID_DIR: str = "invalid"
|
| 29 |
+
SCHEMA_FILE_PATH = os.path.join("data_schema", "schema.yaml")
|
| 30 |
+
REQUIRED_COLUMNS = ["post_id", "timestamp", "platform", "category", "text", "content_hash"]
|
| 31 |
+
|
| 32 |
+
# Data Transformation
|
| 33 |
+
DATA_TRANSFORMATION_DIR_NAME: str = "data_transformation"
|
| 34 |
+
DATA_TRANSFORMATION_TRANSFORMED_DATA_DIR: str = "transformed"
|
| 35 |
+
FEATURE_STORE_FILE_NAME: str = "features.npy"
|
| 36 |
+
|
| 37 |
+
# Language Models (Multilingual BERT)
|
| 38 |
+
ENGLISH_MODEL: str = "distilbert-base-uncased"
|
| 39 |
+
SINHALA_MODEL: str = "keshan/SinhalaBERTo"
|
| 40 |
+
TAMIL_MODEL: str = "l3cube-pune/tamil-bert"
|
| 41 |
+
VECTOR_DIM: int = 768
|
| 42 |
+
|
| 43 |
+
# Model Training
|
| 44 |
+
MODEL_TRAINER_DIR_NAME: str = "model_trainer"
|
| 45 |
+
MODEL_TRAINER_TRAINED_MODEL_DIR: str = "trained_model"
|
| 46 |
+
MODEL_FILE_NAME: str = "model.joblib"
|
| 47 |
+
SAVED_MODEL_DIR = os.path.join("saved_models")
|
| 48 |
+
|
| 49 |
+
# Models to train
|
| 50 |
+
MODELS_TO_TRAIN = ["dbscan", "kmeans", "hdbscan", "isolation_forest", "lof"]
|
| 51 |
+
|
| 52 |
+
# Optuna hyperparameter tuning
|
| 53 |
+
N_OPTUNA_TRIALS: int = 50
|
| 54 |
+
OPTUNA_TIMEOUT_SECONDS: int = 3600 # 1 hour
|
| 55 |
+
|
| 56 |
+
# MLflow configuration
|
| 57 |
+
MLFLOW_TRACKING_URI = os.getenv(
|
| 58 |
+
"MLFLOW_TRACKING_URI",
|
| 59 |
+
"https://dagshub.com/sliitguy/Model-X.mlflow"
|
| 60 |
+
)
|
| 61 |
+
MLFLOW_EXPERIMENT_NAME: str = "anomaly_detection_feeds"
|
| 62 |
+
|
| 63 |
+
# Model thresholds
|
| 64 |
+
MODEL_TRAINER_EXPECTED_SCORE: float = 0.3 # Silhouette score threshold
|
| 65 |
+
MODEL_TRAINER_OVERFITTING_THRESHOLD: float = 0.1
|
models/anomaly-detection/src/entity/config_entity.py
CHANGED
|
@@ -71,7 +71,7 @@ class ModelTrainerConfig:
|
|
| 71 |
"""Configuration for model training component"""
|
| 72 |
# MLflow configuration
|
| 73 |
mlflow_tracking_uri: str = field(default_factory=lambda: os.getenv(
|
| 74 |
-
"MLFLOW_TRACKING_URI", "https://dagshub.com/sliitguy/
|
| 75 |
))
|
| 76 |
mlflow_username: str = field(default_factory=lambda: os.getenv(
|
| 77 |
"MLFLOW_TRACKING_USERNAME", ""
|
|
|
|
| 71 |
"""Configuration for model training component"""
|
| 72 |
# MLflow configuration
|
| 73 |
mlflow_tracking_uri: str = field(default_factory=lambda: os.getenv(
|
| 74 |
+
"MLFLOW_TRACKING_URI", "https://dagshub.com/sliitguy/Model-X.mlflow"
|
| 75 |
))
|
| 76 |
mlflow_username: str = field(default_factory=lambda: os.getenv(
|
| 77 |
"MLFLOW_TRACKING_USERNAME", ""
|
models/anomaly-detection/src/exception/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .exception import AnomalyDetectionException
|
models/anomaly-detection/src/exception/exception.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class AnomalyDetectionException(Exception):
|
| 5 |
+
"""Custom exception for Anomaly Detection pipeline."""
|
| 6 |
+
|
| 7 |
+
def __init__(self, error_message, error_details: sys):
|
| 8 |
+
self.error_message = error_message
|
| 9 |
+
_, _, exc_tb = error_details.exc_info()
|
| 10 |
+
|
| 11 |
+
self.lineno = exc_tb.tb_lineno
|
| 12 |
+
self.file_name = exc_tb.tb_frame.f_code.co_filename
|
| 13 |
+
|
| 14 |
+
def __str__(self):
|
| 15 |
+
return "Error occurred in python script name [{0}] line number [{1}] error message [{2}]".format(
|
| 16 |
+
self.file_name, self.lineno, str(self.error_message)
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
if __name__ == '__main__':
|
| 21 |
+
try:
|
| 22 |
+
a = 1 / 0
|
| 23 |
+
except Exception as e:
|
| 24 |
+
raise AnomalyDetectionException(e, sys)
|
models/anomaly-detection/src/logging/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .logger import logging
|
models/anomaly-detection/src/logging/logger.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Logging configuration for Anomaly Detection pipeline.
|
| 3 |
+
Creates timestamped log files in the logs directory.
|
| 4 |
+
"""
|
| 5 |
+
import logging
|
| 6 |
+
import os
|
| 7 |
+
from datetime import datetime
|
| 8 |
+
|
| 9 |
+
LOG_FILE = f"{datetime.now().strftime('%m_%d_%Y_%H_%M_%S')}.log"
|
| 10 |
+
|
| 11 |
+
logs_path = os.path.join(os.getcwd(), "logs", LOG_FILE)
|
| 12 |
+
|
| 13 |
+
os.makedirs(logs_path, exist_ok=True)
|
| 14 |
+
|
| 15 |
+
LOG_FILE_PATH = os.path.join(logs_path, LOG_FILE)
|
| 16 |
+
|
| 17 |
+
logging.basicConfig(
|
| 18 |
+
filename=LOG_FILE_PATH,
|
| 19 |
+
format="[ %(asctime)s ] %(lineno)d %(name)s - %(levelname)s - %(message)s",
|
| 20 |
+
level=logging.INFO
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
# Also add console handler for visibility
|
| 24 |
+
console_handler = logging.StreamHandler()
|
| 25 |
+
console_handler.setLevel(logging.INFO)
|
| 26 |
+
console_handler.setFormatter(logging.Formatter(
|
| 27 |
+
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
| 28 |
+
))
|
| 29 |
+
|
| 30 |
+
# Get root logger and add console handler
|
| 31 |
+
root_logger = logging.getLogger()
|
| 32 |
+
root_logger.addHandler(console_handler)
|
models/anomaly-detection/src/pipeline/training_pipeline.py
CHANGED
|
@@ -53,7 +53,7 @@ class TrainingPipeline:
|
|
| 53 |
logger.info("=" * 50)
|
| 54 |
|
| 55 |
ingestion = DataIngestion(self.config.data_ingestion)
|
| 56 |
-
artifact = ingestion.
|
| 57 |
|
| 58 |
if not artifact.is_data_available:
|
| 59 |
raise ValueError("No data available for training")
|
|
@@ -67,7 +67,7 @@ class TrainingPipeline:
|
|
| 67 |
logger.info("=" * 50)
|
| 68 |
|
| 69 |
validation = DataValidation(self.config.data_validation)
|
| 70 |
-
artifact = validation.
|
| 71 |
|
| 72 |
return artifact
|
| 73 |
|
|
@@ -78,7 +78,7 @@ class TrainingPipeline:
|
|
| 78 |
logger.info("=" * 50)
|
| 79 |
|
| 80 |
transformation = DataTransformation(self.config.data_transformation)
|
| 81 |
-
artifact = transformation.
|
| 82 |
|
| 83 |
return artifact
|
| 84 |
|
|
@@ -89,7 +89,7 @@ class TrainingPipeline:
|
|
| 89 |
logger.info("=" * 50)
|
| 90 |
|
| 91 |
trainer = ModelTrainer(self.config.model_trainer)
|
| 92 |
-
artifact = trainer.
|
| 93 |
|
| 94 |
return artifact
|
| 95 |
|
|
|
|
| 53 |
logger.info("=" * 50)
|
| 54 |
|
| 55 |
ingestion = DataIngestion(self.config.data_ingestion)
|
| 56 |
+
artifact = ingestion.initiate_data_ingestion()
|
| 57 |
|
| 58 |
if not artifact.is_data_available:
|
| 59 |
raise ValueError("No data available for training")
|
|
|
|
| 67 |
logger.info("=" * 50)
|
| 68 |
|
| 69 |
validation = DataValidation(self.config.data_validation)
|
| 70 |
+
artifact = validation.initiate_data_validation(ingestion_artifact.raw_data_path)
|
| 71 |
|
| 72 |
return artifact
|
| 73 |
|
|
|
|
| 78 |
logger.info("=" * 50)
|
| 79 |
|
| 80 |
transformation = DataTransformation(self.config.data_transformation)
|
| 81 |
+
artifact = transformation.initiate_data_transformation(validation_artifact.validated_data_path)
|
| 82 |
|
| 83 |
return artifact
|
| 84 |
|
|
|
|
| 89 |
logger.info("=" * 50)
|
| 90 |
|
| 91 |
trainer = ModelTrainer(self.config.model_trainer)
|
| 92 |
+
artifact = trainer.initiate_model_trainer(transformation_artifact.feature_store_path)
|
| 93 |
|
| 94 |
return artifact
|
| 95 |
|
models/currency-volatility-prediction/main.py
CHANGED
|
@@ -1,87 +1,87 @@
|
|
| 1 |
"""
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
Can run data collection, training, or prediction independently
|
| 5 |
"""
|
| 6 |
-
import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
import sys
|
| 8 |
-
import
|
| 9 |
import argparse
|
| 10 |
-
from pathlib import Path
|
| 11 |
from datetime import datetime
|
| 12 |
|
| 13 |
-
# CRITICAL: Configure logging BEFORE adding src/ to path
|
| 14 |
-
# (src/logging/ directory would otherwise shadow the standard module)
|
| 15 |
-
logging.basicConfig(
|
| 16 |
-
level=logging.INFO,
|
| 17 |
-
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
| 18 |
-
)
|
| 19 |
-
logger = logging.getLogger("currency_prediction")
|
| 20 |
-
|
| 21 |
-
# Setup paths - AFTER logging is configured
|
| 22 |
-
PIPELINE_ROOT = Path(__file__).parent
|
| 23 |
-
sys.path.insert(0, str(PIPELINE_ROOT / "src"))
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
def run_data_ingestion(period: str = "2y"):
|
| 27 |
-
"""Run data ingestion from yfinance."""
|
| 28 |
-
from components.data_ingestion import CurrencyDataIngestion
|
| 29 |
-
from entity.config_entity import DataIngestionConfig
|
| 30 |
-
|
| 31 |
-
logger.info(f"Starting data ingestion ({period})...")
|
| 32 |
-
|
| 33 |
-
config = DataIngestionConfig(history_period=period)
|
| 34 |
-
ingestion = CurrencyDataIngestion(config)
|
| 35 |
-
|
| 36 |
-
data_path = ingestion.ingest_all()
|
| 37 |
-
|
| 38 |
-
df = ingestion.load_existing(data_path)
|
| 39 |
-
|
| 40 |
-
logger.info("Data Ingestion Complete!")
|
| 41 |
-
logger.info(f"Total records: {len(df)}")
|
| 42 |
-
logger.info(f"Features: {len(df.columns)}")
|
| 43 |
-
logger.info(f"Date range: {df['date'].min()} to {df['date'].max()}")
|
| 44 |
-
logger.info(f"Latest rate: {df['close'].iloc[-1]:.2f} LKR/USD")
|
| 45 |
-
|
| 46 |
-
return data_path
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
def run_training(epochs: int = 100):
|
| 50 |
-
"""Run GRU model training."""
|
| 51 |
-
from components.data_ingestion import CurrencyDataIngestion
|
| 52 |
-
from components.model_trainer import CurrencyGRUTrainer
|
| 53 |
-
from entity.config_entity import ModelTrainerConfig
|
| 54 |
-
|
| 55 |
-
logger.info("Starting model training...")
|
| 56 |
-
|
| 57 |
-
# Load data
|
| 58 |
-
ingestion = CurrencyDataIngestion()
|
| 59 |
-
df = ingestion.load_existing()
|
| 60 |
|
| 61 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 75 |
|
| 76 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
|
|
|
|
| 78 |
|
| 79 |
-
def run_prediction():
|
| 80 |
-
"""Run prediction for next day."""
|
| 81 |
-
from components.data_ingestion import CurrencyDataIngestion
|
| 82 |
-
from components.predictor import CurrencyPredictor
|
| 83 |
|
| 84 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 85 |
|
| 86 |
predictor = CurrencyPredictor()
|
| 87 |
|
|
@@ -89,70 +89,78 @@ def run_prediction():
|
|
| 89 |
ingestion = CurrencyDataIngestion()
|
| 90 |
df = ingestion.load_existing()
|
| 91 |
prediction = predictor.predict(df)
|
|
|
|
| 92 |
except FileNotFoundError:
|
| 93 |
-
|
| 94 |
prediction = predictor.generate_fallback_prediction()
|
| 95 |
except Exception as e:
|
| 96 |
-
|
| 97 |
prediction = predictor.generate_fallback_prediction()
|
| 98 |
|
| 99 |
output_path = predictor.save_prediction(prediction)
|
| 100 |
|
| 101 |
# Display
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
|
| 111 |
if prediction.get('weekly_trend'):
|
| 112 |
-
|
| 113 |
if prediction.get('monthly_trend'):
|
| 114 |
-
|
| 115 |
|
| 116 |
-
|
| 117 |
-
|
| 118 |
|
| 119 |
return prediction
|
| 120 |
|
| 121 |
|
| 122 |
def run_full_pipeline():
|
| 123 |
-
"""
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
# Step 2:
|
| 136 |
-
try:
|
| 137 |
-
run_training(epochs=100)
|
| 138 |
-
except Exception as e:
|
| 139 |
-
logger.error(f"Training failed: {e}")
|
| 140 |
-
|
| 141 |
-
# Step 3: Prediction
|
| 142 |
prediction = run_prediction()
|
| 143 |
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 147 |
|
| 148 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 149 |
|
| 150 |
|
| 151 |
if __name__ == "__main__":
|
| 152 |
parser = argparse.ArgumentParser(description="Currency Prediction Pipeline")
|
| 153 |
parser.add_argument(
|
| 154 |
"--mode",
|
| 155 |
-
choices=["
|
| 156 |
default="predict",
|
| 157 |
help="Pipeline mode to run"
|
| 158 |
)
|
|
@@ -171,11 +179,17 @@ if __name__ == "__main__":
|
|
| 171 |
|
| 172 |
args = parser.parse_args()
|
| 173 |
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
"""
|
| 2 |
+
Currency Volatility Prediction Pipeline - USD/LKR Training
|
| 3 |
+
Follows stock-price-prediction pattern with structured artifact flow
|
|
|
|
| 4 |
"""
|
| 5 |
+
from src.components.data_ingestion import CurrencyDataIngestion
|
| 6 |
+
from src.components.model_trainer import CurrencyGRUTrainer
|
| 7 |
+
from src.components.predictor import CurrencyPredictor
|
| 8 |
+
from src.exception.exception import CurrencyPredictionException
|
| 9 |
+
from src.logging.logger import logging
|
| 10 |
+
from src.entity.config_entity import DataIngestionConfig, ModelTrainerConfig
|
| 11 |
+
|
| 12 |
import sys
|
| 13 |
+
import os
|
| 14 |
import argparse
|
|
|
|
| 15 |
from datetime import datetime
|
| 16 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
|
| 18 |
+
def train_currency(period: str = "2y", epochs: int = 100) -> dict:
|
| 19 |
+
"""
|
| 20 |
+
Train the currency prediction model.
|
| 21 |
+
|
| 22 |
+
Follows stock-price-prediction pattern with structured results.
|
| 23 |
+
|
| 24 |
+
Args:
|
| 25 |
+
period: Data period for yfinance (1y, 2y, 5y)
|
| 26 |
+
epochs: Number of training epochs
|
| 27 |
+
|
| 28 |
+
Returns:
|
| 29 |
+
dict with training results or error info
|
| 30 |
+
"""
|
| 31 |
+
result = {"currency": "USD_LKR", "status": "failed"}
|
| 32 |
|
| 33 |
+
try:
|
| 34 |
+
logging.info(f"\n{'='*60}")
|
| 35 |
+
logging.info("CURRENCY PREDICTION PIPELINE - TRAINING")
|
| 36 |
+
logging.info(f"Started at: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
|
| 37 |
+
logging.info(f"{'='*60}")
|
| 38 |
+
|
| 39 |
+
# Step 1: Data Ingestion
|
| 40 |
+
logging.info("[USD_LKR] Starting data ingestion...")
|
| 41 |
+
config = DataIngestionConfig(history_period=period)
|
| 42 |
+
ingestion = CurrencyDataIngestion(config)
|
| 43 |
+
data_path = ingestion.ingest_all()
|
| 44 |
+
df = ingestion.load_existing(data_path)
|
| 45 |
+
logging.info(f"[USD_LKR] ✓ Data ingestion completed: {len(df)} records")
|
| 46 |
+
|
| 47 |
+
# Step 2: Model Training
|
| 48 |
+
logging.info("[USD_LKR] Starting model training...")
|
| 49 |
+
trainer_config = ModelTrainerConfig(epochs=epochs)
|
| 50 |
+
trainer = CurrencyGRUTrainer(trainer_config)
|
| 51 |
+
train_results = trainer.train(df=df, use_mlflow=True)
|
| 52 |
+
logging.info("[USD_LKR] ✓ Model training completed")
|
| 53 |
+
|
| 54 |
+
result = {
|
| 55 |
+
"currency": "USD_LKR",
|
| 56 |
+
"status": "success",
|
| 57 |
+
"model_path": train_results["model_path"],
|
| 58 |
+
"test_mae": train_results["test_mae"],
|
| 59 |
+
"rmse": train_results["rmse"],
|
| 60 |
+
"direction_accuracy": train_results["direction_accuracy"],
|
| 61 |
+
"epochs_trained": train_results["epochs_trained"]
|
| 62 |
+
}
|
| 63 |
+
|
| 64 |
+
logging.info(f"[USD_LKR] ✓ Pipeline completed successfully!")
|
| 65 |
|
| 66 |
+
except Exception as e:
|
| 67 |
+
logging.error(f"[USD_LKR] ✗ Pipeline failed: {str(e)}")
|
| 68 |
+
result = {
|
| 69 |
+
"currency": "USD_LKR",
|
| 70 |
+
"status": "failed",
|
| 71 |
+
"error": str(e)
|
| 72 |
+
}
|
| 73 |
|
| 74 |
+
return result
|
| 75 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 76 |
|
| 77 |
+
def run_prediction() -> dict:
|
| 78 |
+
"""
|
| 79 |
+
Run prediction for next day.
|
| 80 |
+
|
| 81 |
+
Returns:
|
| 82 |
+
Prediction dictionary
|
| 83 |
+
"""
|
| 84 |
+
logging.info("Generating prediction...")
|
| 85 |
|
| 86 |
predictor = CurrencyPredictor()
|
| 87 |
|
|
|
|
| 89 |
ingestion = CurrencyDataIngestion()
|
| 90 |
df = ingestion.load_existing()
|
| 91 |
prediction = predictor.predict(df)
|
| 92 |
+
logging.info("[USD_LKR] ✓ Prediction generated using trained model")
|
| 93 |
except FileNotFoundError:
|
| 94 |
+
logging.warning("[USD_LKR] Model not trained, using fallback")
|
| 95 |
prediction = predictor.generate_fallback_prediction()
|
| 96 |
except Exception as e:
|
| 97 |
+
logging.error(f"[USD_LKR] Error: {e}")
|
| 98 |
prediction = predictor.generate_fallback_prediction()
|
| 99 |
|
| 100 |
output_path = predictor.save_prediction(prediction)
|
| 101 |
|
| 102 |
# Display
|
| 103 |
+
logging.info(f"\n{'='*50}")
|
| 104 |
+
logging.info(f"USD/LKR PREDICTION FOR {prediction['prediction_date']}")
|
| 105 |
+
logging.info(f"{'='*50}")
|
| 106 |
+
logging.info(f"Current Rate: {prediction['current_rate']:.2f} LKR/USD")
|
| 107 |
+
logging.info(f"Predicted Rate: {prediction['predicted_rate']:.2f} LKR/USD")
|
| 108 |
+
logging.info(f"Expected Change: {prediction['expected_change_pct']:+.3f}%")
|
| 109 |
+
logging.info(f"Direction: {prediction['direction_emoji']} LKR {prediction['direction']}")
|
| 110 |
+
logging.info(f"Volatility: {prediction['volatility_class']}")
|
| 111 |
|
| 112 |
if prediction.get('weekly_trend'):
|
| 113 |
+
logging.info(f"Weekly Trend: {prediction['weekly_trend']:+.2f}%")
|
| 114 |
if prediction.get('monthly_trend'):
|
| 115 |
+
logging.info(f"Monthly Trend: {prediction['monthly_trend']:+.2f}%")
|
| 116 |
|
| 117 |
+
logging.info(f"{'='*50}")
|
| 118 |
+
logging.info(f"Saved to: {output_path}")
|
| 119 |
|
| 120 |
return prediction
|
| 121 |
|
| 122 |
|
| 123 |
def run_full_pipeline():
|
| 124 |
+
"""
|
| 125 |
+
Run the complete pipeline: train → predict.
|
| 126 |
+
Following stock-price-prediction pattern.
|
| 127 |
+
"""
|
| 128 |
+
logging.info("\n" + "="*70)
|
| 129 |
+
logging.info("CURRENCY PREDICTION PIPELINE - FULL RUN")
|
| 130 |
+
logging.info(f"Started at: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
|
| 131 |
+
logging.info("="*70 + "\n")
|
| 132 |
+
|
| 133 |
+
# Step 1: Training
|
| 134 |
+
result = train_currency(period="2y", epochs=100)
|
| 135 |
+
|
| 136 |
+
# Step 2: Prediction
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 137 |
prediction = run_prediction()
|
| 138 |
|
| 139 |
+
# Print summary
|
| 140 |
+
logging.info("\n" + "="*70)
|
| 141 |
+
logging.info("TRAINING SUMMARY")
|
| 142 |
+
logging.info("="*70)
|
| 143 |
+
|
| 144 |
+
if result["status"] == "success":
|
| 145 |
+
logging.info(f" ✓ USD_LKR: {result['model_path']}")
|
| 146 |
+
logging.info(f" MAE: {result['test_mae']:.4f} LKR")
|
| 147 |
+
logging.info(f" RMSE: {result['rmse']:.4f} LKR")
|
| 148 |
+
logging.info(f" Direction Accuracy: {result['direction_accuracy']*100:.1f}%")
|
| 149 |
+
else:
|
| 150 |
+
logging.info(f" ✗ USD_LKR: {result.get('error', 'Unknown error')[:50]}")
|
| 151 |
|
| 152 |
+
logging.info("="*70)
|
| 153 |
+
logging.info(f"Completed at: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
|
| 154 |
+
logging.info("="*70 + "\n")
|
| 155 |
+
|
| 156 |
+
return result, prediction
|
| 157 |
|
| 158 |
|
| 159 |
if __name__ == "__main__":
|
| 160 |
parser = argparse.ArgumentParser(description="Currency Prediction Pipeline")
|
| 161 |
parser.add_argument(
|
| 162 |
"--mode",
|
| 163 |
+
choices=["train", "predict", "full"],
|
| 164 |
default="predict",
|
| 165 |
help="Pipeline mode to run"
|
| 166 |
)
|
|
|
|
| 179 |
|
| 180 |
args = parser.parse_args()
|
| 181 |
|
| 182 |
+
try:
|
| 183 |
+
if args.mode == "train":
|
| 184 |
+
result = train_currency(period=args.period, epochs=args.epochs)
|
| 185 |
+
if result["status"] == "failed":
|
| 186 |
+
sys.exit(1)
|
| 187 |
+
elif args.mode == "predict":
|
| 188 |
+
run_prediction()
|
| 189 |
+
elif args.mode == "full":
|
| 190 |
+
result, prediction = run_full_pipeline()
|
| 191 |
+
if result["status"] == "failed":
|
| 192 |
+
sys.exit(1)
|
| 193 |
+
except Exception as e:
|
| 194 |
+
logging.error(f"Pipeline crashed: {e}")
|
| 195 |
+
raise CurrencyPredictionException(e, sys)
|
models/currency-volatility-prediction/src/components/model_trainer.py
CHANGED
|
@@ -5,6 +5,11 @@ Optimized for 8GB RAM laptops without GPU
|
|
| 5 |
"""
|
| 6 |
import os
|
| 7 |
import sys
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
import logging
|
| 9 |
import numpy as np
|
| 10 |
import pandas as pd
|
|
|
|
| 5 |
"""
|
| 6 |
import os
|
| 7 |
import sys
|
| 8 |
+
|
| 9 |
+
# Fix Windows console encoding issue with MLflow emoji output
|
| 10 |
+
if sys.platform == 'win32':
|
| 11 |
+
sys.stdout.reconfigure(encoding='utf-8', errors='replace')
|
| 12 |
+
|
| 13 |
import logging
|
| 14 |
import numpy as np
|
| 15 |
import pandas as pd
|
models/currency-volatility-prediction/src/components/predictor.py
CHANGED
|
@@ -62,7 +62,11 @@ class CurrencyPredictor:
|
|
| 62 |
if not os.path.exists(model_path):
|
| 63 |
raise FileNotFoundError(f"No trained model found at {model_path}")
|
| 64 |
|
| 65 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
scalers = joblib.load(scaler_path)
|
| 67 |
|
| 68 |
self._scalers = {
|
|
|
|
| 62 |
if not os.path.exists(model_path):
|
| 63 |
raise FileNotFoundError(f"No trained model found at {model_path}")
|
| 64 |
|
| 65 |
+
# Load with compile=False to avoid Keras 2->3 mse serialization issues
|
| 66 |
+
# Then recompile with standard metrics
|
| 67 |
+
self._model = load_model(model_path, compile=False)
|
| 68 |
+
self._model.compile(optimizer='adam', loss='mse', metrics=['mae'])
|
| 69 |
+
|
| 70 |
scalers = joblib.load(scaler_path)
|
| 71 |
|
| 72 |
self._scalers = {
|
models/currency-volatility-prediction/src/exception/__init__.py
CHANGED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from src.exception.exception import CurrencyPredictionException
|
models/currency-volatility-prediction/src/exception/exception.py
CHANGED
|
@@ -1,22 +1,24 @@
|
|
| 1 |
import sys
|
| 2 |
-
from src.log_utils import logger
|
| 3 |
|
| 4 |
-
class NetworkSecurityException(Exception):
|
| 5 |
-
def __init__(self,error_message,error_details:sys):
|
| 6 |
-
self.error_message = error_message
|
| 7 |
-
_,_,exc_tb = error_details.exc_info()
|
| 8 |
-
|
| 9 |
-
self.lineno=exc_tb.tb_lineno
|
| 10 |
-
self.file_name=exc_tb.tb_frame.f_code.co_filename
|
| 11 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
def __str__(self):
|
| 13 |
-
return "Error
|
| 14 |
-
|
|
|
|
|
|
|
| 15 |
|
| 16 |
-
if __name__=='__main__':
|
| 17 |
try:
|
| 18 |
-
|
| 19 |
-
a=1/0
|
| 20 |
-
print("This will not be printed",a)
|
| 21 |
except Exception as e:
|
| 22 |
-
|
|
|
|
| 1 |
import sys
|
|
|
|
| 2 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
|
| 4 |
+
class CurrencyPredictionException(Exception):
|
| 5 |
+
"""Custom exception for Currency Prediction pipeline."""
|
| 6 |
+
|
| 7 |
+
def __init__(self, error_message, error_details: sys):
|
| 8 |
+
self.error_message = error_message
|
| 9 |
+
_, _, exc_tb = error_details.exc_info()
|
| 10 |
+
|
| 11 |
+
self.lineno = exc_tb.tb_lineno
|
| 12 |
+
self.file_name = exc_tb.tb_frame.f_code.co_filename
|
| 13 |
+
|
| 14 |
def __str__(self):
|
| 15 |
+
return "Error occurred in python script name [{0}] line number [{1}] error message [{2}]".format(
|
| 16 |
+
self.file_name, self.lineno, str(self.error_message)
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
|
| 20 |
+
if __name__ == '__main__':
|
| 21 |
try:
|
| 22 |
+
a = 1 / 0
|
|
|
|
|
|
|
| 23 |
except Exception as e:
|
| 24 |
+
raise CurrencyPredictionException(e, sys)
|
models/currency-volatility-prediction/src/logging/__init__.py
CHANGED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from src.logging.logger import logging
|
models/currency-volatility-prediction/src/logging/logger.py
CHANGED
|
@@ -1,15 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import logging
|
| 2 |
import os
|
| 3 |
from datetime import datetime
|
| 4 |
|
| 5 |
-
LOG_FILE=f"{datetime.now().strftime('%m_%d_%Y_%H_%M_%S')}.log"
|
| 6 |
|
| 7 |
-
logs_path=os.path.join(os.getcwd(), "logs", LOG_FILE)
|
| 8 |
|
| 9 |
os.makedirs(logs_path, exist_ok=True)
|
| 10 |
-
# Create the file only if it is not created
|
| 11 |
|
| 12 |
-
LOG_FILE_PATH=os.path.join(logs_path, LOG_FILE)
|
| 13 |
|
| 14 |
logging.basicConfig(
|
| 15 |
filename=LOG_FILE_PATH,
|
|
@@ -17,4 +20,13 @@ logging.basicConfig(
|
|
| 17 |
level=logging.INFO
|
| 18 |
)
|
| 19 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Logging configuration for Currency Prediction pipeline.
|
| 3 |
+
Creates timestamped log files in the logs directory.
|
| 4 |
+
"""
|
| 5 |
import logging
|
| 6 |
import os
|
| 7 |
from datetime import datetime
|
| 8 |
|
| 9 |
+
LOG_FILE = f"{datetime.now().strftime('%m_%d_%Y_%H_%M_%S')}.log"
|
| 10 |
|
| 11 |
+
logs_path = os.path.join(os.getcwd(), "logs", LOG_FILE)
|
| 12 |
|
| 13 |
os.makedirs(logs_path, exist_ok=True)
|
|
|
|
| 14 |
|
| 15 |
+
LOG_FILE_PATH = os.path.join(logs_path, LOG_FILE)
|
| 16 |
|
| 17 |
logging.basicConfig(
|
| 18 |
filename=LOG_FILE_PATH,
|
|
|
|
| 20 |
level=logging.INFO
|
| 21 |
)
|
| 22 |
|
| 23 |
+
# Also add console handler for visibility
|
| 24 |
+
console_handler = logging.StreamHandler()
|
| 25 |
+
console_handler.setLevel(logging.INFO)
|
| 26 |
+
console_handler.setFormatter(logging.Formatter(
|
| 27 |
+
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
| 28 |
+
))
|
| 29 |
|
| 30 |
+
# Get root logger and add console handler
|
| 31 |
+
root_logger = logging.getLogger()
|
| 32 |
+
root_logger.addHandler(console_handler)
|
models/weather-prediction/main.py
CHANGED
|
@@ -1,86 +1,154 @@
|
|
| 1 |
"""
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
Can run data collection, training, or prediction independently
|
| 5 |
"""
|
| 6 |
-
import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
import sys
|
| 8 |
-
import
|
| 9 |
import argparse
|
| 10 |
from pathlib import Path
|
| 11 |
from datetime import datetime
|
| 12 |
|
| 13 |
-
# CRITICAL: Configure logging BEFORE adding src/ to path
|
| 14 |
-
# (src/logging/ directory would otherwise shadow the standard module)
|
| 15 |
-
logging.basicConfig(
|
| 16 |
-
level=logging.INFO,
|
| 17 |
-
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
| 18 |
-
)
|
| 19 |
-
logger = logging.getLogger("weather_prediction")
|
| 20 |
-
|
| 21 |
-
# Setup paths - AFTER logging is configured
|
| 22 |
PIPELINE_ROOT = Path(__file__).parent
|
| 23 |
-
sys.path.insert(0, str(PIPELINE_ROOT / "src"))
|
| 24 |
|
| 25 |
|
| 26 |
-
def
|
| 27 |
-
"""
|
| 28 |
-
|
| 29 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
|
| 31 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
|
| 33 |
-
|
| 34 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
|
| 36 |
-
|
| 37 |
|
| 38 |
-
df = ingestion.load_existing(data_path)
|
| 39 |
-
stats = ingestion.get_data_stats(df)
|
| 40 |
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
|
| 46 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
|
|
|
|
|
|
|
|
|
|
| 48 |
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
|
| 55 |
-
|
| 56 |
|
| 57 |
-
ingestion = DataIngestion()
|
| 58 |
-
df = ingestion.load_existing()
|
| 59 |
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
)
|
| 64 |
|
| 65 |
-
|
| 66 |
-
|
| 67 |
|
| 68 |
-
|
| 69 |
-
try:
|
| 70 |
-
logger.info(f"Training {station_name}...")
|
| 71 |
-
result = trainer.train(
|
| 72 |
-
df=df,
|
| 73 |
-
station_name=station_name,
|
| 74 |
-
epochs=epochs,
|
| 75 |
-
use_mlflow=False # Disabled due to Windows Unicode encoding issues
|
| 76 |
-
)
|
| 77 |
-
results.append(result)
|
| 78 |
-
logger.info(f"[OK] {station_name}: MAE={result['test_mae']:.3f}")
|
| 79 |
-
except Exception as e:
|
| 80 |
-
logger.error(f"[FAIL] {station_name}: {e}")
|
| 81 |
|
| 82 |
-
|
| 83 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 84 |
|
| 85 |
|
| 86 |
def check_and_train_missing_models(priority_only: bool = True, epochs: int = 25):
|
|
@@ -95,8 +163,6 @@ def check_and_train_missing_models(priority_only: bool = True, epochs: int = 25)
|
|
| 95 |
Returns:
|
| 96 |
List of trained station names
|
| 97 |
"""
|
| 98 |
-
from entity.config_entity import WEATHER_STATIONS
|
| 99 |
-
|
| 100 |
models_dir = PIPELINE_ROOT / "artifacts" / "models"
|
| 101 |
models_dir.mkdir(parents=True, exist_ok=True)
|
| 102 |
|
|
@@ -113,44 +179,36 @@ def check_and_train_missing_models(priority_only: bool = True, epochs: int = 25)
|
|
| 113 |
missing_stations.append(station)
|
| 114 |
|
| 115 |
if not missing_stations:
|
| 116 |
-
|
| 117 |
return []
|
| 118 |
|
| 119 |
-
|
| 120 |
-
|
| 121 |
|
| 122 |
# Ensure we have data first
|
| 123 |
data_path = PIPELINE_ROOT / "artifacts" / "data"
|
| 124 |
existing_data = list(data_path.glob("weather_history_*.csv")) if data_path.exists() else []
|
| 125 |
|
| 126 |
if not existing_data:
|
| 127 |
-
|
| 128 |
try:
|
| 129 |
run_data_ingestion(months=3)
|
| 130 |
except Exception as e:
|
| 131 |
-
|
| 132 |
-
|
| 133 |
return []
|
| 134 |
|
| 135 |
-
# Train missing models
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
run_training(station=station, epochs=epochs)
|
| 141 |
-
trained.append(station)
|
| 142 |
-
except Exception as e:
|
| 143 |
-
logger.warning(f"[AUTO-TRAIN] Failed to train {station}: {e}")
|
| 144 |
-
|
| 145 |
-
logger.info(f"[AUTO-TRAIN] Auto-training complete. Trained {len(trained)} models: {', '.join(trained)}")
|
| 146 |
return trained
|
| 147 |
|
| 148 |
|
| 149 |
def run_prediction():
|
| 150 |
"""Run prediction for all districts."""
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
logger.info("Generating predictions...")
|
| 154 |
|
| 155 |
predictor = WeatherPredictor()
|
| 156 |
|
|
@@ -160,9 +218,9 @@ def run_prediction():
|
|
| 160 |
sys.path.insert(0, str(PIPELINE_ROOT.parent.parent / "src"))
|
| 161 |
from utils.utils import tool_rivernet_status
|
| 162 |
rivernet_data = tool_rivernet_status()
|
| 163 |
-
|
| 164 |
except Exception as e:
|
| 165 |
-
|
| 166 |
|
| 167 |
predictions = predictor.predict_all_districts(rivernet_data=rivernet_data)
|
| 168 |
output_path = predictor.save_predictions(predictions)
|
|
@@ -175,48 +233,49 @@ def run_prediction():
|
|
| 175 |
sev = p.get("severity", "normal")
|
| 176 |
severity_counts[sev] = severity_counts.get(sev, 0) + 1
|
| 177 |
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
|
| 188 |
return predictions
|
| 189 |
|
| 190 |
|
| 191 |
def run_full_pipeline():
|
| 192 |
-
"""
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 196 |
|
| 197 |
# Step 1: Data Ingestion
|
| 198 |
try:
|
| 199 |
run_data_ingestion(months=3)
|
| 200 |
except Exception as e:
|
| 201 |
-
|
| 202 |
-
|
| 203 |
|
| 204 |
# Step 2: Training (priority stations only)
|
| 205 |
priority_stations = ["COLOMBO", "KANDY", "JAFFNA", "BATTICALOA", "RATNAPURA"]
|
| 206 |
-
|
| 207 |
-
try:
|
| 208 |
-
run_training(station=station, epochs=50)
|
| 209 |
-
except Exception as e:
|
| 210 |
-
logger.warning(f"Training {station} failed: {e}")
|
| 211 |
|
| 212 |
# Step 3: Prediction
|
| 213 |
predictions = run_prediction()
|
| 214 |
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
|
|
|
| 218 |
|
| 219 |
-
return predictions
|
| 220 |
|
| 221 |
|
| 222 |
if __name__ == "__main__":
|
|
@@ -253,18 +312,34 @@ if __name__ == "__main__":
|
|
| 253 |
|
| 254 |
args = parser.parse_args()
|
| 255 |
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
|
| 265 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 266 |
check_and_train_missing_models(priority_only=True, epochs=25)
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
"""
|
| 2 |
+
Weather Prediction Pipeline - Multi-Station Training
|
| 3 |
+
Follows stock-price-prediction pattern with structured artifact flow
|
|
|
|
| 4 |
"""
|
| 5 |
+
from src.components.data_ingestion import DataIngestion
|
| 6 |
+
from src.components.model_trainer import WeatherLSTMTrainer
|
| 7 |
+
from src.components.predictor import WeatherPredictor
|
| 8 |
+
from src.exception.exception import WeatherPredictionException
|
| 9 |
+
from src.logging.logger import logging
|
| 10 |
+
from src.entity.config_entity import DataIngestionConfig, WEATHER_STATIONS
|
| 11 |
+
|
| 12 |
import sys
|
| 13 |
+
import os
|
| 14 |
import argparse
|
| 15 |
from pathlib import Path
|
| 16 |
from datetime import datetime
|
| 17 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
PIPELINE_ROOT = Path(__file__).parent
|
|
|
|
| 19 |
|
| 20 |
|
| 21 |
+
def train_single_station(station_name: str, epochs: int = 100) -> dict:
|
| 22 |
+
"""
|
| 23 |
+
Train a model for a single weather station.
|
| 24 |
+
|
| 25 |
+
Follows stock-price-prediction pattern with structured results.
|
| 26 |
+
|
| 27 |
+
Args:
|
| 28 |
+
station_name: Weather station name (e.g., 'COLOMBO', 'KANDY')
|
| 29 |
+
epochs: Number of training epochs
|
| 30 |
+
|
| 31 |
+
Returns:
|
| 32 |
+
dict with training results or error info
|
| 33 |
+
"""
|
| 34 |
+
result = {"station": station_name, "status": "failed"}
|
| 35 |
|
| 36 |
+
try:
|
| 37 |
+
logging.info(f"\n{'='*60}")
|
| 38 |
+
logging.info(f"Training model for: {station_name}")
|
| 39 |
+
logging.info(f"{'='*60}")
|
| 40 |
+
|
| 41 |
+
# Data Ingestion
|
| 42 |
+
logging.info(f"[{station_name}] Loading data...")
|
| 43 |
+
ingestion = DataIngestion()
|
| 44 |
+
df = ingestion.load_existing()
|
| 45 |
+
logging.info(f"[{station_name}] ✓ Data loaded")
|
| 46 |
+
|
| 47 |
+
# Model Training
|
| 48 |
+
logging.info(f"[{station_name}] Starting model training...")
|
| 49 |
+
trainer = WeatherLSTMTrainer(
|
| 50 |
+
sequence_length=30,
|
| 51 |
+
lstm_units=[64, 32]
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
train_results = trainer.train(
|
| 55 |
+
df=df,
|
| 56 |
+
station_name=station_name,
|
| 57 |
+
epochs=epochs,
|
| 58 |
+
use_mlflow=False # Disabled due to Windows Unicode encoding issues
|
| 59 |
+
)
|
| 60 |
+
logging.info(f"[{station_name}] ✓ Model training completed")
|
| 61 |
+
|
| 62 |
+
result = {
|
| 63 |
+
"station": station_name,
|
| 64 |
+
"status": "success",
|
| 65 |
+
"model_path": train_results.get("model_path", ""),
|
| 66 |
+
"test_mae": train_results.get("test_mae", 0),
|
| 67 |
+
"test_mse": train_results.get("test_mse", 0),
|
| 68 |
+
"epochs_trained": epochs
|
| 69 |
+
}
|
| 70 |
+
|
| 71 |
+
logging.info(f"[{station_name}] ✓ Pipeline completed successfully!")
|
| 72 |
|
| 73 |
+
except Exception as e:
|
| 74 |
+
logging.error(f"[{station_name}] ✗ Pipeline failed: {str(e)}")
|
| 75 |
+
result = {
|
| 76 |
+
"station": station_name,
|
| 77 |
+
"status": "failed",
|
| 78 |
+
"error": str(e)
|
| 79 |
+
}
|
| 80 |
|
| 81 |
+
return result
|
| 82 |
|
|
|
|
|
|
|
| 83 |
|
| 84 |
+
def train_all_stations(stations: list = None, epochs: int = 100) -> list:
|
| 85 |
+
"""
|
| 86 |
+
Train models for all weather stations.
|
| 87 |
+
Each station gets its own model saved separately.
|
| 88 |
+
|
| 89 |
+
Follows stock-price-prediction pattern.
|
| 90 |
+
"""
|
| 91 |
+
stations_to_train = stations or list(WEATHER_STATIONS.keys())
|
| 92 |
|
| 93 |
+
logging.info("\n" + "="*70)
|
| 94 |
+
logging.info("WEATHER PREDICTION - MULTI-STATION TRAINING PIPELINE")
|
| 95 |
+
logging.info(f"Started at: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
|
| 96 |
+
logging.info(f"Stations to train: {stations_to_train}")
|
| 97 |
+
logging.info("="*70 + "\n")
|
| 98 |
|
| 99 |
+
results = []
|
| 100 |
+
successful = 0
|
| 101 |
+
failed = 0
|
| 102 |
|
| 103 |
+
for station_name in stations_to_train:
|
| 104 |
+
result = train_single_station(station_name, epochs)
|
| 105 |
+
results.append(result)
|
| 106 |
+
|
| 107 |
+
if result["status"] == "success":
|
| 108 |
+
successful += 1
|
| 109 |
+
else:
|
| 110 |
+
failed += 1
|
| 111 |
+
|
| 112 |
+
# Print summary
|
| 113 |
+
logging.info("\n" + "="*70)
|
| 114 |
+
logging.info("TRAINING SUMMARY")
|
| 115 |
+
logging.info("="*70)
|
| 116 |
+
logging.info(f"Total stations: {len(stations_to_train)}")
|
| 117 |
+
logging.info(f"Successful: {successful}")
|
| 118 |
+
logging.info(f"Failed: {failed}")
|
| 119 |
+
logging.info("-"*70)
|
| 120 |
+
|
| 121 |
+
for result in results:
|
| 122 |
+
if result["status"] == "success":
|
| 123 |
+
logging.info(f" ✓ {result['station']}: MAE={result['test_mae']:.3f}")
|
| 124 |
+
else:
|
| 125 |
+
logging.info(f" ✗ {result['station']}: {result.get('error', 'Unknown error')[:50]}")
|
| 126 |
+
|
| 127 |
+
logging.info("="*70)
|
| 128 |
+
logging.info(f"Completed at: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
|
| 129 |
+
logging.info("="*70 + "\n")
|
| 130 |
|
| 131 |
+
return results
|
| 132 |
|
|
|
|
|
|
|
| 133 |
|
| 134 |
+
def run_data_ingestion(months: int = 12):
|
| 135 |
+
"""Run data ingestion for all stations."""
|
| 136 |
+
logging.info(f"Starting data ingestion ({months} months)...")
|
|
|
|
| 137 |
|
| 138 |
+
config = DataIngestionConfig(months_to_fetch=months)
|
| 139 |
+
ingestion = DataIngestion(config)
|
| 140 |
|
| 141 |
+
data_path = ingestion.ingest_all()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 142 |
|
| 143 |
+
df = ingestion.load_existing(data_path)
|
| 144 |
+
stats = ingestion.get_data_stats(df)
|
| 145 |
+
|
| 146 |
+
logging.info("✓ Data Ingestion Complete!")
|
| 147 |
+
logging.info(f"Total records: {stats['total_records']}")
|
| 148 |
+
logging.info(f"Stations: {stats['stations']}")
|
| 149 |
+
logging.info(f"Date range: {stats['date_range']}")
|
| 150 |
+
|
| 151 |
+
return data_path
|
| 152 |
|
| 153 |
|
| 154 |
def check_and_train_missing_models(priority_only: bool = True, epochs: int = 25):
|
|
|
|
| 163 |
Returns:
|
| 164 |
List of trained station names
|
| 165 |
"""
|
|
|
|
|
|
|
| 166 |
models_dir = PIPELINE_ROOT / "artifacts" / "models"
|
| 167 |
models_dir.mkdir(parents=True, exist_ok=True)
|
| 168 |
|
|
|
|
| 179 |
missing_stations.append(station)
|
| 180 |
|
| 181 |
if not missing_stations:
|
| 182 |
+
logging.info("[AUTO-TRAIN] All required models exist.")
|
| 183 |
return []
|
| 184 |
|
| 185 |
+
logging.info(f"[AUTO-TRAIN] Missing models for: {', '.join(missing_stations)}")
|
| 186 |
+
logging.info("[AUTO-TRAIN] Starting automatic training...")
|
| 187 |
|
| 188 |
# Ensure we have data first
|
| 189 |
data_path = PIPELINE_ROOT / "artifacts" / "data"
|
| 190 |
existing_data = list(data_path.glob("weather_history_*.csv")) if data_path.exists() else []
|
| 191 |
|
| 192 |
if not existing_data:
|
| 193 |
+
logging.info("[AUTO-TRAIN] No training data found, ingesting...")
|
| 194 |
try:
|
| 195 |
run_data_ingestion(months=3)
|
| 196 |
except Exception as e:
|
| 197 |
+
logging.error(f"[AUTO-TRAIN] Data ingestion failed: {e}")
|
| 198 |
+
logging.info("[AUTO-TRAIN] Cannot train without data. Please run: python main.py --mode ingest")
|
| 199 |
return []
|
| 200 |
|
| 201 |
+
# Train missing models using structured function
|
| 202 |
+
results = train_all_stations(stations=missing_stations, epochs=epochs)
|
| 203 |
+
|
| 204 |
+
trained = [r["station"] for r in results if r["status"] == "success"]
|
| 205 |
+
logging.info(f"[AUTO-TRAIN] Auto-training complete. Trained {len(trained)} models.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 206 |
return trained
|
| 207 |
|
| 208 |
|
| 209 |
def run_prediction():
|
| 210 |
"""Run prediction for all districts."""
|
| 211 |
+
logging.info("Generating predictions...")
|
|
|
|
|
|
|
| 212 |
|
| 213 |
predictor = WeatherPredictor()
|
| 214 |
|
|
|
|
| 218 |
sys.path.insert(0, str(PIPELINE_ROOT.parent.parent / "src"))
|
| 219 |
from utils.utils import tool_rivernet_status
|
| 220 |
rivernet_data = tool_rivernet_status()
|
| 221 |
+
logging.info(f"✓ RiverNet data available: {len(rivernet_data.get('rivers', []))} rivers")
|
| 222 |
except Exception as e:
|
| 223 |
+
logging.warning(f"RiverNet data unavailable: {e}")
|
| 224 |
|
| 225 |
predictions = predictor.predict_all_districts(rivernet_data=rivernet_data)
|
| 226 |
output_path = predictor.save_predictions(predictions)
|
|
|
|
| 233 |
sev = p.get("severity", "normal")
|
| 234 |
severity_counts[sev] = severity_counts.get(sev, 0) + 1
|
| 235 |
|
| 236 |
+
logging.info(f"\n{'='*50}")
|
| 237 |
+
logging.info(f"PREDICTIONS FOR {predictions['prediction_date']}")
|
| 238 |
+
logging.info(f"{'='*50}")
|
| 239 |
+
logging.info(f"Districts: {len(districts)}")
|
| 240 |
+
logging.info(f"Normal: {severity_counts['normal']}")
|
| 241 |
+
logging.info(f"Advisory: {severity_counts['advisory']}")
|
| 242 |
+
logging.info(f"Warning: {severity_counts['warning']}")
|
| 243 |
+
logging.info(f"Critical: {severity_counts['critical']}")
|
| 244 |
+
logging.info(f"Output: {output_path}")
|
| 245 |
|
| 246 |
return predictions
|
| 247 |
|
| 248 |
|
| 249 |
def run_full_pipeline():
|
| 250 |
+
"""
|
| 251 |
+
Run the full pipeline: ingest → train → predict.
|
| 252 |
+
Following stock-price-prediction pattern.
|
| 253 |
+
"""
|
| 254 |
+
logging.info("\n" + "="*70)
|
| 255 |
+
logging.info("WEATHER PREDICTION PIPELINE - FULL RUN")
|
| 256 |
+
logging.info(f"Started at: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
|
| 257 |
+
logging.info("="*70 + "\n")
|
| 258 |
|
| 259 |
# Step 1: Data Ingestion
|
| 260 |
try:
|
| 261 |
run_data_ingestion(months=3)
|
| 262 |
except Exception as e:
|
| 263 |
+
logging.error(f"Data ingestion failed: {e}")
|
| 264 |
+
logging.info("Attempting to use existing data...")
|
| 265 |
|
| 266 |
# Step 2: Training (priority stations only)
|
| 267 |
priority_stations = ["COLOMBO", "KANDY", "JAFFNA", "BATTICALOA", "RATNAPURA"]
|
| 268 |
+
results = train_all_stations(stations=priority_stations, epochs=50)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 269 |
|
| 270 |
# Step 3: Prediction
|
| 271 |
predictions = run_prediction()
|
| 272 |
|
| 273 |
+
logging.info("\n" + "="*70)
|
| 274 |
+
logging.info("PIPELINE COMPLETE!")
|
| 275 |
+
logging.info(f"Completed at: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
|
| 276 |
+
logging.info("="*70 + "\n")
|
| 277 |
|
| 278 |
+
return results, predictions
|
| 279 |
|
| 280 |
|
| 281 |
if __name__ == "__main__":
|
|
|
|
| 312 |
|
| 313 |
args = parser.parse_args()
|
| 314 |
|
| 315 |
+
try:
|
| 316 |
+
if args.mode == "ingest":
|
| 317 |
+
run_data_ingestion(months=args.months)
|
| 318 |
+
elif args.mode == "train":
|
| 319 |
+
if args.station:
|
| 320 |
+
result = train_single_station(args.station, args.epochs)
|
| 321 |
+
if result["status"] == "failed":
|
| 322 |
+
sys.exit(1)
|
| 323 |
+
else:
|
| 324 |
+
results = train_all_stations(epochs=args.epochs)
|
| 325 |
+
failed = sum(1 for r in results if r["status"] == "failed")
|
| 326 |
+
if failed > 0:
|
| 327 |
+
logging.warning(f"{failed} stations failed to train")
|
| 328 |
+
sys.exit(1)
|
| 329 |
+
elif args.mode == "auto-train":
|
| 330 |
+
# Explicitly auto-train missing models
|
| 331 |
check_and_train_missing_models(priority_only=True, epochs=25)
|
| 332 |
+
elif args.mode == "predict":
|
| 333 |
+
# Auto-train missing models before prediction (unless skipped)
|
| 334 |
+
if not args.skip_auto_train:
|
| 335 |
+
check_and_train_missing_models(priority_only=True, epochs=25)
|
| 336 |
+
run_prediction()
|
| 337 |
+
elif args.mode == "full":
|
| 338 |
+
results, predictions = run_full_pipeline()
|
| 339 |
+
failed = sum(1 for r in results if r["status"] == "failed")
|
| 340 |
+
if failed > 0:
|
| 341 |
+
logging.warning(f"{failed} stations failed to train")
|
| 342 |
+
sys.exit(1)
|
| 343 |
+
except Exception as e:
|
| 344 |
+
logging.error(f"Pipeline crashed: {e}")
|
| 345 |
+
raise WeatherPredictionException(e, sys)
|
models/weather-prediction/src/components/model_trainer.py
CHANGED
|
@@ -4,6 +4,11 @@ LSTM-based Weather Prediction Model Trainer
|
|
| 4 |
"""
|
| 5 |
import os
|
| 6 |
import sys
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
import logging
|
| 8 |
import numpy as np
|
| 9 |
import pandas as pd
|
|
|
|
| 4 |
"""
|
| 5 |
import os
|
| 6 |
import sys
|
| 7 |
+
|
| 8 |
+
# Fix Windows console encoding issue with MLflow emoji output
|
| 9 |
+
if sys.platform == 'win32':
|
| 10 |
+
sys.stdout.reconfigure(encoding='utf-8', errors='replace')
|
| 11 |
+
|
| 12 |
import logging
|
| 13 |
import numpy as np
|
| 14 |
import pandas as pd
|
models/weather-prediction/src/components/predictor.py
CHANGED
|
@@ -67,7 +67,10 @@ class WeatherPredictor:
|
|
| 67 |
logger.warning(f"[PREDICTOR] No model for {station_name}")
|
| 68 |
return None, None
|
| 69 |
|
| 70 |
-
|
|
|
|
|
|
|
|
|
|
| 71 |
self._scalers[station_name] = joblib.load(scaler_path)
|
| 72 |
|
| 73 |
return self._models[station_name], self._scalers[station_name]
|
|
|
|
| 67 |
logger.warning(f"[PREDICTOR] No model for {station_name}")
|
| 68 |
return None, None
|
| 69 |
|
| 70 |
+
# Load with compile=False to avoid Keras 2->3 mse serialization issues
|
| 71 |
+
# Then recompile with standard metrics
|
| 72 |
+
self._models[station_name] = load_model(model_path, compile=False)
|
| 73 |
+
self._models[station_name].compile(optimizer='adam', loss='mse', metrics=['mae'])
|
| 74 |
self._scalers[station_name] = joblib.load(scaler_path)
|
| 75 |
|
| 76 |
return self._models[station_name], self._scalers[station_name]
|
models/weather-prediction/src/exception/__init__.py
CHANGED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from src.exception.exception import WeatherPredictionException
|
models/weather-prediction/src/exception/exception.py
CHANGED
|
@@ -1,22 +1,24 @@
|
|
| 1 |
import sys
|
| 2 |
-
from src.log_utils import logger
|
| 3 |
|
| 4 |
-
class NetworkSecurityException(Exception):
|
| 5 |
-
def __init__(self,error_message,error_details:sys):
|
| 6 |
-
self.error_message = error_message
|
| 7 |
-
_,_,exc_tb = error_details.exc_info()
|
| 8 |
-
|
| 9 |
-
self.lineno=exc_tb.tb_lineno
|
| 10 |
-
self.file_name=exc_tb.tb_frame.f_code.co_filename
|
| 11 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
def __str__(self):
|
| 13 |
-
return "Error
|
| 14 |
-
|
|
|
|
|
|
|
| 15 |
|
| 16 |
-
if __name__=='__main__':
|
| 17 |
try:
|
| 18 |
-
|
| 19 |
-
a=1/0
|
| 20 |
-
print("This will not be printed",a)
|
| 21 |
except Exception as e:
|
| 22 |
-
|
|
|
|
| 1 |
import sys
|
|
|
|
| 2 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
|
| 4 |
+
class WeatherPredictionException(Exception):
|
| 5 |
+
"""Custom exception for Weather Prediction pipeline."""
|
| 6 |
+
|
| 7 |
+
def __init__(self, error_message, error_details: sys):
|
| 8 |
+
self.error_message = error_message
|
| 9 |
+
_, _, exc_tb = error_details.exc_info()
|
| 10 |
+
|
| 11 |
+
self.lineno = exc_tb.tb_lineno
|
| 12 |
+
self.file_name = exc_tb.tb_frame.f_code.co_filename
|
| 13 |
+
|
| 14 |
def __str__(self):
|
| 15 |
+
return "Error occurred in python script name [{0}] line number [{1}] error message [{2}]".format(
|
| 16 |
+
self.file_name, self.lineno, str(self.error_message)
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
|
| 20 |
+
if __name__ == '__main__':
|
| 21 |
try:
|
| 22 |
+
a = 1 / 0
|
|
|
|
|
|
|
| 23 |
except Exception as e:
|
| 24 |
+
raise WeatherPredictionException(e, sys)
|
models/weather-prediction/src/logging/__init__.py
CHANGED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from src.logging.logger import logging
|
models/weather-prediction/src/logging/logger.py
CHANGED
|
@@ -1,15 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import logging
|
| 2 |
import os
|
| 3 |
from datetime import datetime
|
| 4 |
|
| 5 |
-
LOG_FILE=f"{datetime.now().strftime('%m_%d_%Y_%H_%M_%S')}.log"
|
| 6 |
|
| 7 |
-
logs_path=os.path.join(os.getcwd(), "logs", LOG_FILE)
|
| 8 |
|
| 9 |
os.makedirs(logs_path, exist_ok=True)
|
| 10 |
-
# Create the file only if it is not created
|
| 11 |
|
| 12 |
-
LOG_FILE_PATH=os.path.join(logs_path, LOG_FILE)
|
| 13 |
|
| 14 |
logging.basicConfig(
|
| 15 |
filename=LOG_FILE_PATH,
|
|
@@ -17,4 +20,13 @@ logging.basicConfig(
|
|
| 17 |
level=logging.INFO
|
| 18 |
)
|
| 19 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Logging configuration for Weather Prediction pipeline.
|
| 3 |
+
Creates timestamped log files in the logs directory.
|
| 4 |
+
"""
|
| 5 |
import logging
|
| 6 |
import os
|
| 7 |
from datetime import datetime
|
| 8 |
|
| 9 |
+
LOG_FILE = f"{datetime.now().strftime('%m_%d_%Y_%H_%M_%S')}.log"
|
| 10 |
|
| 11 |
+
logs_path = os.path.join(os.getcwd(), "logs", LOG_FILE)
|
| 12 |
|
| 13 |
os.makedirs(logs_path, exist_ok=True)
|
|
|
|
| 14 |
|
| 15 |
+
LOG_FILE_PATH = os.path.join(logs_path, LOG_FILE)
|
| 16 |
|
| 17 |
logging.basicConfig(
|
| 18 |
filename=LOG_FILE_PATH,
|
|
|
|
| 20 |
level=logging.INFO
|
| 21 |
)
|
| 22 |
|
| 23 |
+
# Also add console handler for visibility
|
| 24 |
+
console_handler = logging.StreamHandler()
|
| 25 |
+
console_handler.setLevel(logging.INFO)
|
| 26 |
+
console_handler.setFormatter(logging.Formatter(
|
| 27 |
+
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
| 28 |
+
))
|
| 29 |
|
| 30 |
+
# Get root logger and add console handler
|
| 31 |
+
root_logger = logging.getLogger()
|
| 32 |
+
root_logger.addHandler(console_handler)
|
src/graphs/economicalAgentGraph.py
CHANGED
|
@@ -60,9 +60,9 @@ class EconomicalGraphBuilder:
|
|
| 60 |
|
| 61 |
main_graph = StateGraph(EconomicalAgentState)
|
| 62 |
|
| 63 |
-
main_graph.add_node("official_sources_module", official_subgraph.invoke)
|
| 64 |
-
main_graph.add_node("social_media_module", social_subgraph.invoke)
|
| 65 |
-
main_graph.add_node("feed_generation_module", feed_subgraph.invoke)
|
| 66 |
main_graph.add_node("feed_aggregator", node.aggregate_and_store_feeds)
|
| 67 |
|
| 68 |
main_graph.set_entry_point("official_sources_module")
|
|
|
|
| 60 |
|
| 61 |
main_graph = StateGraph(EconomicalAgentState)
|
| 62 |
|
| 63 |
+
main_graph.add_node("official_sources_module", lambda state: official_subgraph.invoke(state))
|
| 64 |
+
main_graph.add_node("social_media_module", lambda state: social_subgraph.invoke(state))
|
| 65 |
+
main_graph.add_node("feed_generation_module", lambda state: feed_subgraph.invoke(state))
|
| 66 |
main_graph.add_node("feed_aggregator", node.aggregate_and_store_feeds)
|
| 67 |
|
| 68 |
main_graph.set_entry_point("official_sources_module")
|
src/graphs/intelligenceAgentGraph.py
CHANGED
|
@@ -60,9 +60,9 @@ class IntelligenceGraphBuilder:
|
|
| 60 |
|
| 61 |
main_graph = StateGraph(IntelligenceAgentState)
|
| 62 |
|
| 63 |
-
main_graph.add_node("profile_monitoring_module", profile_subgraph.invoke)
|
| 64 |
-
main_graph.add_node("competitive_intelligence_module", intelligence_subgraph.invoke)
|
| 65 |
-
main_graph.add_node("feed_generation_module", feed_subgraph.invoke)
|
| 66 |
main_graph.add_node("feed_aggregator", node.aggregate_and_store_feeds)
|
| 67 |
|
| 68 |
main_graph.set_entry_point("profile_monitoring_module")
|
|
|
|
| 60 |
|
| 61 |
main_graph = StateGraph(IntelligenceAgentState)
|
| 62 |
|
| 63 |
+
main_graph.add_node("profile_monitoring_module", lambda state: profile_subgraph.invoke(state))
|
| 64 |
+
main_graph.add_node("competitive_intelligence_module", lambda state: intelligence_subgraph.invoke(state))
|
| 65 |
+
main_graph.add_node("feed_generation_module", lambda state: feed_subgraph.invoke(state))
|
| 66 |
main_graph.add_node("feed_aggregator", node.aggregate_and_store_feeds)
|
| 67 |
|
| 68 |
main_graph.set_entry_point("profile_monitoring_module")
|
src/graphs/meteorologicalAgentGraph.py
CHANGED
|
@@ -60,9 +60,9 @@ class MeteorologicalGraphBuilder:
|
|
| 60 |
|
| 61 |
main_graph = StateGraph(MeteorologicalAgentState)
|
| 62 |
|
| 63 |
-
main_graph.add_node("official_sources_module", official_subgraph.invoke)
|
| 64 |
-
main_graph.add_node("social_media_module", social_subgraph.invoke)
|
| 65 |
-
main_graph.add_node("feed_generation_module", feed_subgraph.invoke)
|
| 66 |
main_graph.add_node("feed_aggregator", node.aggregate_and_store_feeds)
|
| 67 |
|
| 68 |
main_graph.set_entry_point("official_sources_module")
|
|
|
|
| 60 |
|
| 61 |
main_graph = StateGraph(MeteorologicalAgentState)
|
| 62 |
|
| 63 |
+
main_graph.add_node("official_sources_module", lambda state: official_subgraph.invoke(state))
|
| 64 |
+
main_graph.add_node("social_media_module", lambda state: social_subgraph.invoke(state))
|
| 65 |
+
main_graph.add_node("feed_generation_module", lambda state: feed_subgraph.invoke(state))
|
| 66 |
main_graph.add_node("feed_aggregator", node.aggregate_and_store_feeds)
|
| 67 |
|
| 68 |
main_graph.set_entry_point("official_sources_module")
|
src/graphs/politicalAgentGraph.py
CHANGED
|
@@ -59,9 +59,9 @@ class PoliticalGraphBuilder:
|
|
| 59 |
|
| 60 |
main_graph = StateGraph(PoliticalAgentState)
|
| 61 |
|
| 62 |
-
main_graph.add_node("official_sources_module", official_subgraph.invoke)
|
| 63 |
-
main_graph.add_node("social_media_module", social_subgraph.invoke)
|
| 64 |
-
main_graph.add_node("feed_generation_module", feed_subgraph.invoke)
|
| 65 |
main_graph.add_node("feed_aggregator", node.aggregate_and_store_feeds)
|
| 66 |
|
| 67 |
main_graph.set_entry_point("official_sources_module")
|
|
|
|
| 59 |
|
| 60 |
main_graph = StateGraph(PoliticalAgentState)
|
| 61 |
|
| 62 |
+
main_graph.add_node("official_sources_module", lambda state: official_subgraph.invoke(state))
|
| 63 |
+
main_graph.add_node("social_media_module", lambda state: social_subgraph.invoke(state))
|
| 64 |
+
main_graph.add_node("feed_generation_module", lambda state: feed_subgraph.invoke(state))
|
| 65 |
main_graph.add_node("feed_aggregator", node.aggregate_and_store_feeds)
|
| 66 |
|
| 67 |
main_graph.set_entry_point("official_sources_module")
|
src/graphs/socialAgentGraph.py
CHANGED
|
@@ -51,25 +51,39 @@ class SocialGraphBuilder:
|
|
| 51 |
|
| 52 |
return subgraph.compile()
|
| 53 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
def build_graph(self):
|
| 55 |
node = SocialAgentNode(self.llm)
|
| 56 |
|
| 57 |
trending_subgraph = self.build_trending_subgraph(node)
|
| 58 |
social_subgraph = self.build_social_media_subgraph(node)
|
|
|
|
| 59 |
feed_subgraph = self.build_feed_generation_subgraph(node)
|
| 60 |
|
| 61 |
main_graph = StateGraph(SocialAgentState)
|
| 62 |
|
| 63 |
-
main_graph.add_node("trending_module", trending_subgraph.invoke)
|
| 64 |
-
main_graph.add_node("social_media_module", social_subgraph.invoke)
|
| 65 |
-
main_graph.add_node("
|
|
|
|
| 66 |
main_graph.add_node("feed_aggregator", node.aggregate_and_store_feeds)
|
| 67 |
|
|
|
|
| 68 |
main_graph.set_entry_point("trending_module")
|
| 69 |
main_graph.set_entry_point("social_media_module")
|
|
|
|
| 70 |
|
|
|
|
| 71 |
main_graph.add_edge("trending_module", "feed_generation_module")
|
| 72 |
main_graph.add_edge("social_media_module", "feed_generation_module")
|
|
|
|
| 73 |
main_graph.add_edge("feed_generation_module", "feed_aggregator")
|
| 74 |
main_graph.add_edge("feed_aggregator", END)
|
| 75 |
|
|
|
|
| 51 |
|
| 52 |
return subgraph.compile()
|
| 53 |
|
| 54 |
+
def build_user_targets_subgraph(self, node: SocialAgentNode) -> StateGraph:
|
| 55 |
+
"""Build subgraph for user-defined keywords and profiles."""
|
| 56 |
+
subgraph = StateGraph(SocialAgentState)
|
| 57 |
+
subgraph.add_node("collect_user_targets", node.collect_user_defined_targets)
|
| 58 |
+
subgraph.set_entry_point("collect_user_targets")
|
| 59 |
+
subgraph.add_edge("collect_user_targets", END)
|
| 60 |
+
return subgraph.compile()
|
| 61 |
+
|
| 62 |
def build_graph(self):
|
| 63 |
node = SocialAgentNode(self.llm)
|
| 64 |
|
| 65 |
trending_subgraph = self.build_trending_subgraph(node)
|
| 66 |
social_subgraph = self.build_social_media_subgraph(node)
|
| 67 |
+
user_targets_subgraph = self.build_user_targets_subgraph(node)
|
| 68 |
feed_subgraph = self.build_feed_generation_subgraph(node)
|
| 69 |
|
| 70 |
main_graph = StateGraph(SocialAgentState)
|
| 71 |
|
| 72 |
+
main_graph.add_node("trending_module", lambda state: trending_subgraph.invoke(state))
|
| 73 |
+
main_graph.add_node("social_media_module", lambda state: social_subgraph.invoke(state))
|
| 74 |
+
main_graph.add_node("user_targets_module", lambda state: user_targets_subgraph.invoke(state))
|
| 75 |
+
main_graph.add_node("feed_generation_module", lambda state: feed_subgraph.invoke(state))
|
| 76 |
main_graph.add_node("feed_aggregator", node.aggregate_and_store_feeds)
|
| 77 |
|
| 78 |
+
# Parallel entry points - all 3 modules start together
|
| 79 |
main_graph.set_entry_point("trending_module")
|
| 80 |
main_graph.set_entry_point("social_media_module")
|
| 81 |
+
main_graph.set_entry_point("user_targets_module")
|
| 82 |
|
| 83 |
+
# All modules converge to feed generation
|
| 84 |
main_graph.add_edge("trending_module", "feed_generation_module")
|
| 85 |
main_graph.add_edge("social_media_module", "feed_generation_module")
|
| 86 |
+
main_graph.add_edge("user_targets_module", "feed_generation_module")
|
| 87 |
main_graph.add_edge("feed_generation_module", "feed_aggregator")
|
| 88 |
main_graph.add_edge("feed_aggregator", END)
|
| 89 |
|
src/nodes/socialAgentNode.py
CHANGED
|
@@ -5,23 +5,44 @@ Monitors trending topics, events, people, social intelligence across geographic
|
|
| 5 |
|
| 6 |
Updated: Uses Tool Factory pattern for parallel execution safety.
|
| 7 |
Each agent instance gets its own private set of tools.
|
|
|
|
|
|
|
| 8 |
"""
|
| 9 |
|
| 10 |
import json
|
| 11 |
import uuid
|
| 12 |
-
|
|
|
|
| 13 |
from datetime import datetime
|
| 14 |
from src.states.socialAgentState import SocialAgentState
|
| 15 |
from src.utils.tool_factory import create_tool_set
|
| 16 |
from src.llms.groqllm import GroqLLM
|
| 17 |
|
| 18 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
class SocialAgentNode:
|
| 20 |
"""
|
| 21 |
Modular Social Agent - Geographic social intelligence collection.
|
| 22 |
Module 1: Trending Topics (Sri Lanka specific trends)
|
| 23 |
Module 2: Social Media (Sri Lanka, Asia, World scopes)
|
| 24 |
Module 3: Feed Generation (Categorize, Summarize, Format)
|
|
|
|
| 25 |
|
| 26 |
Thread Safety:
|
| 27 |
Each SocialAgentNode instance creates its own private ToolSet,
|
|
@@ -40,6 +61,15 @@ class SocialAgentNode:
|
|
| 40 |
else:
|
| 41 |
self.llm = llm
|
| 42 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
# Geographic scopes
|
| 44 |
self.geographic_scopes = {
|
| 45 |
"sri_lanka": ["sri lanka", "colombo", "srilanka"],
|
|
@@ -375,6 +405,111 @@ class SocialAgentNode:
|
|
| 375 |
|
| 376 |
return {"worker_results": world_results, "social_media_results": world_results}
|
| 377 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 378 |
# ============================================
|
| 379 |
# MODULE 3: FEED GENERATION
|
| 380 |
# ============================================
|
|
|
|
| 5 |
|
| 6 |
Updated: Uses Tool Factory pattern for parallel execution safety.
|
| 7 |
Each agent instance gets its own private set of tools.
|
| 8 |
+
|
| 9 |
+
Updated: Now loads user-defined keywords and profiles from intel config.
|
| 10 |
"""
|
| 11 |
|
| 12 |
import json
|
| 13 |
import uuid
|
| 14 |
+
import os
|
| 15 |
+
from typing import Dict, Any, List
|
| 16 |
from datetime import datetime
|
| 17 |
from src.states.socialAgentState import SocialAgentState
|
| 18 |
from src.utils.tool_factory import create_tool_set
|
| 19 |
from src.llms.groqllm import GroqLLM
|
| 20 |
|
| 21 |
|
| 22 |
+
def load_intel_config() -> dict:
|
| 23 |
+
"""Load intel config from JSON file (same as main.py)."""
|
| 24 |
+
config_path = os.path.join(os.path.dirname(__file__), "..", "..", "data", "intel_config.json")
|
| 25 |
+
default_config = {
|
| 26 |
+
"user_profiles": {"twitter": [], "facebook": [], "linkedin": []},
|
| 27 |
+
"user_keywords": [],
|
| 28 |
+
"user_products": []
|
| 29 |
+
}
|
| 30 |
+
try:
|
| 31 |
+
if os.path.exists(config_path):
|
| 32 |
+
with open(config_path, "r", encoding="utf-8") as f:
|
| 33 |
+
return json.load(f)
|
| 34 |
+
except Exception:
|
| 35 |
+
pass
|
| 36 |
+
return default_config
|
| 37 |
+
|
| 38 |
+
|
| 39 |
class SocialAgentNode:
|
| 40 |
"""
|
| 41 |
Modular Social Agent - Geographic social intelligence collection.
|
| 42 |
Module 1: Trending Topics (Sri Lanka specific trends)
|
| 43 |
Module 2: Social Media (Sri Lanka, Asia, World scopes)
|
| 44 |
Module 3: Feed Generation (Categorize, Summarize, Format)
|
| 45 |
+
Module 4: User-Defined Keywords & Profiles (from frontend config)
|
| 46 |
|
| 47 |
Thread Safety:
|
| 48 |
Each SocialAgentNode instance creates its own private ToolSet,
|
|
|
|
| 61 |
else:
|
| 62 |
self.llm = llm
|
| 63 |
|
| 64 |
+
# Load user-defined intel config (keywords, profiles, products)
|
| 65 |
+
self.intel_config = load_intel_config()
|
| 66 |
+
self.user_keywords = self.intel_config.get("user_keywords", [])
|
| 67 |
+
self.user_profiles = self.intel_config.get("user_profiles", {})
|
| 68 |
+
self.user_products = self.intel_config.get("user_products", [])
|
| 69 |
+
|
| 70 |
+
print(f"[SocialAgent] Loaded {len(self.user_keywords)} user keywords, "
|
| 71 |
+
f"{sum(len(v) for v in self.user_profiles.values())} profiles")
|
| 72 |
+
|
| 73 |
# Geographic scopes
|
| 74 |
self.geographic_scopes = {
|
| 75 |
"sri_lanka": ["sri lanka", "colombo", "srilanka"],
|
|
|
|
| 405 |
|
| 406 |
return {"worker_results": world_results, "social_media_results": world_results}
|
| 407 |
|
| 408 |
+
def collect_user_defined_targets(self, state: SocialAgentState) -> Dict[str, Any]:
|
| 409 |
+
"""
|
| 410 |
+
Module 2D: Collect data for USER-DEFINED keywords and profiles.
|
| 411 |
+
These are configured via the frontend Intelligence Settings UI.
|
| 412 |
+
"""
|
| 413 |
+
print("[MODULE 2D] Collecting User-Defined Targets")
|
| 414 |
+
|
| 415 |
+
user_results = []
|
| 416 |
+
|
| 417 |
+
# Reload config to get latest user settings
|
| 418 |
+
self.intel_config = load_intel_config()
|
| 419 |
+
self.user_keywords = self.intel_config.get("user_keywords", [])
|
| 420 |
+
self.user_profiles = self.intel_config.get("user_profiles", {})
|
| 421 |
+
self.user_products = self.intel_config.get("user_products", [])
|
| 422 |
+
|
| 423 |
+
# Skip if no user config
|
| 424 |
+
if not self.user_keywords and not any(self.user_profiles.values()):
|
| 425 |
+
print(" ⏭️ No user-defined targets configured")
|
| 426 |
+
return {"worker_results": [], "user_target_results": []}
|
| 427 |
+
|
| 428 |
+
# ============================================
|
| 429 |
+
# Scrape USER KEYWORDS across Twitter
|
| 430 |
+
# ============================================
|
| 431 |
+
if self.user_keywords:
|
| 432 |
+
print(f" 📝 Scraping {len(self.user_keywords)} user keywords...")
|
| 433 |
+
twitter_tool = self.tools.get("scrape_twitter")
|
| 434 |
+
|
| 435 |
+
for keyword in self.user_keywords[:10]: # Limit to 10 keywords
|
| 436 |
+
try:
|
| 437 |
+
if twitter_tool:
|
| 438 |
+
twitter_data = twitter_tool.invoke(
|
| 439 |
+
{"query": keyword, "max_items": 5}
|
| 440 |
+
)
|
| 441 |
+
user_results.append({
|
| 442 |
+
"source_tool": "scrape_twitter",
|
| 443 |
+
"raw_content": str(twitter_data),
|
| 444 |
+
"category": "user_keyword",
|
| 445 |
+
"scope": "sri_lanka",
|
| 446 |
+
"platform": "twitter",
|
| 447 |
+
"keyword": keyword,
|
| 448 |
+
"timestamp": datetime.utcnow().isoformat(),
|
| 449 |
+
})
|
| 450 |
+
print(f" ✓ Keyword: '{keyword}'")
|
| 451 |
+
except Exception as e:
|
| 452 |
+
print(f" ⚠️ Keyword '{keyword}' error: {e}")
|
| 453 |
+
|
| 454 |
+
# ============================================
|
| 455 |
+
# Scrape USER PRODUCTS
|
| 456 |
+
# ============================================
|
| 457 |
+
if self.user_products:
|
| 458 |
+
print(f" 📦 Scraping {len(self.user_products)} user products...")
|
| 459 |
+
twitter_tool = self.tools.get("scrape_twitter")
|
| 460 |
+
|
| 461 |
+
for product in self.user_products[:5]: # Limit to 5 products
|
| 462 |
+
try:
|
| 463 |
+
if twitter_tool:
|
| 464 |
+
twitter_data = twitter_tool.invoke(
|
| 465 |
+
{"query": f"{product} review OR {product} Sri Lanka", "max_items": 3}
|
| 466 |
+
)
|
| 467 |
+
user_results.append({
|
| 468 |
+
"source_tool": "scrape_twitter",
|
| 469 |
+
"raw_content": str(twitter_data),
|
| 470 |
+
"category": "user_product",
|
| 471 |
+
"scope": "sri_lanka",
|
| 472 |
+
"platform": "twitter",
|
| 473 |
+
"product": product,
|
| 474 |
+
"timestamp": datetime.utcnow().isoformat(),
|
| 475 |
+
})
|
| 476 |
+
print(f" ✓ Product: '{product}'")
|
| 477 |
+
except Exception as e:
|
| 478 |
+
print(f" ⚠️ Product '{product}' error: {e}")
|
| 479 |
+
|
| 480 |
+
# ============================================
|
| 481 |
+
# Scrape USER TWITTER PROFILES
|
| 482 |
+
# ============================================
|
| 483 |
+
twitter_profiles = self.user_profiles.get("twitter", [])
|
| 484 |
+
if twitter_profiles:
|
| 485 |
+
print(f" 👤 Scraping {len(twitter_profiles)} Twitter profiles...")
|
| 486 |
+
twitter_tool = self.tools.get("scrape_twitter")
|
| 487 |
+
|
| 488 |
+
for profile in twitter_profiles[:10]: # Limit to 10 profiles
|
| 489 |
+
try:
|
| 490 |
+
# Clean profile handle
|
| 491 |
+
handle = profile.replace("@", "").strip()
|
| 492 |
+
if twitter_tool:
|
| 493 |
+
# Search for tweets mentioning this profile
|
| 494 |
+
twitter_data = twitter_tool.invoke(
|
| 495 |
+
{"query": f"from:{handle} OR @{handle}", "max_items": 5}
|
| 496 |
+
)
|
| 497 |
+
user_results.append({
|
| 498 |
+
"source_tool": "scrape_twitter",
|
| 499 |
+
"raw_content": str(twitter_data),
|
| 500 |
+
"category": "user_profile",
|
| 501 |
+
"scope": "sri_lanka",
|
| 502 |
+
"platform": "twitter",
|
| 503 |
+
"profile": f"@{handle}",
|
| 504 |
+
"timestamp": datetime.utcnow().isoformat(),
|
| 505 |
+
})
|
| 506 |
+
print(f" ✓ Profile: @{handle}")
|
| 507 |
+
except Exception as e:
|
| 508 |
+
print(f" ⚠️ Profile @{profile} error: {e}")
|
| 509 |
+
|
| 510 |
+
print(f" ✅ User targets: {len(user_results)} results collected")
|
| 511 |
+
return {"worker_results": user_results, "user_target_results": user_results}
|
| 512 |
+
|
| 513 |
# ============================================
|
| 514 |
# MODULE 3: FEED GENERATION
|
| 515 |
# ============================================
|
src/rag.py
CHANGED
|
@@ -42,6 +42,200 @@ except ImportError:
|
|
| 42 |
LANGCHAIN_AVAILABLE = False
|
| 43 |
logger.warning("[RAG] LangChain not available")
|
| 44 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
|
| 46 |
class MultiCollectionRetriever:
|
| 47 |
COLLECTIONS = ["Roger_feeds"]
|
|
@@ -52,6 +246,10 @@ class MultiCollectionRetriever:
|
|
| 52 |
)
|
| 53 |
self.client = None
|
| 54 |
self.collections: Dict[str, Any] = {}
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
|
| 56 |
if not CHROMA_AVAILABLE:
|
| 57 |
logger.error("[RAG] ChromaDB not installed")
|
|
@@ -90,43 +288,68 @@ class MultiCollectionRetriever:
|
|
| 90 |
logger.error(f"[RAG] ChromaDB initialization error: {e}")
|
| 91 |
self.client = None
|
| 92 |
|
| 93 |
-
def
|
| 94 |
-
self, query: str, n_results: int
|
| 95 |
) -> List[Dict[str, Any]]:
|
| 96 |
-
|
| 97 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 98 |
|
| 99 |
-
|
|
|
|
|
|
|
| 100 |
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
|
| 107 |
-
|
| 108 |
-
query_texts=[query], n_results=n_results, where=where_filter
|
| 109 |
-
)
|
| 110 |
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
|
|
|
|
|
|
|
|
|
| 116 |
|
| 117 |
-
|
|
|
|
|
|
|
|
|
|
| 118 |
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
"domain": meta.get("domain", "unknown"),
|
| 126 |
-
})
|
| 127 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 128 |
except Exception as e:
|
| 129 |
-
logger.warning(f"[RAG]
|
| 130 |
|
| 131 |
all_results.sort(key=lambda x: x["similarity"], reverse=True)
|
| 132 |
return all_results[: n_results * 2]
|
|
@@ -152,6 +375,9 @@ class MultiCollectionRetriever:
|
|
| 152 |
class RogerRAG:
|
| 153 |
def __init__(self):
|
| 154 |
self.retriever = MultiCollectionRetriever()
|
|
|
|
|
|
|
|
|
|
| 155 |
self.llm = None
|
| 156 |
self.chat_history: List[Tuple[str, str]] = []
|
| 157 |
|
|
@@ -165,29 +391,51 @@ class RogerRAG:
|
|
| 165 |
logger.error("[RAG] GROQ_API_KEY not set")
|
| 166 |
return
|
| 167 |
|
|
|
|
| 168 |
self.llm = ChatGroq(
|
| 169 |
api_key=api_key,
|
| 170 |
-
model="
|
| 171 |
temperature=0.3,
|
| 172 |
max_tokens=1024,
|
|
|
|
| 173 |
)
|
| 174 |
-
logger.info("[RAG] Groq LLM initialized")
|
| 175 |
|
| 176 |
except Exception as e:
|
| 177 |
logger.error(f"[RAG] LLM initialization error: {e}")
|
| 178 |
|
| 179 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 180 |
if not docs:
|
| 181 |
return "No relevant intelligence data found."
|
| 182 |
|
| 183 |
context_parts = []
|
| 184 |
now = datetime.now()
|
|
|
|
|
|
|
|
|
|
|
|
|
| 185 |
|
| 186 |
-
|
|
|
|
| 187 |
meta = doc.get("metadata", {})
|
| 188 |
-
domain = meta.get("domain", "unknown")
|
| 189 |
platform = meta.get("platform", "")
|
| 190 |
-
timestamp = meta.get("timestamp", "")
|
| 191 |
|
| 192 |
age_str = "unknown date"
|
| 193 |
if timestamp:
|
|
@@ -199,7 +447,7 @@ class RogerRAG:
|
|
| 199 |
"%d/%m/%Y",
|
| 200 |
]:
|
| 201 |
try:
|
| 202 |
-
ts_date = datetime.strptime(timestamp[:19], fmt)
|
| 203 |
days_old = (now - ts_date).days
|
| 204 |
if days_old == 0:
|
| 205 |
age_str = "TODAY"
|
|
@@ -224,6 +472,22 @@ class RogerRAG:
|
|
| 224 |
f"TIMESTAMP: {timestamp} ({age_str})\n"
|
| 225 |
f"{doc['content']}\n"
|
| 226 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 227 |
|
| 228 |
return "\n---\n".join(context_parts)
|
| 229 |
|
|
@@ -266,11 +530,31 @@ class RogerRAG:
|
|
| 266 |
if use_history and self.chat_history:
|
| 267 |
search_question = self._reformulate_question(question)
|
| 268 |
|
|
|
|
| 269 |
docs = self.retriever.search(
|
| 270 |
search_question, n_results=5, domain_filter=domain_filter
|
| 271 |
)
|
| 272 |
-
|
| 273 |
-
if
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 274 |
return {
|
| 275 |
"answer": "I couldn't find any relevant intelligence data to answer your question.",
|
| 276 |
"sources": [],
|
|
@@ -278,7 +562,7 @@ class RogerRAG:
|
|
| 278 |
"reformulated": search_question if search_question != question else None,
|
| 279 |
}
|
| 280 |
|
| 281 |
-
context = self._format_context(
|
| 282 |
|
| 283 |
if not self.llm:
|
| 284 |
return {
|
|
|
|
| 42 |
LANGCHAIN_AVAILABLE = False
|
| 43 |
logger.warning("[RAG] LangChain not available")
|
| 44 |
|
| 45 |
+
# Neo4j for graph-based retrieval
|
| 46 |
+
try:
|
| 47 |
+
from neo4j import GraphDatabase
|
| 48 |
+
NEO4J_AVAILABLE = True
|
| 49 |
+
except ImportError:
|
| 50 |
+
NEO4J_AVAILABLE = False
|
| 51 |
+
logger.warning("[RAG] Neo4j not available")
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
# Keywords that indicate a graph/relationship query
|
| 55 |
+
GRAPH_KEYWORDS = [
|
| 56 |
+
"connected", "related", "timeline", "before", "after",
|
| 57 |
+
"caused by", "followed by", "similar to", "linked",
|
| 58 |
+
"what happened", "sequence", "chain of events"
|
| 59 |
+
]
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def is_graph_query(question: str) -> bool:
|
| 63 |
+
"""Detect if question requires graph traversal."""
|
| 64 |
+
q_lower = question.lower()
|
| 65 |
+
return any(kw in q_lower for kw in GRAPH_KEYWORDS)
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
class Neo4jRetriever:
|
| 69 |
+
"""Graph-based retrieval for relationship queries with LAZY initialization."""
|
| 70 |
+
|
| 71 |
+
def __init__(self):
|
| 72 |
+
self.driver = None
|
| 73 |
+
self._initialized = False
|
| 74 |
+
self._init_attempted = False
|
| 75 |
+
|
| 76 |
+
def _lazy_init(self):
|
| 77 |
+
"""Lazy initialization - only connect when actually needed."""
|
| 78 |
+
if self._init_attempted:
|
| 79 |
+
return self.driver is not None
|
| 80 |
+
|
| 81 |
+
self._init_attempted = True
|
| 82 |
+
|
| 83 |
+
if not NEO4J_AVAILABLE:
|
| 84 |
+
logger.info("[Neo4jRetriever] Neo4j package not installed")
|
| 85 |
+
return False
|
| 86 |
+
|
| 87 |
+
neo4j_uri = os.getenv("NEO4J_URI", "")
|
| 88 |
+
neo4j_user = os.getenv("NEO4J_USER", "neo4j")
|
| 89 |
+
neo4j_password = os.getenv("NEO4J_PASSWORD", "")
|
| 90 |
+
|
| 91 |
+
if not neo4j_uri or not neo4j_password:
|
| 92 |
+
logger.info("[Neo4jRetriever] Neo4j credentials not configured - skipping")
|
| 93 |
+
return False
|
| 94 |
+
|
| 95 |
+
try:
|
| 96 |
+
self.driver = GraphDatabase.driver(
|
| 97 |
+
neo4j_uri, auth=(neo4j_user, neo4j_password)
|
| 98 |
+
)
|
| 99 |
+
self.driver.verify_connectivity()
|
| 100 |
+
self._initialized = True
|
| 101 |
+
logger.info(f"[Neo4jRetriever] Connected to {neo4j_uri}")
|
| 102 |
+
return True
|
| 103 |
+
except Exception as e:
|
| 104 |
+
logger.warning(f"[Neo4jRetriever] Connection failed (will skip graph queries): {e}")
|
| 105 |
+
self.driver = None
|
| 106 |
+
return False
|
| 107 |
+
|
| 108 |
+
def get_related_events(self, keyword: str, limit: int = 5) -> List[Dict[str, Any]]:
|
| 109 |
+
"""Find events containing keyword and their related events."""
|
| 110 |
+
if not self._lazy_init():
|
| 111 |
+
return []
|
| 112 |
+
|
| 113 |
+
try:
|
| 114 |
+
with self.driver.session() as session:
|
| 115 |
+
query = """
|
| 116 |
+
MATCH (e:Event)
|
| 117 |
+
WHERE toLower(e.summary) CONTAINS toLower($keyword)
|
| 118 |
+
OPTIONAL MATCH (e)-[:SIMILAR_TO]-(related:Event)
|
| 119 |
+
RETURN e.event_id as event_id,
|
| 120 |
+
e.summary as summary,
|
| 121 |
+
e.domain as domain,
|
| 122 |
+
e.severity as severity,
|
| 123 |
+
e.timestamp as timestamp,
|
| 124 |
+
COLLECT(DISTINCT related.summary)[0..3] as related_summaries
|
| 125 |
+
ORDER BY e.timestamp DESC
|
| 126 |
+
LIMIT $limit
|
| 127 |
+
"""
|
| 128 |
+
results = session.run(query, keyword=keyword, limit=limit)
|
| 129 |
+
|
| 130 |
+
events = []
|
| 131 |
+
for record in results:
|
| 132 |
+
events.append({
|
| 133 |
+
"event_id": record["event_id"],
|
| 134 |
+
"content": record["summary"],
|
| 135 |
+
"domain": record["domain"],
|
| 136 |
+
"severity": record["severity"],
|
| 137 |
+
"timestamp": record["timestamp"],
|
| 138 |
+
"related": record["related_summaries"],
|
| 139 |
+
"source": "neo4j_graph"
|
| 140 |
+
})
|
| 141 |
+
|
| 142 |
+
logger.info(f"[Neo4jRetriever] Found {len(events)} events for '{keyword}'")
|
| 143 |
+
return events
|
| 144 |
+
|
| 145 |
+
except Exception as e:
|
| 146 |
+
logger.error(f"[Neo4jRetriever] Query error: {e}")
|
| 147 |
+
return []
|
| 148 |
+
|
| 149 |
+
def get_domain_events(self, domain: str, limit: int = 5) -> List[Dict[str, Any]]:
|
| 150 |
+
"""Get recent events by domain with relationships."""
|
| 151 |
+
if not self._lazy_init():
|
| 152 |
+
return []
|
| 153 |
+
|
| 154 |
+
try:
|
| 155 |
+
with self.driver.session() as session:
|
| 156 |
+
query = """
|
| 157 |
+
MATCH (e:Event)-[:BELONGS_TO]->(d:Domain {name: $domain})
|
| 158 |
+
OPTIONAL MATCH (e)-[:SIMILAR_TO]-(related:Event)
|
| 159 |
+
RETURN e.event_id as event_id,
|
| 160 |
+
e.summary as summary,
|
| 161 |
+
e.severity as severity,
|
| 162 |
+
e.timestamp as timestamp,
|
| 163 |
+
COUNT(related) as related_count
|
| 164 |
+
ORDER BY e.timestamp DESC
|
| 165 |
+
LIMIT $limit
|
| 166 |
+
"""
|
| 167 |
+
results = session.run(query, domain=domain.lower(), limit=limit)
|
| 168 |
+
|
| 169 |
+
events = []
|
| 170 |
+
for record in results:
|
| 171 |
+
events.append({
|
| 172 |
+
"event_id": record["event_id"],
|
| 173 |
+
"content": record["summary"],
|
| 174 |
+
"domain": domain,
|
| 175 |
+
"severity": record["severity"],
|
| 176 |
+
"timestamp": record["timestamp"],
|
| 177 |
+
"related_count": record["related_count"],
|
| 178 |
+
"source": "neo4j_graph"
|
| 179 |
+
})
|
| 180 |
+
|
| 181 |
+
return events
|
| 182 |
+
|
| 183 |
+
except Exception as e:
|
| 184 |
+
logger.error(f"[Neo4jRetriever] Domain query error: {e}")
|
| 185 |
+
return []
|
| 186 |
+
|
| 187 |
+
def get_event_chain(self, keyword: str, depth: int = 3) -> List[Dict[str, Any]]:
|
| 188 |
+
"""Get temporal chain of related events."""
|
| 189 |
+
if not self._lazy_init():
|
| 190 |
+
return []
|
| 191 |
+
|
| 192 |
+
try:
|
| 193 |
+
with self.driver.session() as session:
|
| 194 |
+
query = """
|
| 195 |
+
MATCH (start:Event)
|
| 196 |
+
WHERE toLower(start.summary) CONTAINS toLower($keyword)
|
| 197 |
+
OPTIONAL MATCH path = (start)-[:FOLLOWS|SIMILAR_TO*1..3]-(chain:Event)
|
| 198 |
+
WITH start, COLLECT(DISTINCT chain) as chain_events
|
| 199 |
+
RETURN start.event_id as start_id,
|
| 200 |
+
start.summary as start_summary,
|
| 201 |
+
start.timestamp as start_time,
|
| 202 |
+
[e IN chain_events | {summary: e.summary, time: e.timestamp}][0..5] as chain
|
| 203 |
+
LIMIT 1
|
| 204 |
+
"""
|
| 205 |
+
result = session.run(query, keyword=keyword).single()
|
| 206 |
+
|
| 207 |
+
if result:
|
| 208 |
+
return [{
|
| 209 |
+
"event_id": result["start_id"],
|
| 210 |
+
"content": result["start_summary"],
|
| 211 |
+
"timestamp": result["start_time"],
|
| 212 |
+
"chain": result["chain"],
|
| 213 |
+
"source": "neo4j_chain"
|
| 214 |
+
}]
|
| 215 |
+
return []
|
| 216 |
+
|
| 217 |
+
except Exception as e:
|
| 218 |
+
logger.error(f"[Neo4jRetriever] Chain query error: {e}")
|
| 219 |
+
return []
|
| 220 |
+
|
| 221 |
+
def get_stats(self) -> Dict[str, Any]:
|
| 222 |
+
"""Get Neo4j graph statistics."""
|
| 223 |
+
if not self._initialized or not self.driver:
|
| 224 |
+
return {"status": "not_initialized" if not self._init_attempted else "disconnected"}
|
| 225 |
+
|
| 226 |
+
try:
|
| 227 |
+
with self.driver.session() as session:
|
| 228 |
+
event_count = session.run(
|
| 229 |
+
"MATCH (e:Event) RETURN COUNT(e) as count"
|
| 230 |
+
).single()["count"]
|
| 231 |
+
|
| 232 |
+
return {
|
| 233 |
+
"status": "connected",
|
| 234 |
+
"total_events": event_count
|
| 235 |
+
}
|
| 236 |
+
except Exception as e:
|
| 237 |
+
return {"status": "error", "error": str(e)}
|
| 238 |
+
|
| 239 |
|
| 240 |
class MultiCollectionRetriever:
|
| 241 |
COLLECTIONS = ["Roger_feeds"]
|
|
|
|
| 246 |
)
|
| 247 |
self.client = None
|
| 248 |
self.collections: Dict[str, Any] = {}
|
| 249 |
+
|
| 250 |
+
# Thread pool for parallel queries
|
| 251 |
+
from concurrent.futures import ThreadPoolExecutor
|
| 252 |
+
self._executor = ThreadPoolExecutor(max_workers=4)
|
| 253 |
|
| 254 |
if not CHROMA_AVAILABLE:
|
| 255 |
logger.error("[RAG] ChromaDB not installed")
|
|
|
|
| 288 |
logger.error(f"[RAG] ChromaDB initialization error: {e}")
|
| 289 |
self.client = None
|
| 290 |
|
| 291 |
+
def _query_single_collection(
|
| 292 |
+
self, name: str, collection, query: str, n_results: int, domain_filter: Optional[str]
|
| 293 |
) -> List[Dict[str, Any]]:
|
| 294 |
+
"""Query a single collection - used for parallel execution."""
|
| 295 |
+
results_list = []
|
| 296 |
+
try:
|
| 297 |
+
where_filter = None
|
| 298 |
+
if domain_filter:
|
| 299 |
+
where_filter = {"domain": domain_filter.lower()}
|
| 300 |
|
| 301 |
+
results = collection.query(
|
| 302 |
+
query_texts=[query], n_results=n_results, where=where_filter
|
| 303 |
+
)
|
| 304 |
|
| 305 |
+
if results["ids"] and results["ids"][0]:
|
| 306 |
+
for i, doc_id in enumerate(results["ids"][0]):
|
| 307 |
+
doc = results["documents"][0][i] if results["documents"] else ""
|
| 308 |
+
meta = results["metadatas"][0][i] if results["metadatas"] else {}
|
| 309 |
+
distance = results["distances"][0][i] if results["distances"] else 0
|
| 310 |
|
| 311 |
+
similarity = 1.0 - min(distance / 2.0, 1.0)
|
|
|
|
|
|
|
| 312 |
|
| 313 |
+
results_list.append({
|
| 314 |
+
"id": doc_id,
|
| 315 |
+
"content": doc,
|
| 316 |
+
"metadata": meta,
|
| 317 |
+
"similarity": similarity,
|
| 318 |
+
"collection": name,
|
| 319 |
+
"domain": meta.get("domain", "unknown"),
|
| 320 |
+
})
|
| 321 |
|
| 322 |
+
except Exception as e:
|
| 323 |
+
logger.warning(f"[RAG] Error querying {name}: {e}")
|
| 324 |
+
|
| 325 |
+
return results_list
|
| 326 |
|
| 327 |
+
def search(
|
| 328 |
+
self, query: str, n_results: int = 5, domain_filter: Optional[str] = None
|
| 329 |
+
) -> List[Dict[str, Any]]:
|
| 330 |
+
"""Search all collections in PARALLEL for faster results."""
|
| 331 |
+
if not self.client:
|
| 332 |
+
return []
|
|
|
|
|
|
|
| 333 |
|
| 334 |
+
# Submit parallel queries to all collections
|
| 335 |
+
from concurrent.futures import as_completed
|
| 336 |
+
|
| 337 |
+
futures = {}
|
| 338 |
+
for name, collection in self.collections.items():
|
| 339 |
+
future = self._executor.submit(
|
| 340 |
+
self._query_single_collection,
|
| 341 |
+
name, collection, query, n_results, domain_filter
|
| 342 |
+
)
|
| 343 |
+
futures[future] = name
|
| 344 |
+
|
| 345 |
+
# Collect results as they complete (fastest first)
|
| 346 |
+
all_results = []
|
| 347 |
+
for future in as_completed(futures, timeout=10.0): # 10s timeout
|
| 348 |
+
try:
|
| 349 |
+
results = future.result()
|
| 350 |
+
all_results.extend(results)
|
| 351 |
except Exception as e:
|
| 352 |
+
logger.warning(f"[RAG] Parallel query failed for {futures[future]}: {e}")
|
| 353 |
|
| 354 |
all_results.sort(key=lambda x: x["similarity"], reverse=True)
|
| 355 |
return all_results[: n_results * 2]
|
|
|
|
| 375 |
class RogerRAG:
|
| 376 |
def __init__(self):
|
| 377 |
self.retriever = MultiCollectionRetriever()
|
| 378 |
+
# Neo4j disabled for faster startup - uncomment when graph DB is configured
|
| 379 |
+
# self.neo4j_retriever = Neo4jRetriever() # Graph-based retrieval
|
| 380 |
+
self.neo4j_retriever = None # Disabled
|
| 381 |
self.llm = None
|
| 382 |
self.chat_history: List[Tuple[str, str]] = []
|
| 383 |
|
|
|
|
| 391 |
logger.error("[RAG] GROQ_API_KEY not set")
|
| 392 |
return
|
| 393 |
|
| 394 |
+
# Using Llama 4 Maverick 17B for fast, high-quality responses
|
| 395 |
self.llm = ChatGroq(
|
| 396 |
api_key=api_key,
|
| 397 |
+
model="meta-llama/llama-4-maverick-17b-128e-instruct",
|
| 398 |
temperature=0.3,
|
| 399 |
max_tokens=1024,
|
| 400 |
+
request_timeout=30, # 30 second timeout
|
| 401 |
)
|
| 402 |
+
logger.info("[RAG] Groq LLM initialized with Llama 4 Maverick 17B")
|
| 403 |
|
| 404 |
except Exception as e:
|
| 405 |
logger.error(f"[RAG] LLM initialization error: {e}")
|
| 406 |
|
| 407 |
+
def _extract_keywords(self, question: str) -> List[str]:
|
| 408 |
+
"""Extract key terms from question for graph search."""
|
| 409 |
+
# Remove common stopwords
|
| 410 |
+
stopwords = {
|
| 411 |
+
"what", "when", "where", "who", "why", "how", "is", "are", "was",
|
| 412 |
+
"were", "the", "a", "an", "to", "of", "in", "on", "for", "with",
|
| 413 |
+
"about", "related", "connected", "happened", "after", "before",
|
| 414 |
+
"show", "me", "tell", "find", "get", "events", "timeline"
|
| 415 |
+
}
|
| 416 |
+
|
| 417 |
+
words = question.lower().replace("?", "").replace(",", "").split()
|
| 418 |
+
keywords = [w for w in words if w not in stopwords and len(w) > 2]
|
| 419 |
+
|
| 420 |
+
return keywords[:5] # Return top 5 keywords
|
| 421 |
+
|
| 422 |
+
def _format_context(self, docs: List[Dict[str, Any]], include_graph: bool = False) -> str:
|
| 423 |
if not docs:
|
| 424 |
return "No relevant intelligence data found."
|
| 425 |
|
| 426 |
context_parts = []
|
| 427 |
now = datetime.now()
|
| 428 |
+
|
| 429 |
+
# Separate ChromaDB and Neo4j results
|
| 430 |
+
chroma_docs = [d for d in docs if d.get("source") != "neo4j_graph"]
|
| 431 |
+
graph_docs = [d for d in docs if d.get("source") == "neo4j_graph"]
|
| 432 |
|
| 433 |
+
# Format ChromaDB results
|
| 434 |
+
for i, doc in enumerate(chroma_docs[:5], 1):
|
| 435 |
meta = doc.get("metadata", {})
|
| 436 |
+
domain = meta.get("domain", doc.get("domain", "unknown"))
|
| 437 |
platform = meta.get("platform", "")
|
| 438 |
+
timestamp = meta.get("timestamp", doc.get("timestamp", ""))
|
| 439 |
|
| 440 |
age_str = "unknown date"
|
| 441 |
if timestamp:
|
|
|
|
| 447 |
"%d/%m/%Y",
|
| 448 |
]:
|
| 449 |
try:
|
| 450 |
+
ts_date = datetime.strptime(str(timestamp)[:19], fmt)
|
| 451 |
days_old = (now - ts_date).days
|
| 452 |
if days_old == 0:
|
| 453 |
age_str = "TODAY"
|
|
|
|
| 472 |
f"TIMESTAMP: {timestamp} ({age_str})\n"
|
| 473 |
f"{doc['content']}\n"
|
| 474 |
)
|
| 475 |
+
|
| 476 |
+
# Format Neo4j graph results (if any)
|
| 477 |
+
if graph_docs:
|
| 478 |
+
context_parts.append("\n=== RELATED EVENTS FROM KNOWLEDGE GRAPH ===\n")
|
| 479 |
+
for i, doc in enumerate(graph_docs[:3], 1):
|
| 480 |
+
related = doc.get("related", [])
|
| 481 |
+
related_str = ""
|
| 482 |
+
if related:
|
| 483 |
+
related_str = f"\n Related events: {', '.join(str(r)[:50] + '...' for r in related[:2])}"
|
| 484 |
+
|
| 485 |
+
context_parts.append(
|
| 486 |
+
f"[Graph {i}] Domain: {doc.get('domain', 'unknown')} | "
|
| 487 |
+
f"Severity: {doc.get('severity', 'unknown')}\n"
|
| 488 |
+
f"{doc.get('content', '')[:500]}"
|
| 489 |
+
f"{related_str}\n"
|
| 490 |
+
)
|
| 491 |
|
| 492 |
return "\n---\n".join(context_parts)
|
| 493 |
|
|
|
|
| 530 |
if use_history and self.chat_history:
|
| 531 |
search_question = self._reformulate_question(question)
|
| 532 |
|
| 533 |
+
# ChromaDB semantic search (always)
|
| 534 |
docs = self.retriever.search(
|
| 535 |
search_question, n_results=5, domain_filter=domain_filter
|
| 536 |
)
|
| 537 |
+
|
| 538 |
+
# Neo4j graph search (for relationship queries) - only if enabled
|
| 539 |
+
graph_docs = []
|
| 540 |
+
used_graph = False
|
| 541 |
+
if self.neo4j_retriever and is_graph_query(search_question):
|
| 542 |
+
logger.info(f"[RAG] Graph query detected: '{search_question}'")
|
| 543 |
+
used_graph = True
|
| 544 |
+
|
| 545 |
+
# Extract keywords for graph search
|
| 546 |
+
# Simple: use first nouns/keywords from question
|
| 547 |
+
keywords = self._extract_keywords(search_question)
|
| 548 |
+
|
| 549 |
+
for keyword in keywords[:2]: # Limit to 2 keywords
|
| 550 |
+
graph_docs.extend(self.neo4j_retriever.get_related_events(keyword, limit=3))
|
| 551 |
+
|
| 552 |
+
logger.info(f"[RAG] Graph retrieval: {len(graph_docs)} docs from Neo4j")
|
| 553 |
+
|
| 554 |
+
# Merge results (ChromaDB + Neo4j)
|
| 555 |
+
all_docs = docs + graph_docs
|
| 556 |
+
|
| 557 |
+
if not all_docs:
|
| 558 |
return {
|
| 559 |
"answer": "I couldn't find any relevant intelligence data to answer your question.",
|
| 560 |
"sources": [],
|
|
|
|
| 562 |
"reformulated": search_question if search_question != question else None,
|
| 563 |
}
|
| 564 |
|
| 565 |
+
context = self._format_context(all_docs, include_graph=used_graph)
|
| 566 |
|
| 567 |
if not self.llm:
|
| 568 |
return {
|
src/storage/storage_manager.py
CHANGED
|
@@ -4,6 +4,7 @@ Unified storage manager orchestrating 3-tier deduplication pipeline
|
|
| 4 |
"""
|
| 5 |
|
| 6 |
import logging
|
|
|
|
| 7 |
from typing import Dict, Any, List, Optional, Tuple
|
| 8 |
import csv
|
| 9 |
from datetime import datetime
|
|
@@ -16,6 +17,14 @@ from .neo4j_graph import Neo4jGraph
|
|
| 16 |
|
| 17 |
logger = logging.getLogger("storage_manager")
|
| 18 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
|
| 20 |
class StorageManager:
|
| 21 |
"""
|
|
@@ -133,6 +142,10 @@ class StorageManager:
|
|
| 133 |
metadata=metadata,
|
| 134 |
)
|
| 135 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 136 |
self.stats["unique_stored"] += 1
|
| 137 |
logger.debug(f"[STORE] Stored event {event_id[:8]}... in all databases")
|
| 138 |
|
|
@@ -140,6 +153,87 @@ class StorageManager:
|
|
| 140 |
self.stats["errors"] += 1
|
| 141 |
logger.error(f"[STORE] Error storing event: {e}")
|
| 142 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 143 |
def link_similar_events(self, event_id_1: str, event_id_2: str, similarity: float):
|
| 144 |
"""Create similarity link in Neo4j"""
|
| 145 |
self.neo4j.link_similar_events(event_id_1, event_id_2, similarity)
|
|
|
|
| 4 |
"""
|
| 5 |
|
| 6 |
import logging
|
| 7 |
+
import re
|
| 8 |
from typing import Dict, Any, List, Optional, Tuple
|
| 9 |
import csv
|
| 10 |
from datetime import datetime
|
|
|
|
| 17 |
|
| 18 |
logger = logging.getLogger("storage_manager")
|
| 19 |
|
| 20 |
+
# Trending detection integration
|
| 21 |
+
try:
|
| 22 |
+
from ..utils.trending_detector import record_topic_mention
|
| 23 |
+
TRENDING_AVAILABLE = True
|
| 24 |
+
except ImportError:
|
| 25 |
+
TRENDING_AVAILABLE = False
|
| 26 |
+
logger.warning("[StorageManager] Trending detector not available")
|
| 27 |
+
|
| 28 |
|
| 29 |
class StorageManager:
|
| 30 |
"""
|
|
|
|
| 142 |
metadata=metadata,
|
| 143 |
)
|
| 144 |
|
| 145 |
+
# Record keywords for trending detection
|
| 146 |
+
if TRENDING_AVAILABLE:
|
| 147 |
+
self._record_trending_mentions(summary, domain, metadata)
|
| 148 |
+
|
| 149 |
self.stats["unique_stored"] += 1
|
| 150 |
logger.debug(f"[STORE] Stored event {event_id[:8]}... in all databases")
|
| 151 |
|
|
|
|
| 153 |
self.stats["errors"] += 1
|
| 154 |
logger.error(f"[STORE] Error storing event: {e}")
|
| 155 |
|
| 156 |
+
def _extract_keywords(self, text: str, max_keywords: int = 5) -> List[str]:
|
| 157 |
+
"""
|
| 158 |
+
Extract significant keywords from text for trending detection.
|
| 159 |
+
|
| 160 |
+
Args:
|
| 161 |
+
text: Text to extract keywords from
|
| 162 |
+
max_keywords: Maximum number of keywords to return
|
| 163 |
+
|
| 164 |
+
Returns:
|
| 165 |
+
List of keywords (2-3 word phrases)
|
| 166 |
+
"""
|
| 167 |
+
# Common stopwords to filter out
|
| 168 |
+
stopwords = {
|
| 169 |
+
"the", "is", "at", "which", "on", "a", "an", "and", "or", "but",
|
| 170 |
+
"in", "with", "to", "for", "of", "as", "by", "from", "that", "this",
|
| 171 |
+
"be", "are", "was", "were", "been", "being", "have", "has", "had",
|
| 172 |
+
"do", "does", "did", "will", "would", "could", "should", "may",
|
| 173 |
+
"might", "must", "shall", "can", "need", "dare", "ought", "used",
|
| 174 |
+
"सिंहल", "தமிழ்", # Common Sinhala/Tamil particles
|
| 175 |
+
}
|
| 176 |
+
|
| 177 |
+
# Clean text
|
| 178 |
+
text = text.lower()
|
| 179 |
+
text = re.sub(r'http\S+|www\.\S+', '', text) # Remove URLs
|
| 180 |
+
text = re.sub(r'[^\w\s]', ' ', text) # Remove punctuation
|
| 181 |
+
|
| 182 |
+
# Split into words
|
| 183 |
+
words = text.split()
|
| 184 |
+
|
| 185 |
+
# Filter stopwords and short words
|
| 186 |
+
filtered = [w for w in words if w not in stopwords and len(w) > 2]
|
| 187 |
+
|
| 188 |
+
# Extract significant words (prioritize proper nouns, locations, etc.)
|
| 189 |
+
keywords = []
|
| 190 |
+
|
| 191 |
+
# Single important words (capitalized in original or long words)
|
| 192 |
+
for word in filtered[:20]:
|
| 193 |
+
if len(word) > 4: # Longer words are often more significant
|
| 194 |
+
keywords.append(word)
|
| 195 |
+
|
| 196 |
+
# Deduplicate and limit
|
| 197 |
+
seen = set()
|
| 198 |
+
unique_keywords = []
|
| 199 |
+
for kw in keywords:
|
| 200 |
+
if kw not in seen:
|
| 201 |
+
seen.add(kw)
|
| 202 |
+
unique_keywords.append(kw)
|
| 203 |
+
|
| 204 |
+
return unique_keywords[:max_keywords]
|
| 205 |
+
|
| 206 |
+
def _record_trending_mentions(
|
| 207 |
+
self,
|
| 208 |
+
summary: str,
|
| 209 |
+
domain: str,
|
| 210 |
+
metadata: Optional[Dict[str, Any]] = None
|
| 211 |
+
):
|
| 212 |
+
"""
|
| 213 |
+
Extract keywords from summary and record them for trending detection.
|
| 214 |
+
|
| 215 |
+
Args:
|
| 216 |
+
summary: Event summary text
|
| 217 |
+
domain: Event domain (political, economical, etc.)
|
| 218 |
+
metadata: Optional metadata with platform info
|
| 219 |
+
"""
|
| 220 |
+
try:
|
| 221 |
+
keywords = self._extract_keywords(summary)
|
| 222 |
+
source = metadata.get("platform", "scraper") if metadata else "scraper"
|
| 223 |
+
|
| 224 |
+
for keyword in keywords:
|
| 225 |
+
record_topic_mention(
|
| 226 |
+
topic=keyword,
|
| 227 |
+
source=source,
|
| 228 |
+
domain=domain
|
| 229 |
+
)
|
| 230 |
+
|
| 231 |
+
if keywords:
|
| 232 |
+
logger.debug(f"[TRENDING] Recorded {len(keywords)} keywords: {keywords[:3]}...")
|
| 233 |
+
|
| 234 |
+
except Exception as e:
|
| 235 |
+
logger.warning(f"[TRENDING] Error recording mentions: {e}")
|
| 236 |
+
|
| 237 |
def link_similar_events(self, event_id_1: str, event_id_2: str, similarity: float):
|
| 238 |
"""Create similarity link in Neo4j"""
|
| 239 |
self.neo4j.link_similar_events(event_id_1, event_id_2, similarity)
|
src/utils/.browser_data/linkedin/BrowserMetrics-spare.pma
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:bb9f8df61474d25e71fa00722318cd387396ca1736605e1248821cc0de3d3af8
|
| 3 |
+
size 4194304
|
src/utils/.browser_data/linkedin/Crashpad/metadata
ADDED
|
Binary file (310 Bytes). View file
|
|
|
src/utils/.browser_data/linkedin/Crashpad/reports/1bb2b465-675d-47f0-b953-a844af38ce6b.dmp
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:574b5012a5ecf99fb5133d8821bed664eaca4a686a0197a07298449b3db67bed
|
| 3 |
+
size 968496
|
src/utils/.browser_data/linkedin/Crashpad/reports/55792d7f-8397-4730-8518-c50a507a611a.dmp
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:9edc190dddd703583b366c589e3593ab3eab8ae2c3ee8b0e7884d116aaff6be2
|
| 3 |
+
size 4326864
|
src/utils/.browser_data/linkedin/Crashpad/reports/880fc1e0-3241-4d76-a26b-0f9d6135dcd6.dmp
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:009e8131265d00b6ef330dc1f2947daaa6b295c6bfdf47ebe974a64c8bc351a8
|
| 3 |
+
size 11408000
|
src/utils/.browser_data/linkedin/Crashpad/settings.dat
ADDED
|
Binary file (40 Bytes). View file
|
|
|
src/utils/.browser_data/linkedin/Default/Account Web Data
ADDED
|
Binary file (77.8 kB). View file
|
|
|
src/utils/.browser_data/linkedin/Default/Account Web Data-journal
ADDED
|
File without changes
|