Pradeep Rajan commited on
Commit
dc68ce6
·
1 Parent(s): 5c7dec2

Initial deployment of Zyon Traders Backend10

Browse files
.env.example CHANGED
@@ -1,27 +1,137 @@
1
  # Zyon Traders Backend Environment Variables
2
  # Copy this file to .env and fill in your actual values
3
 
4
- # App Configuration
 
 
5
  DEBUG=false
6
- SECRET_KEY=your-secret-key-here
7
- ALLOWED_ORIGINS=https://your-frontend-domain.com
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
- # Supabase Configuration
10
  SUPABASE_PROJECT_URL=https://your-project.supabase.co
11
  SUPABASE_ANON_KEY=your-supabase-anon-key
 
12
 
13
- # Dhan Trading API
14
- DHAN_API_KEY=your-dhan-api-key
 
 
 
15
 
16
- # AI Services
17
- GEMINI_API_KEY=your-gemini-api-key
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  TOGETHER_API_KEY=your-together-api-key
19
  FIREWORKS_API_KEY=your-fireworks-api-key
20
- HUGGINGFACE_API_KEY=your-huggingface-api-key
21
  LANGCHAIN_API_KEY=your-langchain-api-key
 
22
 
23
- # External Services
 
 
 
24
  UPTIMEROBOT_API_KEY=your-uptimerobot-key
25
 
26
- # Logging
27
- LOG_LEVEL=INFO
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # Zyon Traders Backend Environment Variables
2
  # Copy this file to .env and fill in your actual values
3
 
4
+ # =============================================================================
5
+ # APPLICATION CONFIGURATION
6
+ # =============================================================================
7
  DEBUG=false
8
+ SECRET_KEY=your-secret-key-here-change-in-production
9
+ PORT=7860
10
+ LOG_LEVEL=INFO
11
+
12
+ # =============================================================================
13
+ # CORS AND SECURITY
14
+ # =============================================================================
15
+ # Comma-separated list of allowed origins
16
+ ALLOWED_ORIGINS=http://localhost:3000,http://localhost:5173,https://your-frontend-domain.com
17
+
18
+ # =============================================================================
19
+ # DATABASE CONFIGURATION
20
+ # =============================================================================
21
+ # Optional: Direct database URL
22
+ DATABASE_URL=postgresql://user:password@localhost:5432/zyon_traders
23
 
24
+ # Supabase Configuration (Primary database)
25
  SUPABASE_PROJECT_URL=https://your-project.supabase.co
26
  SUPABASE_ANON_KEY=your-supabase-anon-key
27
+ SUPABASE_SERVICE_ROLE_KEY=your-supabase-service-role-key
28
 
29
+ # =============================================================================
30
+ # DHAN API CONFIGURATION
31
+ # =============================================================================
32
+ # Base API URL (usually doesn't need to change)
33
+ DHAN_API_BASE_URL=https://api.dhan.co/v2
34
 
35
+ # Trading Mode: "live", "paper", or "sandbox"
36
+ DHAN_TRADING_MODE=live
37
+
38
+ # Live Trading Configuration
39
+ DHAN_CLIENT_ID=your-live-client-id
40
+ DHAN_TRADING_API_TOKEN=your-trading-api-token
41
+ DHAN_DATA_API_TOKEN=your-data-api-token
42
+
43
+ # Paper Trading Configuration (for testing without real money)
44
+ DHAN_PAPER_TRADING_CLIENT_ID=your-paper-client-id
45
+ DHAN_PAPER_TRADING_TOKEN=your-paper-trading-token
46
+
47
+ # WebSocket URLs (usually don't need to change)
48
+ DHAN_WS_MARKET_FEED_URL=wss://api-feed.dhan.co
49
+ DHAN_WS_ORDER_UPDATE_URL=wss://api-order-update.dhan.co
50
+ DHAN_WS_MARKET_DEPTH_URL=wss://depth-api-feed.dhan.co/twentydepth
51
+
52
+ # Legacy API configuration (for backward compatibility)
53
+ DHAN_API_KEY=your-legacy-api-key
54
+ DHAN_ACCESS_TOKEN=your-legacy-access-token
55
+
56
+ # =============================================================================
57
+ # AI SERVICES CONFIGURATION
58
+ # =============================================================================
59
+ # AI model API keys (optional)
60
  TOGETHER_API_KEY=your-together-api-key
61
  FIREWORKS_API_KEY=your-fireworks-api-key
62
+ GEMINI_API_KEY=your-gemini-api-key
63
  LANGCHAIN_API_KEY=your-langchain-api-key
64
+ HUGGINGFACE_API_KEY=your-huggingface-api-key
65
 
66
+ # =============================================================================
67
+ # EXTERNAL SERVICES
68
+ # =============================================================================
69
+ # Monitoring and alerting
70
  UPTIMEROBOT_API_KEY=your-uptimerobot-key
71
 
72
+ # Redis for caching (optional)
73
+ REDIS_URL=redis://localhost:6379
74
+
75
+ # =============================================================================
76
+ # HUGGINGFACE SPACES CONFIGURATION
77
+ # =============================================================================
78
+ # These are automatically set by HuggingFace Spaces
79
+ SPACE_ID=your-space-id
80
+ SPACE_AUTHOR_NAME=your-username
81
+
82
+ # =============================================================================
83
+ # PERFORMANCE AND LIMITS
84
+ # =============================================================================
85
+ # Rate limiting
86
+ RATE_LIMIT_PER_MINUTE=100
87
+
88
+ # WebSocket settings
89
+ WS_HEARTBEAT_INTERVAL=30
90
+ WS_MAX_CONNECTIONS=1000
91
+
92
+ # =============================================================================
93
+ # DHAN API TOKEN SETUP GUIDE
94
+ # =============================================================================
95
+ #
96
+ # 1. TRADING API TOKEN:
97
+ # - Used for: Placing orders, managing positions, checking funds
98
+ # - Get from: Dhan API portal under "Trading API"
99
+ # - Set as: DHAN_TRADING_API_TOKEN
100
+ #
101
+ # 2. DATA API TOKEN:
102
+ # - Used for: Real-time quotes, historical data, market feeds
103
+ # - Get from: Dhan API portal under "Data API"
104
+ # - Set as: DHAN_DATA_API_TOKEN
105
+ #
106
+ # 3. PAPER TRADING TOKEN:
107
+ # - Used for: Testing strategies without real money
108
+ # - Get from: Dhan API portal under "Paper Trading"
109
+ # - Set as: DHAN_PAPER_TRADING_TOKEN
110
+ # - Also set: DHAN_PAPER_TRADING_CLIENT_ID
111
+ #
112
+ # 4. TRADING MODE:
113
+ # - Set to "paper" for testing with paper trading
114
+ # - Set to "live" for real trading
115
+ # - The system will automatically use the appropriate tokens
116
+ #
117
+ # =============================================================================
118
+ # DEPLOYMENT NOTES
119
+ # =============================================================================
120
+ #
121
+ # For HuggingFace Spaces:
122
+ # - PORT should be 7860 (automatically set)
123
+ # - Add your HF Space URL to ALLOWED_ORIGINS
124
+ # - Set DEBUG=false for production
125
+ #
126
+ # For Local Development:
127
+ # - PORT can be 8000 or any available port
128
+ # - Set DEBUG=true
129
+ # - Add localhost URLs to ALLOWED_ORIGINS
130
+ #
131
+ # For Production Deployment:
132
+ # - Always set DEBUG=false
133
+ # - Use strong SECRET_KEY
134
+ # - Set appropriate ALLOWED_ORIGINS
135
+ # - Use environment-specific database URLs
136
+ #
137
+ # =============================================================================
Dockerfile CHANGED
@@ -4,7 +4,7 @@ FROM python:3.10-slim
4
  # Set environment variables
5
  ENV PYTHONDONTWRITEBYTECODE=1
6
  ENV PYTHONUNBUFFERED=1
7
- ENV PORT=8000
8
 
9
  # Set work directory
10
  WORKDIR /app
@@ -35,9 +35,9 @@ USER app
35
  # Expose port
36
  EXPOSE $PORT
37
 
38
- # Health check
39
  HEALTHCHECK --interval=30s --timeout=30s --start-period=5s --retries=3 \
40
- CMD curl -f http://localhost:$PORT/api/health || exit 1
41
 
42
  # Run the application (HuggingFace Spaces compatible)
43
  CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "7860"]
 
4
  # Set environment variables
5
  ENV PYTHONDONTWRITEBYTECODE=1
6
  ENV PYTHONUNBUFFERED=1
7
+ ENV PORT=7860
8
 
9
  # Set work directory
10
  WORKDIR /app
 
35
  # Expose port
36
  EXPOSE $PORT
37
 
38
+ # Health check (use fixed port for HF Spaces)
39
  HEALTHCHECK --interval=30s --timeout=30s --start-period=5s --retries=3 \
40
+ CMD curl -f http://localhost:7860/api/health || exit 1
41
 
42
  # Run the application (HuggingFace Spaces compatible)
43
  CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "7860"]
config/settings.py CHANGED
@@ -1,58 +1,101 @@
1
  """
2
  Application settings and configuration
3
- Environment-based configuration for production deployment
4
  """
5
 
6
  import os
7
  from typing import List, Optional
8
- from pydantic import BaseSettings, validator
9
  from functools import lru_cache
10
 
11
  class Settings(BaseSettings):
12
  """Application settings"""
13
 
14
  # Basic app config
15
- DEBUG: bool = False
16
- SECRET_KEY: str = "your-secret-key-change-in-production"
17
  API_V1_STR: str = "/api"
18
  PROJECT_NAME: str = "Zyon Traders API"
19
 
20
- # Server config
21
  HOST: str = "0.0.0.0"
22
- PORT: int = 8000
23
 
24
- # CORS settings
25
  ALLOWED_ORIGINS: List[str] = [
26
  "http://localhost:3000",
 
27
  "http://localhost:8080",
28
- "https://zyon-traders.netlify.app", # Your frontend URL
29
- "https://your-domain.com" # Replace with your actual domain
 
 
 
 
 
30
  ]
31
 
32
  # Database
33
  DATABASE_URL: Optional[str] = None
34
 
35
- # Supabase
36
- SUPABASE_PROJECT_URL: str
37
- SUPABASE_ANON_KEY: str
38
- SUPABASE_SERVICE_ROLE_KEY: Optional[str] = None
 
 
 
 
 
 
39
 
40
- # Dhan API
41
- DHAN_API_KEY: str
42
  DHAN_API_BASE_URL: str = "https://api.dhan.co/v2"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
  # AI Services
45
- TOGETHER_API_KEY: Optional[str] = None
46
- FIREWORKS_API_KEY: Optional[str] = None
47
- GEMINI_API_KEY: Optional[str] = None
48
- LANGCHAIN_API_KEY: Optional[str] = None
49
- HUGGINGFACE_API_KEY: Optional[str] = None
50
 
51
  # External Services
52
- UPTIMEROBOT_API_KEY: Optional[str] = None
53
 
54
  # Redis (for caching - optional)
55
- REDIS_URL: Optional[str] = None
56
 
57
  # Rate limiting
58
  RATE_LIMIT_PER_MINUTE: int = 100
@@ -62,24 +105,135 @@ class Settings(BaseSettings):
62
  WS_MAX_CONNECTIONS: int = 1000
63
 
64
  # Logging
65
- LOG_LEVEL: str = "INFO"
 
 
 
 
66
 
67
  @validator("ALLOWED_ORIGINS", pre=True)
68
- def assemble_cors_origins(cls, v: str) -> List[str]:
69
- if isinstance(v, str) and not v.startswith("["):
70
- return [i.strip() for i in v.split(",")]
71
- elif isinstance(v, (list, str)):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  return v
73
- raise ValueError(v)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
 
75
  class Config:
76
  env_file = ".env"
77
  case_sensitive = True
 
 
 
 
78
 
79
  @lru_cache()
80
  def get_settings() -> Settings:
81
  """Get cached settings instance"""
82
  return Settings()
83
 
84
- # For Render.com deployment, settings will be loaded from environment variables
85
  settings = get_settings()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  """
2
  Application settings and configuration
3
+ Environment-based configuration for HuggingFace Spaces deployment
4
  """
5
 
6
  import os
7
  from typing import List, Optional
8
+ from pydantic import BaseSettings, validator, Field
9
  from functools import lru_cache
10
 
11
  class Settings(BaseSettings):
12
  """Application settings"""
13
 
14
  # Basic app config
15
+ DEBUG: bool = Field(default=False, env="DEBUG")
16
+ SECRET_KEY: str = Field(default="your-secret-key-change-in-production", env="SECRET_KEY")
17
  API_V1_STR: str = "/api"
18
  PROJECT_NAME: str = "Zyon Traders API"
19
 
20
+ # Server config - HF Spaces uses port 7860
21
  HOST: str = "0.0.0.0"
22
+ PORT: int = Field(default=7860, env="PORT")
23
 
24
+ # CORS settings - Updated for HF Spaces compatibility
25
  ALLOWED_ORIGINS: List[str] = [
26
  "http://localhost:3000",
27
+ "http://localhost:5173", # Vite dev server
28
  "http://localhost:8080",
29
+ "https://zyon-traders.netlify.app",
30
+ "https://zyonpackers-zyon-traders-backend.hf.space",
31
+ "https://huggingface.co",
32
+ "https://*.hf.space",
33
+ "https://*.netlify.app",
34
+ "https://*.vercel.app",
35
+ "https://*.surge.sh"
36
  ]
37
 
38
  # Database
39
  DATABASE_URL: Optional[str] = None
40
 
41
+ # Supabase - with defaults for HF Spaces
42
+ SUPABASE_PROJECT_URL: str = Field(
43
+ default="https://placeholder.supabase.co",
44
+ env="SUPABASE_PROJECT_URL"
45
+ )
46
+ SUPABASE_ANON_KEY: str = Field(
47
+ default="placeholder-anon-key",
48
+ env="SUPABASE_ANON_KEY"
49
+ )
50
+ SUPABASE_SERVICE_ROLE_KEY: Optional[str] = Field(default=None, env="SUPABASE_SERVICE_ROLE_KEY")
51
 
52
+ # Dhan API Configuration
 
53
  DHAN_API_BASE_URL: str = "https://api.dhan.co/v2"
54
+ DHAN_CLIENT_ID: str = "1107671523"
55
+
56
+ # Trading Environment Mode
57
+ DHAN_TRADING_MODE: str = Field(default="live", env="DHAN_TRADING_MODE")
58
+
59
+ # Trading API - For order management, portfolio, funds
60
+ DHAN_TRADING_API_TOKEN: str = Field(
61
+ default="eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzUxMiJ9.eyJpc3MiOiJkaGFuIiwicGFydG5lcklkIjoiIiwiZXhwIjoxNzU0NTQ1Nzg5LCJ0b2tlbkNvbnN1bWVyVHlwZSI6IlNFTEYiLCJ3ZWJob29rVXJsIjoiIiwiZGhhbkNsaWVudElkIjoiMTEwNzY3MTUyMyJ9.FJUHdkfv4MNNkM-_1AidgRGzch5xI8fBbpb-IZNF0O4sD6TkliY0pQaJ0mlTjX58UrCdzmbTDLziJwVmZeOmxA",
62
+ env="DHAN_TRADING_API_TOKEN"
63
+ )
64
+
65
+ # Data API - For real-time market data, quotes, historical data
66
+ DHAN_DATA_API_TOKEN: str = Field(
67
+ default="eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzUxMiJ9.eyJpc3MiOiJkaGFuIiwicGFydG5lcklkIjoiIiwiZXhwIjoxNzU0NTQ2MjI2LCJ0b2tlbkNvbnN1bWVyVHlwZSI6IlNFTEYiLCJ3ZWJob29rVXJsIjoiIiwiZGhhbkNsaWVudElkIjoiMTEwNzY3MTUyMyJ9.lvaxN5oI3h3_t3cTA-uCWsyh3909hdgBKmQMZDgspCxfkHN-Lh9QcoApy12GUdEvXgVs2oYwdUAdTnuPYuX78g",
68
+ env="DHAN_DATA_API_TOKEN"
69
+ )
70
+
71
+ # Paper Trading (Sandbox) API - For testing without real money
72
+ DHAN_PAPER_TRADING_CLIENT_ID: str = "2507087705"
73
+ DHAN_PAPER_TRADING_TOKEN: str = Field(
74
+ default="eyJhbGciOiJIUzUxMiIsInR5cCI6IkpXVCJ9.eyJ0b2tlbkNvbnN1bWVyVHlwZSI6IlNFTEYiLCJwYXJ0bmVySWQiOiIiLCJkaGFuQ2xpZW50SWQiOiIyNTA3MDg3NzA1Iiwid2ViaG9va1VybCI6Imh0dHBzOi8vMmZjNmMyYzc4MzQxNDIxYWIyMzc2YTM1YTdhNTViZTItOWE5NGRhZDEzNTc0NDMyM2I4N2YxNTc5ZS5mbHkuZGV2L2F1dGgvY2FsbGJhY2siLCJpc3MiOiJkaGFuIiwiZXhwIjoxNzU0NTgyMTQ1fQ.K_hIOXtkVmrqEfNH18nm-wuA1DxkReIAi7zWcBzDF-gEGYkZDzcnxPVNz2anDIzJM65Kr1SvV1L0brOfOawR3A",
75
+ env="DHAN_PAPER_TRADING_TOKEN"
76
+ )
77
+
78
+ # Dhan WebSocket endpoints
79
+ DHAN_WS_MARKET_FEED_URL: str = "wss://api-feed.dhan.co"
80
+ DHAN_WS_ORDER_UPDATE_URL: str = "wss://api-order-update.dhan.co"
81
+ DHAN_WS_MARKET_DEPTH_URL: str = "wss://depth-api-feed.dhan.co/twentydepth"
82
+
83
+ # Legacy support (for backward compatibility)
84
+ DHAN_API_KEY: Optional[str] = Field(default=None, env="DHAN_API_KEY")
85
+ DHAN_ACCESS_TOKEN: Optional[str] = Field(default=None, env="DHAN_ACCESS_TOKEN")
86
 
87
  # AI Services
88
+ TOGETHER_API_KEY: Optional[str] = Field(default=None, env="TOGETHER_API_KEY")
89
+ FIREWORKS_API_KEY: Optional[str] = Field(default=None, env="FIREWORKS_API_KEY")
90
+ GEMINI_API_KEY: Optional[str] = Field(default=None, env="GEMINI_API_KEY")
91
+ LANGCHAIN_API_KEY: Optional[str] = Field(default=None, env="LANGCHAIN_API_KEY")
92
+ HUGGINGFACE_API_KEY: Optional[str] = Field(default=None, env="HUGGINGFACE_API_KEY")
93
 
94
  # External Services
95
+ UPTIMEROBOT_API_KEY: Optional[str] = Field(default=None, env="UPTIMEROBOT_API_KEY")
96
 
97
  # Redis (for caching - optional)
98
+ REDIS_URL: Optional[str] = Field(default=None, env="REDIS_URL")
99
 
100
  # Rate limiting
101
  RATE_LIMIT_PER_MINUTE: int = 100
 
105
  WS_MAX_CONNECTIONS: int = 1000
106
 
107
  # Logging
108
+ LOG_LEVEL: str = Field(default="INFO", env="LOG_LEVEL")
109
+
110
+ # HuggingFace Spaces specific settings
111
+ SPACE_ID: Optional[str] = Field(default=None, env="SPACE_ID")
112
+ SPACE_AUTHOR_NAME: Optional[str] = Field(default=None, env="SPACE_AUTHOR_NAME")
113
 
114
  @validator("ALLOWED_ORIGINS", pre=True)
115
+ def assemble_cors_origins(cls, v) -> List[str]:
116
+ """Parse CORS origins from environment variable or list"""
117
+ if isinstance(v, str):
118
+ if v.startswith("["):
119
+ try:
120
+ import json
121
+ return json.loads(v)
122
+ except Exception:
123
+ pass
124
+ # Handle comma-separated string
125
+ origins = [i.strip() for i in v.split(",") if i.strip()]
126
+
127
+ # Add wildcard support for HF Spaces if needed
128
+ if any("hf.space" in origin for origin in origins):
129
+ origins.append("*") # Allow all for iframe compatibility
130
+
131
+ return origins
132
+ elif isinstance(v, list):
133
  return v
134
+ return [str(v)] if v else ["*"]
135
+
136
+ @validator("PORT", pre=True)
137
+ def validate_port(cls, v):
138
+ """Ensure port is integer and within valid range"""
139
+ try:
140
+ port = int(v)
141
+ if 1 <= port <= 65535:
142
+ return port
143
+ return 7860 # Default HF Spaces port
144
+ except (ValueError, TypeError):
145
+ return 7860
146
+
147
+ @validator("DEBUG", pre=True)
148
+ def validate_debug(cls, v):
149
+ """Parse DEBUG from various string formats"""
150
+ if isinstance(v, str):
151
+ return v.lower() in ("true", "1", "yes", "on")
152
+ return bool(v)
153
+
154
+ @validator("DHAN_TRADING_MODE")
155
+ def validate_trading_mode(cls, v):
156
+ """Validate trading mode is one of allowed values"""
157
+ allowed_modes = ["live", "paper", "sandbox"]
158
+ if v.lower() not in allowed_modes:
159
+ return "live" # Default to live
160
+ return v.lower()
161
+
162
+ def get_cors_origins(self) -> List[str]:
163
+ """Get CORS origins with HF Spaces compatibility"""
164
+ origins = self.ALLOWED_ORIGINS.copy()
165
+
166
+ # Ensure HF Spaces compatibility
167
+ if not any("*" in origin for origin in origins):
168
+ # Add wildcard for development and iframe compatibility
169
+ if self.DEBUG:
170
+ origins.append("*")
171
+
172
+ return origins
173
+
174
+ def is_production(self) -> bool:
175
+ """Check if running in production environment"""
176
+ return not self.DEBUG and "hf.space" in str(self.SPACE_ID or "")
177
+
178
+ def get_database_url(self) -> Optional[str]:
179
+ """Get database URL with fallback to Supabase"""
180
+ if self.DATABASE_URL:
181
+ return self.DATABASE_URL
182
+
183
+ # Construct from Supabase settings if available
184
+ if (self.SUPABASE_PROJECT_URL != "https://placeholder.supabase.co" and
185
+ self.SUPABASE_SERVICE_ROLE_KEY):
186
+ # Extract project ID from URL
187
+ try:
188
+ project_id = self.SUPABASE_PROJECT_URL.split("//")[1].split(".")[0]
189
+ return f"postgresql://postgres:[password]@db.{project_id}.supabase.co:5432/postgres"
190
+ except Exception:
191
+ pass
192
+
193
+ return None
194
 
195
  class Config:
196
  env_file = ".env"
197
  case_sensitive = True
198
+ # Allow extra fields for flexibility in different environments
199
+ extra = "allow"
200
+ # Enable environment variable parsing
201
+ env_file_encoding = 'utf-8'
202
 
203
  @lru_cache()
204
  def get_settings() -> Settings:
205
  """Get cached settings instance"""
206
  return Settings()
207
 
208
+ # Create settings instance for import
209
  settings = get_settings()
210
+
211
+ # Environment-specific configurations
212
+ def configure_for_environment():
213
+ """Configure settings based on environment"""
214
+ current_settings = get_settings()
215
+
216
+ # HuggingFace Spaces specific configuration
217
+ if current_settings.SPACE_ID:
218
+ # Ensure CORS is properly configured for HF Spaces
219
+ if "*" not in current_settings.ALLOWED_ORIGINS:
220
+ current_settings.ALLOWED_ORIGINS.append("*")
221
+
222
+ # Set production flags
223
+ current_settings.DEBUG = False
224
+
225
+ # Development environment
226
+ elif current_settings.DEBUG:
227
+ # Add development-friendly CORS
228
+ dev_origins = [
229
+ "http://localhost:3000",
230
+ "http://localhost:5173",
231
+ "http://127.0.0.1:3000",
232
+ "http://127.0.0.1:5173"
233
+ ]
234
+ for origin in dev_origins:
235
+ if origin not in current_settings.ALLOWED_ORIGINS:
236
+ current_settings.ALLOWED_ORIGINS.append(origin)
237
+
238
+ # Apply environment configuration
239
+ configure_for_environment()
main.py CHANGED
@@ -16,7 +16,8 @@ import os
16
  from datetime import datetime, timedelta
17
 
18
  # Import routers
19
- from routers import auth, signals, dhan, portfolio, screener, analytics
 
20
  from services.websocket_manager import WebSocketManager
21
  from config.settings import get_settings
22
  from utils.logging_config import setup_logging
@@ -62,13 +63,14 @@ app = FastAPI(
62
  lifespan=lifespan
63
  )
64
 
65
- # CORS middleware
66
  app.add_middleware(
67
  CORSMiddleware,
68
  allow_origins=settings.ALLOWED_ORIGINS,
69
  allow_credentials=True,
70
- allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"],
71
  allow_headers=["*"],
 
72
  )
73
 
74
  # Include routers
@@ -91,18 +93,31 @@ async def root():
91
 
92
  @app.get("/api/health")
93
  async def health_check():
94
- """Detailed health check"""
95
  return {
96
  "status": "healthy",
97
  "timestamp": datetime.utcnow().isoformat(),
98
  "version": "1.0.0",
 
 
99
  "services": {
100
  "database": "connected",
101
  "dhan_api": "connected",
102
- "ai_services": "connected"
 
 
 
 
 
 
103
  }
104
  }
105
 
 
 
 
 
 
106
  @app.websocket("/ws")
107
  async def websocket_endpoint(websocket: WebSocket):
108
  """WebSocket endpoint for real-time data"""
@@ -144,10 +159,14 @@ async def global_exception_handler(request, exc):
144
 
145
  if __name__ == "__main__":
146
  import uvicorn
 
 
147
  uvicorn.run(
148
  "main:app",
149
  host="0.0.0.0",
150
- port=int(os.getenv("PORT", 8000)),
151
  reload=settings.DEBUG,
152
- log_level="info"
 
 
153
  )
 
16
  from datetime import datetime, timedelta
17
 
18
  # Import routers
19
+ from routers import auth, signals, portfolio, screener, analytics
20
+ from routers import dhan_new as dhan
21
  from services.websocket_manager import WebSocketManager
22
  from config.settings import get_settings
23
  from utils.logging_config import setup_logging
 
63
  lifespan=lifespan
64
  )
65
 
66
+ # CORS middleware for Hugging Face Spaces compatibility
67
  app.add_middleware(
68
  CORSMiddleware,
69
  allow_origins=settings.ALLOWED_ORIGINS,
70
  allow_credentials=True,
71
+ allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS", "HEAD"],
72
  allow_headers=["*"],
73
+ expose_headers=["*"], # Important for HF Spaces
74
  )
75
 
76
  # Include routers
 
93
 
94
  @app.get("/api/health")
95
  async def health_check():
96
+ """Detailed health check for Hugging Face Spaces"""
97
  return {
98
  "status": "healthy",
99
  "timestamp": datetime.utcnow().isoformat(),
100
  "version": "1.0.0",
101
+ "environment": "huggingface_spaces",
102
+ "port": int(os.getenv("PORT", 7860)),
103
  "services": {
104
  "database": "connected",
105
  "dhan_api": "connected",
106
+ "ai_services": "connected",
107
+ "websocket": "available"
108
+ },
109
+ "endpoints": {
110
+ "api_base": "/api",
111
+ "websocket": "/ws",
112
+ "docs": "/docs" if settings.DEBUG else None
113
  }
114
  }
115
 
116
+ @app.get("/health")
117
+ async def simple_health_check():
118
+ """Simple health check for HF Spaces monitoring"""
119
+ return {"status": "ok", "timestamp": datetime.utcnow().isoformat()}
120
+
121
  @app.websocket("/ws")
122
  async def websocket_endpoint(websocket: WebSocket):
123
  """WebSocket endpoint for real-time data"""
 
159
 
160
  if __name__ == "__main__":
161
  import uvicorn
162
+ # Hugging Face Spaces uses port 7860 by default
163
+ port = int(os.getenv("PORT", 7860))
164
  uvicorn.run(
165
  "main:app",
166
  host="0.0.0.0",
167
+ port=port,
168
  reload=settings.DEBUG,
169
+ log_level="info",
170
+ proxy_headers=True, # Required for HF Spaces proxy
171
+ forwarded_allow_ips="*" # Allow forwarded IPs from HF proxy
172
  )
models/portfolio.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Pydantic models for portfolio management
3
+ """
4
+
5
+ from pydantic import BaseModel, Field, validator
6
+ from typing import List, Optional, Dict, Any
7
+ from datetime import datetime
8
+ from enum import Enum
9
+
10
+ class RiskLevel(str, Enum):
11
+ LOW = "low"
12
+ MEDIUM = "medium"
13
+ HIGH = "high"
14
+
15
+ class AssetClass(str, Enum):
16
+ EQUITY = "equity"
17
+ DEBT = "debt"
18
+ COMMODITY = "commodity"
19
+ CURRENCY = "currency"
20
+ HYBRID = "hybrid"
21
+
22
+ class PortfolioSummary(BaseModel):
23
+ total_value: float = Field(..., description="Total portfolio value")
24
+ total_invested: float = Field(..., description="Total amount invested")
25
+ total_pnl: float = Field(..., description="Total profit/loss")
26
+ pnl_percentage: float = Field(..., description="P&L percentage")
27
+ day_change: float = Field(..., description="Day's change in value")
28
+ day_change_percentage: float = Field(..., description="Day's change percentage")
29
+ holdings_count: int = Field(..., description="Number of holdings")
30
+ last_updated: datetime = Field(..., description="Last update timestamp")
31
+
32
+ class PortfolioHolding(BaseModel):
33
+ security_id: str = Field(..., description="Security ID")
34
+ symbol: str = Field(..., description="Trading symbol")
35
+ quantity: int = Field(..., description="Quantity held")
36
+ average_price: float = Field(..., description="Average buy price")
37
+ current_price: float = Field(..., description="Current market price")
38
+ current_value: float = Field(..., description="Current market value")
39
+ invested_value: float = Field(..., description="Total invested amount")
40
+ pnl: float = Field(..., description="Profit/loss amount")
41
+ pnl_percentage: float = Field(..., description="P&L percentage")
42
+ day_change: float = Field(..., description="Day's change")
43
+ day_change_percentage: float = Field(..., description="Day's change percentage")
44
+ sector: Optional[str] = Field(None, description="Sector classification")
45
+ asset_class: AssetClass = Field(..., description="Asset class")
46
+
47
+ class PortfolioPosition(BaseModel):
48
+ security_id: str = Field(..., description="Security ID")
49
+ symbol: str = Field(..., description="Trading symbol")
50
+ quantity: int = Field(..., description="Net quantity")
51
+ buy_quantity: int = Field(..., description="Buy quantity")
52
+ sell_quantity: int = Field(..., description="Sell quantity")
53
+ buy_average: float = Field(..., description="Average buy price")
54
+ sell_average: float = Field(..., description="Average sell price")
55
+ realized_pnl: float = Field(..., description="Realized P&L")
56
+ unrealized_pnl: float = Field(..., description="Unrealized P&L")
57
+ current_price: float = Field(..., description="Current market price")
58
+ product_type: str = Field(..., description="Product type")
59
+
60
+ class RiskMetrics(BaseModel):
61
+ beta: float = Field(..., description="Portfolio beta")
62
+ alpha: float = Field(..., description="Portfolio alpha")
63
+ sharpe_ratio: float = Field(..., description="Sharpe ratio")
64
+ sortino_ratio: float = Field(..., description="Sortino ratio")
65
+ volatility: float = Field(..., description="Portfolio volatility")
66
+ value_at_risk_1d: float = Field(..., description="1-day Value at Risk")
67
+ value_at_risk_5d: float = Field(..., description="5-day Value at Risk")
68
+ maximum_drawdown: float = Field(..., description="Maximum drawdown")
69
+ correlation_with_market: float = Field(..., description="Correlation with market")
70
+
71
+ class PerformanceMetrics(BaseModel):
72
+ total_return: float = Field(..., description="Total return")
73
+ annualized_return: float = Field(..., description="Annualized return")
74
+ best_day: float = Field(..., description="Best single day return")
75
+ worst_day: float = Field(..., description="Worst single day return")
76
+ positive_days: int = Field(..., description="Number of positive days")
77
+ negative_days: int = Field(..., description="Number of negative days")
78
+ win_rate: float = Field(..., description="Win rate percentage")
79
+ calmar_ratio: float = Field(..., description="Calmar ratio")
80
+ tracking_error: float = Field(..., description="Tracking error")
81
+
82
+ class SectorAllocation(BaseModel):
83
+ sector: str = Field(..., description="Sector name")
84
+ allocation_percentage: float = Field(..., description="Allocation percentage")
85
+ current_value: float = Field(..., description="Current value")
86
+ pnl: float = Field(..., description="Sector P&L")
87
+ pnl_percentage: float = Field(..., description="Sector P&L percentage")
88
+
89
+ class AssetAllocation(BaseModel):
90
+ asset_class: AssetClass = Field(..., description="Asset class")
91
+ allocation_percentage: float = Field(..., description="Allocation percentage")
92
+ current_value: float = Field(..., description="Current value")
93
+ target_percentage: Optional[float] = Field(None, description="Target allocation")
94
+
95
+ class RiskAssessment(BaseModel):
96
+ overall_risk_score: float = Field(..., ge=0, le=10, description="Overall risk score (0-10)")
97
+ risk_level: RiskLevel = Field(..., description="Risk level category")
98
+ concentration_risk: float = Field(..., description="Concentration risk score")
99
+ sector_concentration: Dict[str, float] = Field(..., description="Sector concentration")
100
+ liquidity_risk: float = Field(..., description="Liquidity risk score")
101
+ currency_risk: float = Field(..., description="Currency risk score")
102
+ recommendations: List[str] = Field(..., description="Risk mitigation recommendations")
103
+
104
+ class AllocationRequest(BaseModel):
105
+ target_allocations: List[AssetAllocation] = Field(..., description="Target asset allocations")
106
+ rebalancing_threshold: float = Field(5.0, ge=1, le=20, description="Rebalancing threshold percentage")
107
+ investment_amount: Optional[float] = Field(None, description="Additional investment amount")
108
+
109
+ class PortfolioOptimization(BaseModel):
110
+ optimization_objective: str = Field(..., description="Optimization objective")
111
+ expected_return: float = Field(..., description="Expected return")
112
+ expected_risk: float = Field(..., description="Expected risk")
113
+ efficient_frontier: List[Dict[str, float]] = Field(..., description="Efficient frontier points")
114
+ optimal_weights: Dict[str, float] = Field(..., description="Optimal asset weights")
115
+ suggested_changes: List[Dict[str, Any]] = Field(..., description="Suggested portfolio changes")
116
+
117
+ class RebalancingPlan(BaseModel):
118
+ current_allocation: List[AssetAllocation] = Field(..., description="Current allocation")
119
+ target_allocation: List[AssetAllocation] = Field(..., description="Target allocation")
120
+ required_trades: List[Dict[str, Any]] = Field(..., description="Required trades")
121
+ estimated_cost: float = Field(..., description="Estimated transaction cost")
122
+ tax_implications: float = Field(..., description="Estimated tax implications")
123
+ net_benefit: float = Field(..., description="Net benefit after costs")
124
+
125
+ class PortfolioAlert(BaseModel):
126
+ alert_type: str = Field(..., description="Type of alert")
127
+ severity: str = Field(..., description="Alert severity")
128
+ message: str = Field(..., description="Alert message")
129
+ affected_securities: List[str] = Field(..., description="Affected securities")
130
+ recommended_action: str = Field(..., description="Recommended action")
131
+ created_at: datetime = Field(..., description="Alert creation time")
132
+
133
+ class DiversificationMetrics(BaseModel):
134
+ herfindahl_index: float = Field(..., description="Herfindahl concentration index")
135
+ effective_holdings: int = Field(..., description="Effective number of holdings")
136
+ sector_diversification: float = Field(..., description="Sector diversification score")
137
+ geographic_diversification: float = Field(..., description="Geographic diversification score")
138
+ market_cap_diversification: float = Field(..., description="Market cap diversification score")
139
+ diversification_score: float = Field(..., ge=0, le=100, description="Overall diversification score")
140
+
141
+ class CorrelationMatrix(BaseModel):
142
+ securities: List[str] = Field(..., description="List of securities")
143
+ correlation_data: List[List[float]] = Field(..., description="Correlation matrix data")
144
+ average_correlation: float = Field(..., description="Average correlation")
145
+ max_correlation: float = Field(..., description="Maximum correlation")
146
+ min_correlation: float = Field(..., description="Minimum correlation")
147
+
148
+ class PortfolioComparison(BaseModel):
149
+ portfolio_return: float = Field(..., description="Portfolio return")
150
+ benchmark_return: float = Field(..., description="Benchmark return")
151
+ alpha: float = Field(..., description="Alpha vs benchmark")
152
+ beta: float = Field(..., description="Beta vs benchmark")
153
+ tracking_error: float = Field(..., description="Tracking error")
154
+ information_ratio: float = Field(..., description="Information ratio")
155
+ up_capture: float = Field(..., description="Up market capture ratio")
156
+ down_capture: float = Field(..., description="Down market capture ratio")
157
+
158
+ class StressTestResult(BaseModel):
159
+ scenario_name: str = Field(..., description="Stress test scenario name")
160
+ scenario_description: str = Field(..., description="Scenario description")
161
+ portfolio_impact: float = Field(..., description="Portfolio impact amount")
162
+ impact_percentage: float = Field(..., description="Impact as percentage")
163
+ worst_affected_holdings: List[Dict[str, Any]] = Field(..., description="Most affected holdings")
164
+ recommendations: List[str] = Field(..., description="Risk mitigation recommendations")
165
+
166
+ class TaxOptimization(BaseModel):
167
+ current_year: int = Field(..., description="Tax year")
168
+ realized_gains: float = Field(..., description="Realized capital gains")
169
+ realized_losses: float = Field(..., description="Realized capital losses")
170
+ unrealized_gains: float = Field(..., description="Unrealized capital gains")
171
+ unrealized_losses: float = Field(..., description="Unrealized capital losses")
172
+ tax_liability: float = Field(..., description="Estimated tax liability")
173
+ loss_harvesting_opportunities: List[Dict[str, Any]] = Field(..., description="Tax loss harvesting opportunities")
174
+ ltcg_stcg_breakdown: Dict[str, float] = Field(..., description="Long-term vs short-term gains")
175
+
176
+ class PortfolioInsights(BaseModel):
177
+ top_contributors: List[Dict[str, Any]] = Field(..., description="Top contributing holdings")
178
+ top_detractors: List[Dict[str, Any]] = Field(..., description="Top detracting holdings")
179
+ momentum_stocks: List[Dict[str, Any]] = Field(..., description="Stocks with positive momentum")
180
+ lagging_stocks: List[Dict[str, Any]] = Field(..., description="Underperforming stocks")
181
+ overweight_sectors: List[str] = Field(..., description="Overweight sectors")
182
+ underweight_sectors: List[str] = Field(..., description="Underweight sectors")
183
+ key_insights: List[str] = Field(..., description="Key portfolio insights")
184
+
185
+ class PortfolioHealthScore(BaseModel):
186
+ overall_score: float = Field(..., ge=0, le=100, description="Overall portfolio health score")
187
+ diversification_score: float = Field(..., description="Diversification component score")
188
+ performance_score: float = Field(..., description="Performance component score")
189
+ risk_score: float = Field(..., description="Risk management component score")
190
+ cost_efficiency_score: float = Field(..., description="Cost efficiency score")
191
+ liquidity_score: float = Field(..., description="Liquidity score")
192
+ score_breakdown: Dict[str, float] = Field(..., description="Detailed score breakdown")
193
+ improvement_suggestions: List[str] = Field(..., description="Suggestions for improvement")
194
+
195
+ # Response Models
196
+ class PortfolioResponse(BaseModel):
197
+ summary: PortfolioSummary
198
+ holdings: List[PortfolioHolding]
199
+ positions: List[PortfolioPosition]
200
+ risk_metrics: RiskMetrics
201
+ performance_metrics: PerformanceMetrics
202
+ sector_allocation: List[SectorAllocation]
203
+ asset_allocation: List[AssetAllocation]
204
+ last_updated: datetime
205
+
206
+ class PortfolioAnalyticsResponse(BaseModel):
207
+ portfolio_summary: PortfolioSummary
208
+ risk_assessment: RiskAssessment
209
+ performance_metrics: PerformanceMetrics
210
+ diversification_metrics: DiversificationMetrics
211
+ portfolio_insights: PortfolioInsights
212
+ health_score: PortfolioHealthScore
213
+ alerts: List[PortfolioAlert]
214
+ timestamp: datetime
215
+
216
+ # Request Models
217
+ class PortfolioOptimizationRequest(BaseModel):
218
+ objective: str = Field("max_sharpe", description="Optimization objective")
219
+ constraints: Dict[str, Any] = Field(default_factory=dict, description="Optimization constraints")
220
+ risk_tolerance: RiskLevel = Field(RiskLevel.MEDIUM, description="Risk tolerance")
221
+ investment_horizon: str = Field("medium", description="Investment horizon")
222
+
223
+ class RebalancingRequest(BaseModel):
224
+ target_allocation: List[AssetAllocation] = Field(..., description="Target asset allocation")
225
+ rebalancing_frequency: str = Field("quarterly", description="Rebalancing frequency")
226
+ threshold_percentage: float = Field(5.0, description="Rebalancing threshold")
227
+ tax_optimization: bool = Field(True, description="Consider tax implications")
228
+
229
+ class PortfolioBacktestRequest(BaseModel):
230
+ start_date: str = Field(..., description="Backtest start date")
231
+ end_date: str = Field(..., description="Backtest end date")
232
+ initial_amount: float = Field(..., description="Initial investment amount")
233
+ rebalancing_frequency: str = Field("monthly", description="Rebalancing frequency")
234
+ benchmark: str = Field("NIFTY50", description="Benchmark for comparison")
235
+
236
+ class StressTestRequest(BaseModel):
237
+ scenarios: List[str] = Field(..., description="Stress test scenarios to run")
238
+ confidence_levels: List[float] = Field([0.95, 0.99], description="Confidence levels for VaR")
239
+ time_horizons: List[int] = Field([1, 5, 21], description="Time horizons in days")
models/signals.py ADDED
@@ -0,0 +1,288 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Pydantic models for AI trading signals
3
+ """
4
+
5
+ from pydantic import BaseModel, Field, validator
6
+ from typing import List, Optional, Dict, Any, Union
7
+ from datetime import datetime
8
+ from enum import Enum
9
+
10
+ class SignalType(str, Enum):
11
+ BUY = "BUY"
12
+ SELL = "SELL"
13
+ HOLD = "HOLD"
14
+ STRONG_BUY = "STRONG_BUY"
15
+ STRONG_SELL = "STRONG_SELL"
16
+
17
+ class SignalStrength(str, Enum):
18
+ WEAK = "weak"
19
+ MODERATE = "moderate"
20
+ STRONG = "strong"
21
+ VERY_STRONG = "very_strong"
22
+
23
+ class TimeFrame(str, Enum):
24
+ INTRADAY = "intraday"
25
+ SHORT_TERM = "short_term" # 1-7 days
26
+ MEDIUM_TERM = "medium_term" # 1-4 weeks
27
+ LONG_TERM = "long_term" # 1+ months
28
+
29
+ class IndicatorType(str, Enum):
30
+ TECHNICAL = "technical"
31
+ FUNDAMENTAL = "fundamental"
32
+ SENTIMENT = "sentiment"
33
+ QUANTITATIVE = "quantitative"
34
+ HYBRID = "hybrid"
35
+
36
+ class MarketCondition(str, Enum):
37
+ BULLISH = "bullish"
38
+ BEARISH = "bearish"
39
+ SIDEWAYS = "sideways"
40
+ VOLATILE = "volatile"
41
+
42
+ class AnalysisType(str, Enum):
43
+ TECHNICAL = "technical"
44
+ FUNDAMENTAL = "fundamental"
45
+ QUANTITATIVE = "quantitative"
46
+ SENTIMENT = "sentiment"
47
+ AI_ML = "ai_ml"
48
+
49
+ # Core Signal Models
50
+ class TechnicalIndicator(BaseModel):
51
+ name: str = Field(..., description="Indicator name (e.g., RSI, MACD, SMA)")
52
+ value: float = Field(..., description="Current indicator value")
53
+ signal: SignalType = Field(..., description="Signal generated by this indicator")
54
+ strength: SignalStrength = Field(..., description="Signal strength")
55
+ description: str = Field(..., description="Human-readable description")
56
+
57
+ class FundamentalMetric(BaseModel):
58
+ metric_name: str = Field(..., description="Fundamental metric name")
59
+ current_value: float = Field(..., description="Current value")
60
+ industry_average: Optional[float] = Field(None, description="Industry average")
61
+ percentile_rank: Optional[float] = Field(None, description="Percentile rank")
62
+ trend: str = Field(..., description="Trend direction")
63
+ impact: str = Field(..., description="Impact on signal")
64
+
65
+ class SentimentIndicator(BaseModel):
66
+ source: str = Field(..., description="Sentiment data source")
67
+ score: float = Field(..., ge=-1, le=1, description="Sentiment score (-1 to 1)")
68
+ confidence: float = Field(..., ge=0, le=1, description="Confidence in sentiment")
69
+ key_factors: List[str] = Field(..., description="Key sentiment drivers")
70
+ news_count: int = Field(..., description="Number of news articles analyzed")
71
+
72
+ class AISignal(BaseModel):
73
+ model_name: str = Field(..., description="AI model that generated the signal")
74
+ prediction: SignalType = Field(..., description="Model prediction")
75
+ confidence_score: float = Field(..., ge=0, le=1, description="Model confidence")
76
+ feature_importance: Dict[str, float] = Field(..., description="Feature importance scores")
77
+ prediction_horizon: str = Field(..., description="Time horizon for prediction")
78
+
79
+ # Main Signal Response Model
80
+ class SignalResponse(BaseModel):
81
+ signal_id: str = Field(..., description="Unique signal identifier")
82
+ security_id: str = Field(..., description="Security ID")
83
+ symbol: str = Field(..., description="Trading symbol")
84
+ exchange_segment: str = Field(..., description="Exchange segment")
85
+
86
+ # Core signal information
87
+ signal_type: SignalType = Field(..., description="Primary signal type")
88
+ signal_strength: SignalStrength = Field(..., description="Overall signal strength")
89
+ confidence_score: float = Field(..., ge=0, le=1, description="Overall confidence")
90
+ time_frame: TimeFrame = Field(..., description="Signal time frame")
91
+
92
+ # Price and target information
93
+ current_price: float = Field(..., description="Current market price")
94
+ target_price: Optional[float] = Field(None, description="Target price")
95
+ stop_loss: Optional[float] = Field(None, description="Stop loss price")
96
+ entry_price_range: Optional[Dict[str, float]] = Field(None, description="Optimal entry range")
97
+
98
+ # Analysis components
99
+ technical_analysis: List[TechnicalIndicator] = Field(default_factory=list, description="Technical indicators")
100
+ fundamental_analysis: List[FundamentalMetric] = Field(default_factory=list, description="Fundamental metrics")
101
+ sentiment_analysis: List[SentimentIndicator] = Field(default_factory=list, description="Sentiment indicators")
102
+ ai_analysis: List[AISignal] = Field(default_factory=list, description="AI model predictions")
103
+
104
+ # Additional metadata
105
+ risk_rating: str = Field(..., description="Risk rating for the signal")
106
+ expected_return: Optional[float] = Field(None, description="Expected return percentage")
107
+ probability_of_success: Optional[float] = Field(None, description="Success probability")
108
+ market_condition: MarketCondition = Field(..., description="Current market condition")
109
+
110
+ # Timestamps
111
+ generated_at: datetime = Field(..., description="Signal generation timestamp")
112
+ valid_until: Optional[datetime] = Field(None, description="Signal expiry timestamp")
113
+
114
+ # Human-readable summary
115
+ summary: str = Field(..., description="Human-readable signal summary")
116
+ reasoning: str = Field(..., description="Detailed reasoning behind the signal")
117
+ key_catalysts: List[str] = Field(default_factory=list, description="Key catalysts driving the signal")
118
+
119
+ @validator('target_price', 'stop_loss')
120
+ def validate_price_levels(cls, v, values, field):
121
+ if v is not None and 'current_price' in values:
122
+ current_price = values['current_price']
123
+ signal_type = values.get('signal_type')
124
+
125
+ if field.name == 'target_price' and signal_type in ['BUY', 'STRONG_BUY'] and v <= current_price:
126
+ raise ValueError('Target price should be higher than current price for buy signals')
127
+ elif field.name == 'target_price' and signal_type in ['SELL', 'STRONG_SELL'] and v >= current_price:
128
+ raise ValueError('Target price should be lower than current price for sell signals')
129
+ elif field.name == 'stop_loss' and signal_type in ['BUY', 'STRONG_BUY'] and v >= current_price:
130
+ raise ValueError('Stop loss should be lower than current price for buy signals')
131
+ elif field.name == 'stop_loss' and signal_type in ['SELL', 'STRONG_SELL'] and v <= current_price:
132
+ raise ValueError('Stop loss should be higher than current price for sell signals')
133
+
134
+ return v
135
+
136
+ # Request Models
137
+ class SecurityRequest(BaseModel):
138
+ security_id: str = Field(..., description="Security ID")
139
+ exchange_segment: str = Field(..., description="Exchange segment")
140
+
141
+ class SignalRequest(BaseModel):
142
+ securities: List[SecurityRequest] = Field(..., description="Securities to analyze")
143
+ analysis_types: List[AnalysisType] = Field(..., description="Types of analysis to perform")
144
+ time_frames: List[TimeFrame] = Field(..., description="Time frames for analysis")
145
+ preferences: Optional[Dict[str, Any]] = Field(None, description="User preferences for signal generation")
146
+
147
+ @validator('securities')
148
+ def validate_securities_count(cls, v):
149
+ if len(v) > 50:
150
+ raise ValueError('Maximum 50 securities allowed per request')
151
+ return v
152
+
153
+ class SignalFilter(BaseModel):
154
+ signal_types: Optional[List[SignalType]] = Field(None, description="Filter by signal types")
155
+ signal_strength: Optional[List[SignalStrength]] = Field(None, description="Filter by signal strength")
156
+ time_frames: Optional[List[TimeFrame]] = Field(None, description="Filter by time frames")
157
+ min_confidence: Optional[float] = Field(None, ge=0, le=1, description="Minimum confidence score")
158
+ symbols: Optional[List[str]] = Field(None, description="Filter by symbols")
159
+ date_from: Optional[datetime] = Field(None, description="Filter signals from date")
160
+ date_to: Optional[datetime] = Field(None, description="Filter signals to date")
161
+ risk_rating: Optional[List[str]] = Field(None, description="Filter by risk rating")
162
+
163
+ class UserPreferences(BaseModel):
164
+ risk_tolerance: str = Field("medium", description="Risk tolerance level")
165
+ investment_horizon: TimeFrame = Field(TimeFrame.MEDIUM_TERM, description="Preferred investment horizon")
166
+ preferred_indicators: List[str] = Field(default_factory=list, description="Preferred technical indicators")
167
+ sectors_to_include: List[str] = Field(default_factory=list, description="Preferred sectors")
168
+ sectors_to_exclude: List[str] = Field(default_factory=list, description="Sectors to avoid")
169
+ min_market_cap: Optional[float] = Field(None, description="Minimum market cap filter")
170
+ max_market_cap: Optional[float] = Field(None, description="Maximum market cap filter")
171
+ esg_preference: bool = Field(False, description="Prefer ESG-compliant stocks")
172
+
173
+ # Market Analysis Models
174
+ class MarketTrend(BaseModel):
175
+ trend_direction: str = Field(..., description="Overall trend direction")
176
+ trend_strength: SignalStrength = Field(..., description="Trend strength")
177
+ support_levels: List[float] = Field(..., description="Key support levels")
178
+ resistance_levels: List[float] = Field(..., description="Key resistance levels")
179
+ trend_duration: int = Field(..., description="Trend duration in days")
180
+
181
+ class SectorAnalysis(BaseModel):
182
+ sector_name: str = Field(..., description="Sector name")
183
+ performance: float = Field(..., description="Sector performance percentage")
184
+ outlook: str = Field(..., description="Sector outlook")
185
+ key_drivers: List[str] = Field(..., description="Key sector drivers")
186
+ top_stocks: List[str] = Field(..., description="Top performing stocks in sector")
187
+ recommendation: SignalType = Field(..., description="Sector recommendation")
188
+
189
+ class MarketAnalysisResponse(BaseModel):
190
+ overall_market_sentiment: str = Field(..., description="Overall market sentiment")
191
+ market_trend: MarketTrend = Field(..., description="Current market trend")
192
+ volatility_index: float = Field(..., description="Market volatility index")
193
+ fear_greed_index: int = Field(..., ge=0, le=100, description="Fear & Greed Index")
194
+ sector_analysis: List[SectorAnalysis] = Field(..., description="Sector-wise analysis")
195
+ macro_factors: List[str] = Field(..., description="Key macroeconomic factors")
196
+ market_events: List[str] = Field(..., description="Upcoming market events")
197
+ generated_at: datetime = Field(..., description="Analysis generation time")
198
+
199
+ # Backtesting and Performance Models
200
+ class BacktestResult(BaseModel):
201
+ start_date: datetime = Field(..., description="Backtest start date")
202
+ end_date: datetime = Field(..., description="Backtest end date")
203
+ total_signals: int = Field(..., description="Total signals generated")
204
+ successful_signals: int = Field(..., description="Number of successful signals")
205
+ accuracy_rate: float = Field(..., ge=0, le=1, description="Signal accuracy rate")
206
+ total_return: float = Field(..., description="Total return percentage")
207
+ annualized_return: float = Field(..., description="Annualized return")
208
+ max_drawdown: float = Field(..., description="Maximum drawdown")
209
+ sharpe_ratio: float = Field(..., description="Sharpe ratio")
210
+ win_rate: float = Field(..., description="Win rate percentage")
211
+ average_win: float = Field(..., description="Average winning trade")
212
+ average_loss: float = Field(..., description="Average losing trade")
213
+ profit_factor: float = Field(..., description="Profit factor")
214
+
215
+ class SignalPerformance(BaseModel):
216
+ signal_id: str = Field(..., description="Signal identifier")
217
+ entry_price: float = Field(..., description="Actual entry price")
218
+ exit_price: Optional[float] = Field(None, description="Actual exit price")
219
+ return_percentage: Optional[float] = Field(None, description="Return percentage")
220
+ holding_period: Optional[int] = Field(None, description="Holding period in days")
221
+ status: str = Field(..., description="Signal status (open/closed/expired)")
222
+ outcome: Optional[str] = Field(None, description="Signal outcome (success/failure)")
223
+
224
+ # Alert and Notification Models
225
+ class SignalAlert(BaseModel):
226
+ alert_id: str = Field(..., description="Alert identifier")
227
+ user_id: str = Field(..., description="User identifier")
228
+ symbol: str = Field(..., description="Symbol to watch")
229
+ condition: str = Field(..., description="Alert condition")
230
+ threshold: float = Field(..., description="Alert threshold")
231
+ signal_type: SignalType = Field(..., description="Signal type to trigger")
232
+ is_active: bool = Field(True, description="Alert active status")
233
+ created_at: datetime = Field(..., description="Alert creation time")
234
+ triggered_at: Optional[datetime] = Field(None, description="Alert trigger time")
235
+
236
+ class NotificationPreference(BaseModel):
237
+ email_enabled: bool = Field(True, description="Email notifications enabled")
238
+ sms_enabled: bool = Field(False, description="SMS notifications enabled")
239
+ push_enabled: bool = Field(True, description="Push notifications enabled")
240
+ signal_types: List[SignalType] = Field(..., description="Signal types to notify")
241
+ min_confidence: float = Field(0.7, description="Minimum confidence for notifications")
242
+ time_frames: List[TimeFrame] = Field(..., description="Time frames for notifications")
243
+
244
+ # Custom Strategy Models
245
+ class StrategyRule(BaseModel):
246
+ rule_type: str = Field(..., description="Type of rule (technical/fundamental/custom)")
247
+ parameter: str = Field(..., description="Parameter name")
248
+ operator: str = Field(..., description="Comparison operator")
249
+ value: Union[float, str] = Field(..., description="Threshold value")
250
+ weight: float = Field(1.0, description="Rule weight in strategy")
251
+
252
+ class CustomStrategy(BaseModel):
253
+ strategy_id: str = Field(..., description="Strategy identifier")
254
+ name: str = Field(..., description="Strategy name")
255
+ description: str = Field(..., description="Strategy description")
256
+ rules: List[StrategyRule] = Field(..., description="Strategy rules")
257
+ signal_threshold: float = Field(..., description="Signal generation threshold")
258
+ risk_management: Dict[str, Any] = Field(..., description="Risk management parameters")
259
+ backtest_results: Optional[BacktestResult] = Field(None, description="Backtesting results")
260
+ is_active: bool = Field(True, description="Strategy active status")
261
+ created_by: str = Field(..., description="Strategy creator")
262
+ created_at: datetime = Field(..., description="Strategy creation time")
263
+
264
+ # Signal Statistics and Analytics
265
+ class SignalStatistics(BaseModel):
266
+ total_signals_generated: int = Field(..., description="Total signals generated")
267
+ signals_by_type: Dict[str, int] = Field(..., description="Signals breakdown by type")
268
+ signals_by_strength: Dict[str, int] = Field(..., description="Signals breakdown by strength")
269
+ average_confidence: float = Field(..., description="Average confidence score")
270
+ success_rate: float = Field(..., description="Overall success rate")
271
+ top_performing_symbols: List[Dict[str, Any]] = Field(..., description="Best performing symbols")
272
+ model_performance: Dict[str, float] = Field(..., description="Individual model performance")
273
+ last_updated: datetime = Field(..., description="Last update timestamp")
274
+
275
+ # Batch Processing Models
276
+ class BatchSignalRequest(BaseModel):
277
+ securities: List[SecurityRequest] = Field(..., description="Securities to analyze")
278
+ analysis_config: Dict[str, Any] = Field(..., description="Analysis configuration")
279
+ callback_url: Optional[str] = Field(None, description="Callback URL for results")
280
+ priority: str = Field("normal", description="Processing priority")
281
+
282
+ class BatchSignalResponse(BaseModel):
283
+ batch_id: str = Field(..., description="Batch processing ID")
284
+ status: str = Field(..., description="Batch processing status")
285
+ total_securities: int = Field(..., description="Total securities in batch")
286
+ completed_securities: int = Field(..., description="Completed securities count")
287
+ estimated_completion: Optional[datetime] = Field(None, description="Estimated completion time")
288
+ results_url: Optional[str] = Field(None, description="Results download URL")
routers/analytics.py ADDED
@@ -0,0 +1,479 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Analytics Router
3
+ Handles portfolio analytics, market analysis, and reporting
4
+ """
5
+
6
+ from fastapi import APIRouter, HTTPException, Depends, Query
7
+ from typing import List, Dict, Any, Optional
8
+ import logging
9
+ from datetime import datetime, timedelta
10
+
11
+ from config.settings import get_settings
12
+ from services.auth import get_current_user
13
+ from services.dhan_api_manager import dhan_api_manager, DhanAPIError
14
+
15
+ logger = logging.getLogger(__name__)
16
+ router = APIRouter()
17
+ settings = get_settings()
18
+
19
+ @router.get("/portfolio-analytics")
20
+ async def get_portfolio_analytics(
21
+ period: str = Query("1Y", description="Analysis period (1M, 3M, 6M, 1Y, 2Y)"),
22
+ current_user: Dict = Depends(get_current_user)
23
+ ):
24
+ """Get comprehensive portfolio analytics"""
25
+ try:
26
+ # Get holdings and positions
27
+ holdings_result = await dhan_api_manager.get_holdings()
28
+ positions_result = await dhan_api_manager.get_positions()
29
+
30
+ holdings = holdings_result.get("data", [])
31
+ positions = positions_result.get("data", [])
32
+
33
+ # Calculate analytics (mock implementation)
34
+ total_value = sum(holding.get("currentValue", 0) for holding in holdings)
35
+ total_pnl = sum(holding.get("pnl", 0) for holding in holdings)
36
+ total_invested = total_value - total_pnl
37
+
38
+ return {
39
+ "portfolio_value": total_value,
40
+ "total_invested": total_invested,
41
+ "total_pnl": total_pnl,
42
+ "pnl_percentage": (total_pnl / total_invested * 100) if total_invested > 0 else 0,
43
+ "holdings_count": len(holdings),
44
+ "positions_count": len(positions),
45
+ "period": period,
46
+ "analytics": {
47
+ "best_performer": get_best_performer(holdings),
48
+ "worst_performer": get_worst_performer(holdings),
49
+ "sector_allocation": calculate_sector_allocation(holdings),
50
+ "risk_metrics": calculate_risk_metrics(holdings),
51
+ "performance_metrics": calculate_performance_metrics(holdings, period)
52
+ },
53
+ "timestamp": datetime.utcnow().isoformat()
54
+ }
55
+
56
+ except DhanAPIError as e:
57
+ logger.error(f"Dhan API error in portfolio analytics: {e}")
58
+ raise HTTPException(status_code=e.status_code or 500, detail=str(e))
59
+ except Exception as e:
60
+ logger.error(f"Error generating portfolio analytics: {e}")
61
+ raise HTTPException(status_code=500, detail="Failed to generate portfolio analytics")
62
+
63
+ @router.get("/market-analytics")
64
+ async def get_market_analytics(
65
+ indices: List[str] = Query(["NIFTY", "SENSEX"], description="Market indices to analyze"),
66
+ current_user: Dict = Depends(get_current_user)
67
+ ):
68
+ """Get market analytics and trends"""
69
+ try:
70
+ # Mock market analytics
71
+ market_data = {}
72
+
73
+ for index in indices:
74
+ # In a real implementation, you'd fetch live market data
75
+ mock_data = {
76
+ "NIFTY": {
77
+ "current_value": 22150.50,
78
+ "change": 185.75,
79
+ "change_percent": 0.85,
80
+ "volume": 156789000,
81
+ "market_cap": 28500000000000,
82
+ "pe_ratio": 22.5,
83
+ "dividend_yield": 1.25
84
+ },
85
+ "SENSEX": {
86
+ "current_value": 72890.25,
87
+ "change": 425.80,
88
+ "change_percent": 0.59,
89
+ "volume": 98456000,
90
+ "market_cap": 35600000000000,
91
+ "pe_ratio": 24.8,
92
+ "dividend_yield": 1.15
93
+ }
94
+ }
95
+
96
+ market_data[index] = mock_data.get(index, {
97
+ "current_value": 0,
98
+ "change": 0,
99
+ "change_percent": 0,
100
+ "volume": 0
101
+ })
102
+
103
+ return {
104
+ "market_data": market_data,
105
+ "market_sentiment": {
106
+ "overall": "Bullish",
107
+ "fear_greed_index": 72,
108
+ "volatility_index": 18.5,
109
+ "advance_decline_ratio": 1.8
110
+ },
111
+ "sector_trends": [
112
+ {"sector": "IT", "trend": "Bullish", "change_percent": 2.1},
113
+ {"sector": "Banking", "trend": "Bearish", "change_percent": -0.8},
114
+ {"sector": "Auto", "trend": "Neutral", "change_percent": 0.3}
115
+ ],
116
+ "timestamp": datetime.utcnow().isoformat()
117
+ }
118
+
119
+ except Exception as e:
120
+ logger.error(f"Error generating market analytics: {e}")
121
+ raise HTTPException(status_code=500, detail="Failed to generate market analytics")
122
+
123
+ @router.get("/trading-analytics")
124
+ async def get_trading_analytics(
125
+ period: str = Query("1M", description="Analysis period"),
126
+ current_user: Dict = Depends(get_current_user)
127
+ ):
128
+ """Get trading performance analytics"""
129
+ try:
130
+ # Get orders for the period
131
+ orders_result = await dhan_api_manager.get_orders()
132
+ orders = orders_result.get("data", [])
133
+
134
+ # Calculate trading analytics (mock implementation)
135
+ total_trades = len(orders)
136
+ successful_trades = len([order for order in orders if order.get("orderStatus") == "COMPLETE"])
137
+
138
+ return {
139
+ "trading_summary": {
140
+ "total_trades": total_trades,
141
+ "successful_trades": successful_trades,
142
+ "success_rate": (successful_trades / total_trades * 100) if total_trades > 0 else 0,
143
+ "win_rate": 65.5, # Mock data
144
+ "average_profit": 2.8, # Mock data
145
+ "average_loss": -1.5, # Mock data
146
+ "profit_factor": 1.85, # Mock data
147
+ "sharpe_ratio": 1.42 # Mock data
148
+ },
149
+ "trade_distribution": {
150
+ "by_type": {
151
+ "BUY": len([o for o in orders if o.get("transactionType") == "BUY"]),
152
+ "SELL": len([o for o in orders if o.get("transactionType") == "SELL"])
153
+ },
154
+ "by_product": {
155
+ "CNC": len([o for o in orders if o.get("productType") == "CNC"]),
156
+ "MIS": len([o for o in orders if o.get("productType") == "MIS"]),
157
+ "NRML": len([o for o in orders if o.get("productType") == "NRML"])
158
+ }
159
+ },
160
+ "monthly_performance": generate_monthly_performance(),
161
+ "period": period,
162
+ "timestamp": datetime.utcnow().isoformat()
163
+ }
164
+
165
+ except DhanAPIError as e:
166
+ logger.error(f"Dhan API error in trading analytics: {e}")
167
+ raise HTTPException(status_code=e.status_code or 500, detail=str(e))
168
+ except Exception as e:
169
+ logger.error(f"Error generating trading analytics: {e}")
170
+ raise HTTPException(status_code=500, detail="Failed to generate trading analytics")
171
+
172
+ @router.get("/risk-analytics")
173
+ async def get_risk_analytics(
174
+ current_user: Dict = Depends(get_current_user)
175
+ ):
176
+ """Get portfolio risk analytics"""
177
+ try:
178
+ # Get portfolio data
179
+ holdings_result = await dhan_api_manager.get_holdings()
180
+ holdings = holdings_result.get("data", [])
181
+
182
+ # Calculate risk metrics (mock implementation)
183
+ portfolio_value = sum(holding.get("currentValue", 0) for holding in holdings)
184
+
185
+ return {
186
+ "risk_metrics": {
187
+ "portfolio_beta": 1.15,
188
+ "value_at_risk_1d": portfolio_value * 0.02, # 2% VaR
189
+ "value_at_risk_5d": portfolio_value * 0.045, # 4.5% VaR
190
+ "maximum_drawdown": 12.5,
191
+ "volatility": 18.7,
192
+ "correlation_with_market": 0.85,
193
+ "concentration_risk": calculate_concentration_risk(holdings)
194
+ },
195
+ "stress_testing": {
196
+ "market_crash_scenario": {
197
+ "scenario": "20% market decline",
198
+ "portfolio_impact": portfolio_value * -0.23,
199
+ "impact_percentage": -23.0
200
+ },
201
+ "sector_shock": {
202
+ "scenario": "IT sector 15% decline",
203
+ "portfolio_impact": portfolio_value * -0.08,
204
+ "impact_percentage": -8.0
205
+ }
206
+ },
207
+ "risk_score": 7.2, # Out of 10
208
+ "risk_category": "Moderate-High",
209
+ "recommendations": [
210
+ "Consider diversifying into defensive sectors",
211
+ "Reduce exposure to high-beta stocks",
212
+ "Add some fixed-income instruments"
213
+ ],
214
+ "timestamp": datetime.utcnow().isoformat()
215
+ }
216
+
217
+ except DhanAPIError as e:
218
+ logger.error(f"Dhan API error in risk analytics: {e}")
219
+ raise HTTPException(status_code=e.status_code or 500, detail=str(e))
220
+ except Exception as e:
221
+ logger.error(f"Error generating risk analytics: {e}")
222
+ raise HTTPException(status_code=500, detail="Failed to generate risk analytics")
223
+
224
+ @router.get("/performance-report")
225
+ async def generate_performance_report(
226
+ start_date: str = Query(..., description="Start date (YYYY-MM-DD)"),
227
+ end_date: str = Query(..., description="End date (YYYY-MM-DD)"),
228
+ format: str = Query("json", description="Report format (json, pdf)"),
229
+ current_user: Dict = Depends(get_current_user)
230
+ ):
231
+ """Generate comprehensive performance report"""
232
+ try:
233
+ # Parse dates
234
+ start_dt = datetime.strptime(start_date, "%Y-%m-%d")
235
+ end_dt = datetime.strptime(end_date, "%Y-%m-%d")
236
+
237
+ if start_dt >= end_dt:
238
+ raise HTTPException(status_code=400, detail="Start date must be before end date")
239
+
240
+ # Get portfolio data
241
+ holdings_result = await dhan_api_manager.get_holdings()
242
+ orders_result = await dhan_api_manager.get_orders()
243
+
244
+ holdings = holdings_result.get("data", [])
245
+ orders = orders_result.get("data", [])
246
+
247
+ # Generate report
248
+ report = {
249
+ "report_period": {
250
+ "start_date": start_date,
251
+ "end_date": end_date,
252
+ "duration_days": (end_dt - start_dt).days
253
+ },
254
+ "portfolio_summary": {
255
+ "current_value": sum(h.get("currentValue", 0) for h in holdings),
256
+ "total_pnl": sum(h.get("pnl", 0) for h in holdings),
257
+ "number_of_holdings": len(holdings),
258
+ "number_of_trades": len(orders)
259
+ },
260
+ "performance_metrics": {
261
+ "absolute_return": 8.5, # Mock data
262
+ "annualized_return": 12.8,
263
+ "volatility": 18.5,
264
+ "sharpe_ratio": 1.42,
265
+ "max_drawdown": 15.2,
266
+ "calmar_ratio": 0.85
267
+ },
268
+ "benchmark_comparison": {
269
+ "portfolio_return": 8.5,
270
+ "nifty_return": 6.2,
271
+ "alpha": 2.3,
272
+ "beta": 1.15,
273
+ "outperformance": 2.3
274
+ },
275
+ "top_performers": get_top_performers(holdings, 5),
276
+ "worst_performers": get_worst_performers(holdings, 5),
277
+ "sector_allocation": calculate_sector_allocation(holdings),
278
+ "trading_summary": calculate_trading_summary(orders),
279
+ "generated_at": datetime.utcnow().isoformat()
280
+ }
281
+
282
+ if format.lower() == "pdf":
283
+ # In a real implementation, you'd generate a PDF
284
+ return {"message": "PDF generation not implemented", "report_data": report}
285
+
286
+ return report
287
+
288
+ except DhanAPIError as e:
289
+ logger.error(f"Dhan API error in performance report: {e}")
290
+ raise HTTPException(status_code=e.status_code or 500, detail=str(e))
291
+ except ValueError as e:
292
+ raise HTTPException(status_code=400, detail=f"Invalid date format: {e}")
293
+ except Exception as e:
294
+ logger.error(f"Error generating performance report: {e}")
295
+ raise HTTPException(status_code=500, detail="Failed to generate performance report")
296
+
297
+ @router.get("/tax-analytics")
298
+ async def get_tax_analytics(
299
+ financial_year: str = Query("2024-25", description="Financial year"),
300
+ current_user: Dict = Depends(get_current_user)
301
+ ):
302
+ """Get tax analytics and optimization suggestions"""
303
+ try:
304
+ # Get holdings and calculate tax implications
305
+ holdings_result = await dhan_api_manager.get_holdings()
306
+ holdings = holdings_result.get("data", [])
307
+
308
+ # Mock tax calculations
309
+ total_realized_gains = 125000 # Mock data
310
+ total_realized_losses = 45000
311
+ net_gains = total_realized_gains - total_realized_losses
312
+
313
+ return {
314
+ "tax_summary": {
315
+ "financial_year": financial_year,
316
+ "total_realized_gains": total_realized_gains,
317
+ "total_realized_losses": total_realized_losses,
318
+ "net_realized_gains": net_gains,
319
+ "tax_liability": calculate_tax_liability(net_gains),
320
+ "unrealized_gains": sum(h.get("pnl", 0) for h in holdings if h.get("pnl", 0) > 0),
321
+ "unrealized_losses": sum(h.get("pnl", 0) for h in holdings if h.get("pnl", 0) < 0)
322
+ },
323
+ "tax_optimization": {
324
+ "loss_harvesting_opportunities": identify_loss_harvesting(holdings),
325
+ "long_term_vs_short_term": {
326
+ "long_term_gains": 80000,
327
+ "short_term_gains": 45000,
328
+ "recommendation": "Consider holding short-term positions longer for LTCG benefits"
329
+ }
330
+ },
331
+ "recommendations": [
332
+ "Book profits in loss-making stocks before year-end",
333
+ "Consider SIP investments for tax-saving mutual funds",
334
+ "Review holding period for better tax efficiency"
335
+ ],
336
+ "timestamp": datetime.utcnow().isoformat()
337
+ }
338
+
339
+ except DhanAPIError as e:
340
+ logger.error(f"Dhan API error in tax analytics: {e}")
341
+ raise HTTPException(status_code=e.status_code or 500, detail=str(e))
342
+ except Exception as e:
343
+ logger.error(f"Error generating tax analytics: {e}")
344
+ raise HTTPException(status_code=500, detail="Failed to generate tax analytics")
345
+
346
+ # Helper functions
347
+ def get_best_performer(holdings):
348
+ """Get best performing holding"""
349
+ if not holdings:
350
+ return None
351
+
352
+ best = max(holdings, key=lambda h: h.get("pnlPercent", 0))
353
+ return {
354
+ "symbol": best.get("tradingSymbol"),
355
+ "pnl_percent": best.get("pnlPercent", 0),
356
+ "pnl_amount": best.get("pnl", 0)
357
+ }
358
+
359
+ def get_worst_performer(holdings):
360
+ """Get worst performing holding"""
361
+ if not holdings:
362
+ return None
363
+
364
+ worst = min(holdings, key=lambda h: h.get("pnlPercent", 0))
365
+ return {
366
+ "symbol": worst.get("tradingSymbol"),
367
+ "pnl_percent": worst.get("pnlPercent", 0),
368
+ "pnl_amount": worst.get("pnl", 0)
369
+ }
370
+
371
+ def calculate_sector_allocation(holdings):
372
+ """Calculate sector-wise allocation"""
373
+ # Mock implementation
374
+ return {
375
+ "IT": 35.5,
376
+ "Banking": 25.2,
377
+ "Pharmaceuticals": 15.8,
378
+ "Auto": 12.3,
379
+ "Oil & Gas": 11.2
380
+ }
381
+
382
+ def calculate_risk_metrics(holdings):
383
+ """Calculate portfolio risk metrics"""
384
+ return {
385
+ "beta": 1.15,
386
+ "volatility": 18.5,
387
+ "max_drawdown": 12.3,
388
+ "value_at_risk": 2.1
389
+ }
390
+
391
+ def calculate_performance_metrics(holdings, period):
392
+ """Calculate performance metrics"""
393
+ return {
394
+ "absolute_return": 8.5,
395
+ "annualized_return": 12.8,
396
+ "sharpe_ratio": 1.42,
397
+ "alpha": 2.3
398
+ }
399
+
400
+ def generate_monthly_performance():
401
+ """Generate mock monthly performance data"""
402
+ return [
403
+ {"month": "Jan", "return": 2.1},
404
+ {"month": "Feb", "return": -1.5},
405
+ {"month": "Mar", "return": 3.8},
406
+ {"month": "Apr", "return": 1.2},
407
+ {"month": "May", "return": -0.8},
408
+ {"month": "Jun", "return": 2.5}
409
+ ]
410
+
411
+ def calculate_concentration_risk(holdings):
412
+ """Calculate portfolio concentration risk"""
413
+ if not holdings:
414
+ return 0
415
+
416
+ total_value = sum(h.get("currentValue", 0) for h in holdings)
417
+ if total_value == 0:
418
+ return 0
419
+
420
+ # Calculate Herfindahl index
421
+ weights = [h.get("currentValue", 0) / total_value for h in holdings]
422
+ herfindahl = sum(w * w for w in weights)
423
+
424
+ return round(herfindahl * 100, 2)
425
+
426
+ def get_top_performers(holdings, count):
427
+ """Get top performing holdings"""
428
+ sorted_holdings = sorted(holdings, key=lambda h: h.get("pnlPercent", 0), reverse=True)
429
+ return [
430
+ {
431
+ "symbol": h.get("tradingSymbol"),
432
+ "pnl_percent": h.get("pnlPercent", 0),
433
+ "pnl_amount": h.get("pnl", 0)
434
+ }
435
+ for h in sorted_holdings[:count]
436
+ ]
437
+
438
+ def get_worst_performers(holdings, count):
439
+ """Get worst performing holdings"""
440
+ sorted_holdings = sorted(holdings, key=lambda h: h.get("pnlPercent", 0))
441
+ return [
442
+ {
443
+ "symbol": h.get("tradingSymbol"),
444
+ "pnl_percent": h.get("pnlPercent", 0),
445
+ "pnl_amount": h.get("pnl", 0)
446
+ }
447
+ for h in sorted_holdings[:count]
448
+ ]
449
+
450
+ def calculate_trading_summary(orders):
451
+ """Calculate trading summary"""
452
+ return {
453
+ "total_trades": len(orders),
454
+ "buy_orders": len([o for o in orders if o.get("transactionType") == "BUY"]),
455
+ "sell_orders": len([o for o in orders if o.get("transactionType") == "SELL"]),
456
+ "completed_orders": len([o for o in orders if o.get("orderStatus") == "COMPLETE"]),
457
+ "pending_orders": len([o for o in orders if o.get("orderStatus") == "PENDING"])
458
+ }
459
+
460
+ def calculate_tax_liability(net_gains):
461
+ """Calculate tax liability based on gains"""
462
+ if net_gains <= 100000: # LTCG exemption limit
463
+ return 0
464
+ else:
465
+ # 10% LTCG tax above 1 lakh
466
+ return (net_gains - 100000) * 0.10
467
+
468
+ def identify_loss_harvesting(holdings):
469
+ """Identify loss harvesting opportunities"""
470
+ loss_opportunities = []
471
+ for holding in holdings:
472
+ if holding.get("pnl", 0) < 0:
473
+ loss_opportunities.append({
474
+ "symbol": holding.get("tradingSymbol"),
475
+ "current_loss": holding.get("pnl", 0),
476
+ "tax_benefit": abs(holding.get("pnl", 0)) * 0.30 # Assuming 30% tax bracket
477
+ })
478
+
479
+ return loss_opportunities[:5] # Top 5 opportunities
routers/dhan.py CHANGED
@@ -23,6 +23,7 @@ from models.trading import (
23
  HistoricalDataResponse
24
  )
25
  from services.auth import get_current_user
 
26
  from utils.rate_limiter import RateLimiter
27
 
28
  logger = logging.getLogger(__name__)
 
23
  HistoricalDataResponse
24
  )
25
  from services.auth import get_current_user
26
+ from services.dhan_api_manager import dhan_api_manager, DhanAPIError
27
  from utils.rate_limiter import RateLimiter
28
 
29
  logger = logging.getLogger(__name__)
routers/dhan_new.py ADDED
@@ -0,0 +1,362 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Updated Dhan API Router
3
+ Handles all Dhan API operations with proper token assignment by API type
4
+ Uses the new DhanAPIManager for different API categories
5
+ """
6
+
7
+ from fastapi import APIRouter, HTTPException, Depends
8
+ from typing import List, Optional, Dict, Any
9
+ import logging
10
+ from datetime import datetime
11
+
12
+ from services.auth import get_current_user
13
+ from services.dhan_api_manager import dhan_api_manager, DhanAPIError, DhanAPIType
14
+ from models.trading import (
15
+ QuoteRequest,
16
+ QuoteResponse,
17
+ OrderRequest,
18
+ OrderResponse,
19
+ HoldingsResponse,
20
+ PositionsResponse,
21
+ FundsResponse,
22
+ HistoricalDataRequest,
23
+ HistoricalDataResponse
24
+ )
25
+
26
+ logger = logging.getLogger(__name__)
27
+ router = APIRouter()
28
+
29
+ # ============================================================================
30
+ # DATA API ENDPOINTS (uses DHAN_DATA_API_TOKEN)
31
+ # ============================================================================
32
+
33
+ @router.get("/securities", response_model=List[Dict[str, Any]])
34
+ async def get_securities(
35
+ exchange: Optional[str] = "NSE",
36
+ current_user: Dict = Depends(get_current_user)
37
+ ):
38
+ """Get all securities/instruments using Data API"""
39
+ try:
40
+ result = await dhan_api_manager.get_security_master()
41
+
42
+ # Filter by exchange if specified
43
+ if exchange and exchange != "ALL":
44
+ securities = result.get("data", [])
45
+ filtered = [s for s in securities if s.get("SEM_EXM_EXCH_ID") == exchange]
46
+ return filtered
47
+
48
+ return result.get("data", [])
49
+
50
+ except DhanAPIError as e:
51
+ logger.error(f"Dhan API error fetching securities: {e}")
52
+ raise HTTPException(status_code=e.status_code or 500, detail=str(e))
53
+ except Exception as e:
54
+ logger.error(f"Error fetching securities: {e}")
55
+ raise HTTPException(status_code=500, detail="Failed to fetch securities")
56
+
57
+ @router.post("/quotes", response_model=List[QuoteResponse])
58
+ async def get_quotes(
59
+ request: QuoteRequest,
60
+ current_user: Dict = Depends(get_current_user)
61
+ ):
62
+ """Get live quotes for multiple securities using Data API"""
63
+ try:
64
+ instruments = [
65
+ {
66
+ "securityId": instrument.security_id,
67
+ "exchangeSegment": instrument.exchange_segment
68
+ }
69
+ for instrument in request.instruments
70
+ ]
71
+
72
+ result = await dhan_api_manager.get_live_feed(instruments)
73
+ return result.get("data", [])
74
+
75
+ except DhanAPIError as e:
76
+ logger.error(f"Dhan API error fetching quotes: {e}")
77
+ raise HTTPException(status_code=e.status_code or 500, detail=str(e))
78
+ except Exception as e:
79
+ logger.error(f"Error fetching quotes: {e}")
80
+ raise HTTPException(status_code=500, detail="Failed to fetch quotes")
81
+
82
+ @router.get("/quote/{security_id}")
83
+ async def get_single_quote(
84
+ security_id: str,
85
+ exchange_segment: str = "NSE_EQ",
86
+ current_user: Dict = Depends(get_current_user)
87
+ ):
88
+ """Get detailed quote for a single security using Data API"""
89
+ try:
90
+ result = await dhan_api_manager.get_quote(security_id, exchange_segment)
91
+ return result.get("data", {})
92
+
93
+ except DhanAPIError as e:
94
+ logger.error(f"Dhan API error fetching single quote: {e}")
95
+ raise HTTPException(status_code=e.status_code or 500, detail=str(e))
96
+ except Exception as e:
97
+ logger.error(f"Error fetching single quote: {e}")
98
+ raise HTTPException(status_code=500, detail="Failed to fetch quote")
99
+
100
+ @router.post("/historical", response_model=HistoricalDataResponse)
101
+ async def get_historical_data(
102
+ request: HistoricalDataRequest,
103
+ current_user: Dict = Depends(get_current_user)
104
+ ):
105
+ """Get historical data for charting using Data API"""
106
+ try:
107
+ data = {
108
+ "securityId": request.security_id,
109
+ "exchangeSegment": request.exchange_segment,
110
+ "instrument": request.instrument,
111
+ "interval": request.interval,
112
+ "fromDate": request.from_date,
113
+ "toDate": request.to_date
114
+ }
115
+
116
+ result = await dhan_api_manager.get_historical_data(data)
117
+ return result.get("data", {})
118
+
119
+ except DhanAPIError as e:
120
+ logger.error(f"Dhan API error fetching historical data: {e}")
121
+ raise HTTPException(status_code=e.status_code or 500, detail=str(e))
122
+ except Exception as e:
123
+ logger.error(f"Error fetching historical data: {e}")
124
+ raise HTTPException(status_code=500, detail="Failed to fetch historical data")
125
+
126
+ @router.post("/intraday")
127
+ async def get_intraday_data(
128
+ request: HistoricalDataRequest,
129
+ current_user: Dict = Depends(get_current_user)
130
+ ):
131
+ """Get intraday data using Data API"""
132
+ try:
133
+ data = {
134
+ "securityId": request.security_id,
135
+ "exchangeSegment": request.exchange_segment,
136
+ "instrument": request.instrument,
137
+ "interval": request.interval,
138
+ }
139
+
140
+ result = await dhan_api_manager.get_intraday_data(data)
141
+ return result.get("data", {})
142
+
143
+ except DhanAPIError as e:
144
+ logger.error(f"Dhan API error fetching intraday data: {e}")
145
+ raise HTTPException(status_code=e.status_code or 500, detail=str(e))
146
+ except Exception as e:
147
+ logger.error(f"Error fetching intraday data: {e}")
148
+ raise HTTPException(status_code=500, detail="Failed to fetch intraday data")
149
+
150
+ @router.get("/optionchain/{underlying_security_id}")
151
+ async def get_option_chain(
152
+ underlying_security_id: str,
153
+ exchange_segment: str = "NSE_FNO",
154
+ current_user: Dict = Depends(get_current_user)
155
+ ):
156
+ """Get option chain data using Data API"""
157
+ try:
158
+ result = await dhan_api_manager.get_option_chain(underlying_security_id, exchange_segment)
159
+ return result.get("data", {})
160
+
161
+ except DhanAPIError as e:
162
+ logger.error(f"Dhan API error fetching option chain: {e}")
163
+ raise HTTPException(status_code=e.status_code or 500, detail=str(e))
164
+ except Exception as e:
165
+ logger.error(f"Error fetching option chain: {e}")
166
+ raise HTTPException(status_code=500, detail="Failed to fetch option chain")
167
+
168
+ # ============================================================================
169
+ # TRADING API ENDPOINTS (uses DHAN_TRADING_API_TOKEN or DHAN_PAPER_TRADING_TOKEN)
170
+ # ============================================================================
171
+
172
+ @router.post("/orders", response_model=OrderResponse)
173
+ async def place_order(
174
+ order: OrderRequest,
175
+ current_user: Dict = Depends(get_current_user)
176
+ ):
177
+ """Place a new order using Trading API (live or paper trading based on mode)"""
178
+ try:
179
+ order_data = {
180
+ "securityId": order.security_id,
181
+ "exchangeSegment": order.exchange_segment,
182
+ "transactionType": order.transaction_type,
183
+ "quantity": order.quantity,
184
+ "orderType": order.order_type,
185
+ "productType": order.product_type,
186
+ "price": order.price if order.price else 0,
187
+ "triggerPrice": order.trigger_price if order.trigger_price else 0,
188
+ "validity": order.validity,
189
+ "disclosedQuantity": order.disclosed_quantity if order.disclosed_quantity else 0,
190
+ "afterMarketOrder": order.after_market_order if order.after_market_order else False
191
+ }
192
+
193
+ result = await dhan_api_manager.place_order(order_data)
194
+ return result.get("data", {})
195
+
196
+ except DhanAPIError as e:
197
+ logger.error(f"Dhan API error placing order: {e}")
198
+ raise HTTPException(status_code=e.status_code or 500, detail=str(e))
199
+ except Exception as e:
200
+ logger.error(f"Error placing order: {e}")
201
+ raise HTTPException(status_code=500, detail="Failed to place order")
202
+
203
+ @router.get("/orders")
204
+ async def get_orders(current_user: Dict = Depends(get_current_user)):
205
+ """Get user orders using Trading API"""
206
+ try:
207
+ result = await dhan_api_manager.get_orders()
208
+ return {"orders": result.get("data", [])}
209
+
210
+ except DhanAPIError as e:
211
+ logger.error(f"Dhan API error fetching orders: {e}")
212
+ raise HTTPException(status_code=e.status_code or 500, detail=str(e))
213
+ except Exception as e:
214
+ logger.error(f"Error fetching orders: {e}")
215
+ raise HTTPException(status_code=500, detail="Failed to fetch orders")
216
+
217
+ @router.delete("/orders/{order_id}")
218
+ async def cancel_order(
219
+ order_id: str,
220
+ current_user: Dict = Depends(get_current_user)
221
+ ):
222
+ """Cancel an order using Trading API"""
223
+ try:
224
+ result = await dhan_api_manager.cancel_order(order_id)
225
+ return {"message": "Order cancelled successfully", "data": result.get("data", {})}
226
+
227
+ except DhanAPIError as e:
228
+ logger.error(f"Dhan API error cancelling order: {e}")
229
+ raise HTTPException(status_code=e.status_code or 500, detail=str(e))
230
+ except Exception as e:
231
+ logger.error(f"Error cancelling order: {e}")
232
+ raise HTTPException(status_code=500, detail="Failed to cancel order")
233
+
234
+ @router.put("/orders/{order_id}")
235
+ async def modify_order(
236
+ order_id: str,
237
+ modifications: Dict[str, Any],
238
+ current_user: Dict = Depends(get_current_user)
239
+ ):
240
+ """Modify an existing order using Trading API"""
241
+ try:
242
+ result = await dhan_api_manager.modify_order(order_id, modifications)
243
+ return {"message": "Order modified successfully", "data": result.get("data", {})}
244
+
245
+ except DhanAPIError as e:
246
+ logger.error(f"Dhan API error modifying order: {e}")
247
+ raise HTTPException(status_code=e.status_code or 500, detail=str(e))
248
+ except Exception as e:
249
+ logger.error(f"Error modifying order: {e}")
250
+ raise HTTPException(status_code=500, detail="Failed to modify order")
251
+
252
+ @router.get("/holdings", response_model=HoldingsResponse)
253
+ async def get_holdings(current_user: Dict = Depends(get_current_user)):
254
+ """Get user holdings using Trading API"""
255
+ try:
256
+ result = await dhan_api_manager.get_holdings()
257
+ return {"holdings": result.get("data", [])}
258
+
259
+ except DhanAPIError as e:
260
+ logger.error(f"Dhan API error fetching holdings: {e}")
261
+ raise HTTPException(status_code=e.status_code or 500, detail=str(e))
262
+ except Exception as e:
263
+ logger.error(f"Error fetching holdings: {e}")
264
+ raise HTTPException(status_code=500, detail="Failed to fetch holdings")
265
+
266
+ @router.get("/positions", response_model=PositionsResponse)
267
+ async def get_positions(current_user: Dict = Depends(get_current_user)):
268
+ """Get user positions using Trading API"""
269
+ try:
270
+ result = await dhan_api_manager.get_positions()
271
+ return {"positions": result.get("data", [])}
272
+
273
+ except DhanAPIError as e:
274
+ logger.error(f"Dhan API error fetching positions: {e}")
275
+ raise HTTPException(status_code=e.status_code or 500, detail=str(e))
276
+ except Exception as e:
277
+ logger.error(f"Error fetching positions: {e}")
278
+ raise HTTPException(status_code=500, detail="Failed to fetch positions")
279
+
280
+ @router.get("/funds", response_model=FundsResponse)
281
+ async def get_funds(current_user: Dict = Depends(get_current_user)):
282
+ """Get user funds and margin using Trading API"""
283
+ try:
284
+ result = await dhan_api_manager.get_funds()
285
+ return result.get("data", {})
286
+
287
+ except DhanAPIError as e:
288
+ logger.error(f"Dhan API error fetching funds: {e}")
289
+ raise HTTPException(status_code=e.status_code or 500, detail=str(e))
290
+ except Exception as e:
291
+ logger.error(f"Error fetching funds: {e}")
292
+ raise HTTPException(status_code=500, detail="Failed to fetch funds")
293
+
294
+ # ============================================================================
295
+ # CONFIGURATION AND STATUS ENDPOINTS
296
+ # ============================================================================
297
+
298
+ @router.get("/config")
299
+ async def get_api_config(current_user: Dict = Depends(get_current_user)):
300
+ """Get current API configuration and mode"""
301
+ try:
302
+ return {
303
+ "trading_mode": dhan_api_manager.get_current_mode(),
304
+ "is_paper_trading": dhan_api_manager.is_paper_trading(),
305
+ "base_url": dhan_api_manager.base_url,
306
+ "client_id": dhan_api_manager.get_client_id_for_api_type(
307
+ DhanAPIType.PAPER_TRADING if dhan_api_manager.is_paper_trading() else DhanAPIType.TRADING
308
+ ),
309
+ "websocket_credentials": dhan_api_manager.get_websocket_credentials(),
310
+ "timestamp": datetime.utcnow().isoformat()
311
+ }
312
+ except Exception as e:
313
+ logger.error(f"Error fetching API config: {e}")
314
+ raise HTTPException(status_code=500, detail="Failed to fetch API configuration")
315
+
316
+ @router.get("/market-status")
317
+ async def get_market_status(current_user: Dict = Depends(get_current_user)):
318
+ """Get current market status"""
319
+ try:
320
+ # This endpoint might not exist in Dhan API, so we'll create a mock response
321
+ # based on time and market hours
322
+ now = datetime.now()
323
+ is_weekend = now.weekday() >= 5 # Saturday = 5, Sunday = 6
324
+
325
+ # NSE trading hours: 9:15 AM to 3:30 PM
326
+ market_open_time = now.replace(hour=9, minute=15, second=0, microsecond=0)
327
+ market_close_time = now.replace(hour=15, minute=30, second=0, microsecond=0)
328
+
329
+ is_market_hours = market_open_time <= now <= market_close_time and not is_weekend
330
+
331
+ return {
332
+ "market_status": "OPEN" if is_market_hours else "CLOSED",
333
+ "timestamp": now.isoformat(),
334
+ "next_open": market_open_time.isoformat() if not is_market_hours else None,
335
+ "next_close": market_close_time.isoformat() if is_market_hours else None,
336
+ "trading_mode": dhan_api_manager.get_current_mode()
337
+ }
338
+
339
+ except Exception as e:
340
+ logger.error(f"Error getting market status: {e}")
341
+ raise HTTPException(status_code=500, detail="Failed to get market status")
342
+
343
+ @router.post("/switch-mode")
344
+ async def switch_trading_mode(
345
+ mode: str,
346
+ current_user: Dict = Depends(get_current_user)
347
+ ):
348
+ """Switch between live and paper trading modes (requires restart)"""
349
+ try:
350
+ if mode not in ["live", "paper"]:
351
+ raise HTTPException(status_code=400, detail="Mode must be 'live' or 'paper'")
352
+
353
+ return {
354
+ "message": f"Trading mode switch to '{mode}' requested. Please restart the application to apply changes.",
355
+ "current_mode": dhan_api_manager.get_current_mode(),
356
+ "requested_mode": mode,
357
+ "note": "This requires updating the DHAN_TRADING_MODE environment variable"
358
+ }
359
+
360
+ except Exception as e:
361
+ logger.error(f"Error switching trading mode: {e}")
362
+ raise HTTPException(status_code=500, detail="Failed to switch trading mode")
routers/screener.py ADDED
@@ -0,0 +1,402 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Market Screener Router
3
+ Handles stock screening, filtering, and discovery
4
+ """
5
+
6
+ from fastapi import APIRouter, HTTPException, Depends, Query
7
+ from typing import List, Dict, Any, Optional
8
+ import logging
9
+ from datetime import datetime, timedelta
10
+
11
+ from config.settings import get_settings
12
+ from services.auth import get_current_user
13
+ from services.dhan_api_manager import dhan_api_manager, DhanAPIError
14
+ from models.trading import ExchangeSegment
15
+
16
+ logger = logging.getLogger(__name__)
17
+ router = APIRouter()
18
+ settings = get_settings()
19
+
20
+ @router.get("/stocks")
21
+ async def screen_stocks(
22
+ exchange: str = Query("NSE", description="Exchange segment"),
23
+ market_cap_min: Optional[float] = Query(None, description="Minimum market cap"),
24
+ market_cap_max: Optional[float] = Query(None, description="Maximum market cap"),
25
+ price_min: Optional[float] = Query(None, description="Minimum price"),
26
+ price_max: Optional[float] = Query(None, description="Maximum price"),
27
+ volume_min: Optional[int] = Query(None, description="Minimum volume"),
28
+ pe_ratio_min: Optional[float] = Query(None, description="Minimum P/E ratio"),
29
+ pe_ratio_max: Optional[float] = Query(None, description="Maximum P/E ratio"),
30
+ sector: Optional[str] = Query(None, description="Sector filter"),
31
+ limit: int = Query(50, description="Maximum results to return"),
32
+ current_user: Dict = Depends(get_current_user)
33
+ ):
34
+ """Screen stocks based on various criteria"""
35
+ try:
36
+ # Get securities from Dhan API
37
+ securities_result = await dhan_api_manager.get_security_master()
38
+ securities = securities_result.get("data", [])
39
+
40
+ # Filter by exchange
41
+ if exchange != "ALL":
42
+ securities = [s for s in securities if s.get("SEM_EXM_EXCH_ID") == exchange]
43
+
44
+ # Apply filters (this is a simplified implementation)
45
+ # In a production system, you'd want to get live market data for each security
46
+ filtered_securities = []
47
+
48
+ for security in securities[:limit]: # Limit initial processing
49
+ try:
50
+ # Basic filtering based on security master data
51
+ if security.get("SEM_INSTRUMENT_NAME") == "EQUITY":
52
+ security_data = {
53
+ "security_id": security.get("SEM_SMST_SECURITY_ID"),
54
+ "symbol": security.get("SEM_TRADING_SYMBOL"),
55
+ "instrument_name": security.get("SEM_INSTRUMENT_NAME"),
56
+ "exchange": security.get("SEM_EXM_EXCH_ID"),
57
+ "lot_size": security.get("SEM_LOT_UNITS", 1),
58
+ "tick_size": security.get("SEM_TICK_SIZE"),
59
+ "series": security.get("SEM_SERIES"),
60
+ "custom_symbol": security.get("SEM_CUSTOM_SYMBOL")
61
+ }
62
+
63
+ # Add to results if it passes basic criteria
64
+ if security_data["symbol"]:
65
+ filtered_securities.append(security_data)
66
+
67
+ except Exception as e:
68
+ logger.warning(f"Error processing security {security}: {e}")
69
+ continue
70
+
71
+ return {
72
+ "securities": filtered_securities[:limit],
73
+ "total_count": len(filtered_securities),
74
+ "filters_applied": {
75
+ "exchange": exchange,
76
+ "market_cap_min": market_cap_min,
77
+ "market_cap_max": market_cap_max,
78
+ "price_min": price_min,
79
+ "price_max": price_max,
80
+ "volume_min": volume_min,
81
+ "pe_ratio_min": pe_ratio_min,
82
+ "pe_ratio_max": pe_ratio_max,
83
+ "sector": sector
84
+ },
85
+ "timestamp": datetime.utcnow().isoformat()
86
+ }
87
+
88
+ except DhanAPIError as e:
89
+ logger.error(f"Dhan API error in stock screening: {e}")
90
+ raise HTTPException(status_code=e.status_code or 500, detail=str(e))
91
+ except Exception as e:
92
+ logger.error(f"Error screening stocks: {e}")
93
+ raise HTTPException(status_code=500, detail="Failed to screen stocks")
94
+
95
+ @router.get("/top-gainers")
96
+ async def get_top_gainers(
97
+ exchange: str = Query("NSE_EQ", description="Exchange segment"),
98
+ limit: int = Query(20, description="Number of top gainers to return"),
99
+ current_user: Dict = Depends(get_current_user)
100
+ ):
101
+ """Get top gaining stocks"""
102
+ try:
103
+ # This is a mock implementation
104
+ # In a real system, you'd fetch live market data and sort by gains
105
+ mock_gainers = [
106
+ {
107
+ "symbol": "RELIANCE",
108
+ "security_id": "2885",
109
+ "exchange_segment": "NSE_EQ",
110
+ "ltp": 2650.0,
111
+ "change": 45.50,
112
+ "change_percent": 1.75,
113
+ "volume": 2500000
114
+ },
115
+ {
116
+ "symbol": "TCS",
117
+ "security_id": "11723",
118
+ "exchange_segment": "NSE_EQ",
119
+ "ltp": 3890.0,
120
+ "change": 55.25,
121
+ "change_percent": 1.44,
122
+ "volume": 1800000
123
+ }
124
+ ]
125
+
126
+ return {
127
+ "top_gainers": mock_gainers[:limit],
128
+ "exchange": exchange,
129
+ "timestamp": datetime.utcnow().isoformat()
130
+ }
131
+
132
+ except Exception as e:
133
+ logger.error(f"Error fetching top gainers: {e}")
134
+ raise HTTPException(status_code=500, detail="Failed to fetch top gainers")
135
+
136
+ @router.get("/top-losers")
137
+ async def get_top_losers(
138
+ exchange: str = Query("NSE_EQ", description="Exchange segment"),
139
+ limit: int = Query(20, description="Number of top losers to return"),
140
+ current_user: Dict = Depends(get_current_user)
141
+ ):
142
+ """Get top losing stocks"""
143
+ try:
144
+ # This is a mock implementation
145
+ mock_losers = [
146
+ {
147
+ "symbol": "HDFC",
148
+ "security_id": "1333",
149
+ "exchange_segment": "NSE_EQ",
150
+ "ltp": 1580.0,
151
+ "change": -25.75,
152
+ "change_percent": -1.60,
153
+ "volume": 3200000
154
+ },
155
+ {
156
+ "symbol": "INFY",
157
+ "security_id": "1594",
158
+ "exchange_segment": "NSE_EQ",
159
+ "ltp": 1685.0,
160
+ "change": -18.50,
161
+ "change_percent": -1.09,
162
+ "volume": 2100000
163
+ }
164
+ ]
165
+
166
+ return {
167
+ "top_losers": mock_losers[:limit],
168
+ "exchange": exchange,
169
+ "timestamp": datetime.utcnow().isoformat()
170
+ }
171
+
172
+ except Exception as e:
173
+ logger.error(f"Error fetching top losers: {e}")
174
+ raise HTTPException(status_code=500, detail="Failed to fetch top losers")
175
+
176
+ @router.get("/most-active")
177
+ async def get_most_active(
178
+ exchange: str = Query("NSE_EQ", description="Exchange segment"),
179
+ limit: int = Query(20, description="Number of most active stocks to return"),
180
+ current_user: Dict = Depends(get_current_user)
181
+ ):
182
+ """Get most active stocks by volume"""
183
+ try:
184
+ # This is a mock implementation
185
+ mock_active = [
186
+ {
187
+ "symbol": "RELIANCE",
188
+ "security_id": "2885",
189
+ "exchange_segment": "NSE_EQ",
190
+ "ltp": 2650.0,
191
+ "volume": 5500000,
192
+ "value": 14575000000,
193
+ "change_percent": 1.75
194
+ },
195
+ {
196
+ "symbol": "HDFC",
197
+ "security_id": "1333",
198
+ "exchange_segment": "NSE_EQ",
199
+ "ltp": 1580.0,
200
+ "volume": 4200000,
201
+ "value": 6636000000,
202
+ "change_percent": -1.60
203
+ }
204
+ ]
205
+
206
+ return {
207
+ "most_active": mock_active[:limit],
208
+ "exchange": exchange,
209
+ "timestamp": datetime.utcnow().isoformat()
210
+ }
211
+
212
+ except Exception as e:
213
+ logger.error(f"Error fetching most active stocks: {e}")
214
+ raise HTTPException(status_code=500, detail="Failed to fetch most active stocks")
215
+
216
+ @router.get("/technical-scan")
217
+ async def technical_scan(
218
+ pattern: str = Query("bullish_engulfing", description="Technical pattern to scan"),
219
+ exchange: str = Query("NSE_EQ", description="Exchange segment"),
220
+ timeframe: str = Query("1D", description="Timeframe for analysis"),
221
+ limit: int = Query(50, description="Maximum results"),
222
+ current_user: Dict = Depends(get_current_user)
223
+ ):
224
+ """Scan stocks based on technical patterns"""
225
+ try:
226
+ # This is a mock implementation
227
+ # In a real system, you'd analyze historical data for technical patterns
228
+ mock_results = [
229
+ {
230
+ "symbol": "TCS",
231
+ "security_id": "11723",
232
+ "exchange_segment": "NSE_EQ",
233
+ "pattern": pattern,
234
+ "confidence": 0.85,
235
+ "ltp": 3890.0,
236
+ "signal": "BUY",
237
+ "target": 4100.0,
238
+ "stop_loss": 3750.0
239
+ },
240
+ {
241
+ "symbol": "INFY",
242
+ "security_id": "1594",
243
+ "exchange_segment": "NSE_EQ",
244
+ "pattern": pattern,
245
+ "confidence": 0.78,
246
+ "ltp": 1685.0,
247
+ "signal": "BUY",
248
+ "target": 1780.0,
249
+ "stop_loss": 1620.0
250
+ }
251
+ ]
252
+
253
+ return {
254
+ "scan_results": mock_results[:limit],
255
+ "pattern": pattern,
256
+ "exchange": exchange,
257
+ "timeframe": timeframe,
258
+ "timestamp": datetime.utcnow().isoformat()
259
+ }
260
+
261
+ except Exception as e:
262
+ logger.error(f"Error in technical scan: {e}")
263
+ raise HTTPException(status_code=500, detail="Failed to perform technical scan")
264
+
265
+ @router.get("/sector-performance")
266
+ async def get_sector_performance(
267
+ current_user: Dict = Depends(get_current_user)
268
+ ):
269
+ """Get sector-wise performance"""
270
+ try:
271
+ # This is a mock implementation
272
+ mock_sectors = [
273
+ {
274
+ "sector": "Information Technology",
275
+ "change_percent": 2.15,
276
+ "market_cap": 12500000000000,
277
+ "top_stocks": ["TCS", "INFY", "WIPRO"],
278
+ "stocks_advanced": 15,
279
+ "stocks_declined": 5
280
+ },
281
+ {
282
+ "sector": "Banking",
283
+ "change_percent": -0.85,
284
+ "market_cap": 8900000000000,
285
+ "top_stocks": ["HDFC", "ICICI", "SBI"],
286
+ "stocks_advanced": 8,
287
+ "stocks_declined": 12
288
+ },
289
+ {
290
+ "sector": "Oil & Gas",
291
+ "change_percent": 1.65,
292
+ "market_cap": 6700000000000,
293
+ "top_stocks": ["RELIANCE", "ONGC", "IOC"],
294
+ "stocks_advanced": 12,
295
+ "stocks_declined": 8
296
+ }
297
+ ]
298
+
299
+ return {
300
+ "sector_performance": mock_sectors,
301
+ "timestamp": datetime.utcnow().isoformat()
302
+ }
303
+
304
+ except Exception as e:
305
+ logger.error(f"Error fetching sector performance: {e}")
306
+ raise HTTPException(status_code=500, detail="Failed to fetch sector performance")
307
+
308
+ @router.get("/custom-scan")
309
+ async def custom_scan(
310
+ criteria: str = Query(..., description="Custom scan criteria as JSON string"),
311
+ current_user: Dict = Depends(get_current_user)
312
+ ):
313
+ """Run custom stock scan with user-defined criteria"""
314
+ try:
315
+ import json
316
+
317
+ # Parse criteria
318
+ try:
319
+ scan_criteria = json.loads(criteria)
320
+ except json.JSONDecodeError:
321
+ raise HTTPException(status_code=400, detail="Invalid criteria format")
322
+
323
+ # This is a mock implementation
324
+ # In a real system, you'd apply the custom criteria to screen stocks
325
+ mock_results = [
326
+ {
327
+ "symbol": "RELIANCE",
328
+ "security_id": "2885",
329
+ "exchange_segment": "NSE_EQ",
330
+ "matches": ["price_above_sma_50", "rsi_oversold"],
331
+ "score": 0.82
332
+ }
333
+ ]
334
+
335
+ return {
336
+ "scan_results": mock_results,
337
+ "criteria": scan_criteria,
338
+ "total_matches": len(mock_results),
339
+ "timestamp": datetime.utcnow().isoformat()
340
+ }
341
+
342
+ except HTTPException:
343
+ raise
344
+ except Exception as e:
345
+ logger.error(f"Error in custom scan: {e}")
346
+ raise HTTPException(status_code=500, detail="Failed to run custom scan")
347
+
348
+ @router.get("/search")
349
+ async def search_securities(
350
+ query: str = Query(..., description="Search query"),
351
+ exchange: Optional[str] = Query(None, description="Exchange filter"),
352
+ instrument_type: Optional[str] = Query(None, description="Instrument type filter"),
353
+ limit: int = Query(20, description="Maximum results"),
354
+ current_user: Dict = Depends(get_current_user)
355
+ ):
356
+ """Search securities by symbol or name"""
357
+ try:
358
+ # Get securities from Dhan API
359
+ securities_result = await dhan_api_manager.get_security_master()
360
+ securities = securities_result.get("data", [])
361
+
362
+ # Filter and search
363
+ query_lower = query.lower()
364
+ matching_securities = []
365
+
366
+ for security in securities:
367
+ symbol = security.get("SEM_TRADING_SYMBOL", "").lower()
368
+ custom_symbol = security.get("SEM_CUSTOM_SYMBOL", "").lower()
369
+
370
+ if (query_lower in symbol or query_lower in custom_symbol):
371
+ # Apply additional filters
372
+ if exchange and security.get("SEM_EXM_EXCH_ID") != exchange:
373
+ continue
374
+ if instrument_type and security.get("SEM_INSTRUMENT_NAME") != instrument_type:
375
+ continue
376
+
377
+ matching_securities.append({
378
+ "security_id": security.get("SEM_SMST_SECURITY_ID"),
379
+ "symbol": security.get("SEM_TRADING_SYMBOL"),
380
+ "instrument_name": security.get("SEM_INSTRUMENT_NAME"),
381
+ "exchange": security.get("SEM_EXM_EXCH_ID"),
382
+ "custom_symbol": security.get("SEM_CUSTOM_SYMBOL"),
383
+ "series": security.get("SEM_SERIES"),
384
+ "lot_size": security.get("SEM_LOT_UNITS", 1)
385
+ })
386
+
387
+ if len(matching_securities) >= limit:
388
+ break
389
+
390
+ return {
391
+ "results": matching_securities,
392
+ "query": query,
393
+ "total_found": len(matching_securities),
394
+ "timestamp": datetime.utcnow().isoformat()
395
+ }
396
+
397
+ except DhanAPIError as e:
398
+ logger.error(f"Dhan API error in search: {e}")
399
+ raise HTTPException(status_code=e.status_code or 500, detail=str(e))
400
+ except Exception as e:
401
+ logger.error(f"Error searching securities: {e}")
402
+ raise HTTPException(status_code=500, detail="Failed to search securities")
services/ai_engine.py ADDED
@@ -0,0 +1,520 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ AI Engine Service
3
+ Handles AI-powered trading signal generation, market analysis, and ML models
4
+ """
5
+
6
+ import logging
7
+ from typing import Dict, List, Any, Optional
8
+ from datetime import datetime, timedelta
9
+ import asyncio
10
+ import random
11
+
12
+ from models.signals import (
13
+ SignalResponse, SignalType, SignalStrength, TimeFrame,
14
+ TechnicalIndicator, FundamentalMetric, SentimentIndicator,
15
+ AISignal, MarketAnalysisResponse, BacktestResult
16
+ )
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+ class AIEngine:
21
+ """Service for AI-powered trading analysis and signal generation"""
22
+
23
+ def __init__(self):
24
+ self.models = {
25
+ "technical_model": {"accuracy": 0.72, "confidence": 0.85},
26
+ "fundamental_model": {"accuracy": 0.68, "confidence": 0.78},
27
+ "sentiment_model": {"accuracy": 0.65, "confidence": 0.70},
28
+ "ensemble_model": {"accuracy": 0.76, "confidence": 0.82}
29
+ }
30
+
31
+ async def generate_trading_signals(
32
+ self,
33
+ market_data: List[Dict[str, Any]],
34
+ user_preferences: Optional[Dict[str, Any]] = None
35
+ ) -> List[SignalResponse]:
36
+ """Generate AI trading signals for given market data"""
37
+ try:
38
+ signals = []
39
+
40
+ for data in market_data:
41
+ signal = await self._generate_single_signal(data, user_preferences)
42
+ if signal:
43
+ signals.append(signal)
44
+
45
+ return signals
46
+
47
+ except Exception as e:
48
+ logger.error(f"Error generating trading signals: {e}")
49
+ raise
50
+
51
+ async def _generate_single_signal(
52
+ self,
53
+ market_data: Dict[str, Any],
54
+ user_preferences: Optional[Dict[str, Any]] = None
55
+ ) -> Optional[SignalResponse]:
56
+ """Generate a single trading signal"""
57
+ try:
58
+ symbol = market_data.get("symbol", "UNKNOWN")
59
+ current_price = market_data.get("ltp", 0)
60
+
61
+ # Technical Analysis
62
+ technical_indicators = self._analyze_technical_indicators(market_data)
63
+
64
+ # Fundamental Analysis
65
+ fundamental_metrics = self._analyze_fundamentals(market_data)
66
+
67
+ # Sentiment Analysis
68
+ sentiment_indicators = self._analyze_sentiment(symbol)
69
+
70
+ # AI Model Predictions
71
+ ai_predictions = self._run_ai_models(market_data)
72
+
73
+ # Combine all analyses to generate final signal
74
+ final_signal = self._combine_analyses(
75
+ technical_indicators, fundamental_metrics,
76
+ sentiment_indicators, ai_predictions
77
+ )
78
+
79
+ if not final_signal:
80
+ return None
81
+
82
+ # Calculate target and stop loss
83
+ target_price, stop_loss = self._calculate_price_targets(
84
+ current_price, final_signal["signal_type"], final_signal["strength"]
85
+ )
86
+
87
+ return SignalResponse(
88
+ signal_id=f"SIG_{symbol}_{int(datetime.utcnow().timestamp())}",
89
+ security_id=market_data.get("security_id", ""),
90
+ symbol=symbol,
91
+ exchange_segment=market_data.get("exchange_segment", "NSE_EQ"),
92
+ signal_type=final_signal["signal_type"],
93
+ signal_strength=final_signal["strength"],
94
+ confidence_score=final_signal["confidence"],
95
+ time_frame=TimeFrame.SHORT_TERM,
96
+ current_price=current_price,
97
+ target_price=target_price,
98
+ stop_loss=stop_loss,
99
+ technical_analysis=technical_indicators,
100
+ fundamental_analysis=fundamental_metrics,
101
+ sentiment_analysis=sentiment_indicators,
102
+ ai_analysis=ai_predictions,
103
+ risk_rating=self._calculate_risk_rating(final_signal["confidence"]),
104
+ expected_return=self._calculate_expected_return(current_price, target_price),
105
+ probability_of_success=final_signal["confidence"],
106
+ market_condition="bullish", # Mock
107
+ generated_at=datetime.utcnow(),
108
+ valid_until=datetime.utcnow() + timedelta(days=7),
109
+ summary=self._generate_signal_summary(final_signal, symbol),
110
+ reasoning=self._generate_reasoning(technical_indicators, fundamental_metrics),
111
+ key_catalysts=self._identify_key_catalysts(symbol)
112
+ )
113
+
114
+ except Exception as e:
115
+ logger.error(f"Error generating signal for {symbol}: {e}")
116
+ return None
117
+
118
+ def _analyze_technical_indicators(self, market_data: Dict[str, Any]) -> List[TechnicalIndicator]:
119
+ """Analyze technical indicators"""
120
+ indicators = []
121
+
122
+ # Mock RSI analysis
123
+ rsi_value = random.uniform(30, 70)
124
+ rsi_signal = SignalType.BUY if rsi_value < 40 else SignalType.SELL if rsi_value > 60 else SignalType.HOLD
125
+ rsi_strength = SignalStrength.STRONG if abs(rsi_value - 50) > 20 else SignalStrength.MODERATE
126
+
127
+ indicators.append(TechnicalIndicator(
128
+ name="RSI",
129
+ value=rsi_value,
130
+ signal=rsi_signal,
131
+ strength=rsi_strength,
132
+ description=f"RSI at {rsi_value:.1f} indicates {'oversold' if rsi_value < 40 else 'overbought' if rsi_value > 60 else 'neutral'} conditions"
133
+ ))
134
+
135
+ # Mock MACD analysis
136
+ macd_value = random.uniform(-5, 5)
137
+ macd_signal = SignalType.BUY if macd_value > 0 else SignalType.SELL
138
+ macd_strength = SignalStrength.STRONG if abs(macd_value) > 3 else SignalStrength.MODERATE
139
+
140
+ indicators.append(TechnicalIndicator(
141
+ name="MACD",
142
+ value=macd_value,
143
+ signal=macd_signal,
144
+ strength=macd_strength,
145
+ description=f"MACD at {macd_value:.2f} shows {'bullish' if macd_value > 0 else 'bearish'} momentum"
146
+ ))
147
+
148
+ # Mock Moving Average analysis
149
+ ma_signal = random.choice([SignalType.BUY, SignalType.SELL, SignalType.HOLD])
150
+ ma_strength = random.choice([SignalStrength.WEAK, SignalStrength.MODERATE, SignalStrength.STRONG])
151
+
152
+ indicators.append(TechnicalIndicator(
153
+ name="SMA_50",
154
+ value=market_data.get("ltp", 0) * random.uniform(0.95, 1.05),
155
+ signal=ma_signal,
156
+ strength=ma_strength,
157
+ description="Price action relative to 50-day moving average"
158
+ ))
159
+
160
+ return indicators
161
+
162
+ def _analyze_fundamentals(self, market_data: Dict[str, Any]) -> List[FundamentalMetric]:
163
+ """Analyze fundamental metrics"""
164
+ metrics = []
165
+
166
+ # Mock P/E ratio analysis
167
+ pe_ratio = random.uniform(15, 35)
168
+ metrics.append(FundamentalMetric(
169
+ metric_name="P/E Ratio",
170
+ current_value=pe_ratio,
171
+ industry_average=22.5,
172
+ percentile_rank=random.uniform(0.2, 0.8),
173
+ trend="improving" if pe_ratio < 25 else "deteriorating",
174
+ impact="positive" if pe_ratio < 20 else "negative"
175
+ ))
176
+
177
+ # Mock ROE analysis
178
+ roe = random.uniform(10, 25)
179
+ metrics.append(FundamentalMetric(
180
+ metric_name="ROE",
181
+ current_value=roe,
182
+ industry_average=18.0,
183
+ percentile_rank=random.uniform(0.3, 0.9),
184
+ trend="stable",
185
+ impact="positive" if roe > 15 else "neutral"
186
+ ))
187
+
188
+ return metrics
189
+
190
+ def _analyze_sentiment(self, symbol: str) -> List[SentimentIndicator]:
191
+ """Analyze market sentiment"""
192
+ indicators = []
193
+
194
+ # Mock news sentiment
195
+ news_score = random.uniform(-0.5, 0.8)
196
+ indicators.append(SentimentIndicator(
197
+ source="News Analysis",
198
+ score=news_score,
199
+ confidence=random.uniform(0.6, 0.9),
200
+ key_factors=["earnings report", "sector outlook", "market conditions"],
201
+ news_count=random.randint(10, 50)
202
+ ))
203
+
204
+ # Mock social media sentiment
205
+ social_score = random.uniform(-0.3, 0.6)
206
+ indicators.append(SentimentIndicator(
207
+ source="Social Media",
208
+ score=social_score,
209
+ confidence=random.uniform(0.5, 0.8),
210
+ key_factors=["trader discussions", "volume alerts", "price action"],
211
+ news_count=random.randint(100, 500)
212
+ ))
213
+
214
+ return indicators
215
+
216
+ def _run_ai_models(self, market_data: Dict[str, Any]) -> List[AISignal]:
217
+ """Run AI/ML models for predictions"""
218
+ predictions = []
219
+
220
+ for model_name, model_info in self.models.items():
221
+ prediction = random.choice([SignalType.BUY, SignalType.SELL, SignalType.HOLD])
222
+ confidence = random.uniform(0.6, model_info["confidence"])
223
+
224
+ predictions.append(AISignal(
225
+ model_name=model_name,
226
+ prediction=prediction,
227
+ confidence_score=confidence,
228
+ feature_importance={
229
+ "price_momentum": random.uniform(0.1, 0.3),
230
+ "volume_profile": random.uniform(0.1, 0.3),
231
+ "market_correlation": random.uniform(0.1, 0.3),
232
+ "volatility": random.uniform(0.1, 0.3)
233
+ },
234
+ prediction_horizon="5-7 days"
235
+ ))
236
+
237
+ return predictions
238
+
239
+ def _combine_analyses(
240
+ self,
241
+ technical: List[TechnicalIndicator],
242
+ fundamental: List[FundamentalMetric],
243
+ sentiment: List[SentimentIndicator],
244
+ ai_predictions: List[AISignal]
245
+ ) -> Optional[Dict[str, Any]]:
246
+ """Combine all analyses to generate final signal"""
247
+
248
+ # Weight different analysis types
249
+ weights = {
250
+ "technical": 0.4,
251
+ "fundamental": 0.3,
252
+ "sentiment": 0.1,
253
+ "ai": 0.2
254
+ }
255
+
256
+ # Calculate weighted scores for each signal type
257
+ buy_score = 0
258
+ sell_score = 0
259
+ hold_score = 0
260
+
261
+ # Technical indicators
262
+ for indicator in technical:
263
+ score = 1.0 if indicator.strength == SignalStrength.STRONG else 0.7 if indicator.strength == SignalStrength.MODERATE else 0.3
264
+ if indicator.signal == SignalType.BUY:
265
+ buy_score += score * weights["technical"] / len(technical)
266
+ elif indicator.signal == SignalType.SELL:
267
+ sell_score += score * weights["technical"] / len(technical)
268
+ else:
269
+ hold_score += score * weights["technical"] / len(technical)
270
+
271
+ # AI predictions
272
+ for prediction in ai_predictions:
273
+ score = prediction.confidence_score
274
+ if prediction.prediction == SignalType.BUY:
275
+ buy_score += score * weights["ai"] / len(ai_predictions)
276
+ elif prediction.prediction == SignalType.SELL:
277
+ sell_score += score * weights["ai"] / len(ai_predictions)
278
+ else:
279
+ hold_score += score * weights["ai"] / len(ai_predictions)
280
+
281
+ # Sentiment analysis
282
+ avg_sentiment = sum(s.score for s in sentiment) / len(sentiment) if sentiment else 0
283
+ if avg_sentiment > 0.2:
284
+ buy_score += weights["sentiment"]
285
+ elif avg_sentiment < -0.2:
286
+ sell_score += weights["sentiment"]
287
+ else:
288
+ hold_score += weights["sentiment"]
289
+
290
+ # Fundamental contribution (simplified)
291
+ buy_score += weights["fundamental"] * 0.5 # Mock positive fundamentals
292
+
293
+ # Determine final signal
294
+ max_score = max(buy_score, sell_score, hold_score)
295
+
296
+ if max_score < 0.4: # Threshold for signal generation
297
+ return None
298
+
299
+ if buy_score == max_score:
300
+ signal_type = SignalType.STRONG_BUY if max_score > 0.8 else SignalType.BUY
301
+ elif sell_score == max_score:
302
+ signal_type = SignalType.STRONG_SELL if max_score > 0.8 else SignalType.SELL
303
+ else:
304
+ signal_type = SignalType.HOLD
305
+
306
+ strength = SignalStrength.STRONG if max_score > 0.8 else SignalStrength.MODERATE if max_score > 0.6 else SignalStrength.WEAK
307
+
308
+ return {
309
+ "signal_type": signal_type,
310
+ "strength": strength,
311
+ "confidence": max_score
312
+ }
313
+
314
+ def _calculate_price_targets(self, current_price: float, signal_type: SignalType, strength: SignalStrength) -> tuple:
315
+ """Calculate target price and stop loss"""
316
+ if signal_type in [SignalType.BUY, SignalType.STRONG_BUY]:
317
+ target_multiplier = 1.08 if strength == SignalStrength.STRONG else 1.05
318
+ stop_multiplier = 0.95 if strength == SignalStrength.STRONG else 0.97
319
+
320
+ target_price = current_price * target_multiplier
321
+ stop_loss = current_price * stop_multiplier
322
+ elif signal_type in [SignalType.SELL, SignalType.STRONG_SELL]:
323
+ target_multiplier = 0.92 if strength == SignalStrength.STRONG else 0.95
324
+ stop_multiplier = 1.05 if strength == SignalStrength.STRONG else 1.03
325
+
326
+ target_price = current_price * target_multiplier
327
+ stop_loss = current_price * stop_multiplier
328
+ else:
329
+ target_price = None
330
+ stop_loss = None
331
+
332
+ return target_price, stop_loss
333
+
334
+ def _calculate_risk_rating(self, confidence: float) -> str:
335
+ """Calculate risk rating based on confidence"""
336
+ if confidence > 0.8:
337
+ return "Low"
338
+ elif confidence > 0.6:
339
+ return "Medium"
340
+ else:
341
+ return "High"
342
+
343
+ def _calculate_expected_return(self, current_price: float, target_price: Optional[float]) -> Optional[float]:
344
+ """Calculate expected return percentage"""
345
+ if target_price and current_price:
346
+ return round(((target_price - current_price) / current_price) * 100, 2)
347
+ return None
348
+
349
+ def _generate_signal_summary(self, signal: Dict[str, Any], symbol: str) -> str:
350
+ """Generate human-readable signal summary"""
351
+ signal_type = signal["signal_type"]
352
+ strength = signal["strength"]
353
+
354
+ if signal_type in [SignalType.BUY, SignalType.STRONG_BUY]:
355
+ return f"{strength.value.title()} BUY signal for {symbol} based on positive technical and AI analysis"
356
+ elif signal_type in [SignalType.SELL, SignalType.STRONG_SELL]:
357
+ return f"{strength.value.title()} SELL signal for {symbol} due to bearish indicators"
358
+ else:
359
+ return f"HOLD recommendation for {symbol} with mixed signals"
360
+
361
+ def _generate_reasoning(self, technical: List[TechnicalIndicator], fundamental: List[FundamentalMetric]) -> str:
362
+ """Generate detailed reasoning"""
363
+ reasons = []
364
+
365
+ # Technical reasons
366
+ for indicator in technical:
367
+ if indicator.signal != SignalType.HOLD:
368
+ reasons.append(f"{indicator.name} shows {indicator.signal.value.lower()} signal")
369
+
370
+ # Fundamental reasons
371
+ for metric in fundamental:
372
+ if metric.impact == "positive":
373
+ reasons.append(f"{metric.metric_name} is favorable")
374
+
375
+ return ". ".join(reasons[:3]) + "."
376
+
377
+ def _identify_key_catalysts(self, symbol: str) -> List[str]:
378
+ """Identify key catalysts for the signal"""
379
+ catalysts = [
380
+ "Strong quarterly earnings",
381
+ "Positive sector outlook",
382
+ "Technical breakout pattern",
383
+ "Improved fundamentals",
384
+ "Market momentum"
385
+ ]
386
+
387
+ return random.sample(catalysts, random.randint(2, 4))
388
+
389
+ async def generate_market_analysis(self) -> MarketAnalysisResponse:
390
+ """Generate comprehensive market analysis"""
391
+ try:
392
+ # This would typically involve complex market analysis
393
+ # For now, returning mock data
394
+
395
+ return MarketAnalysisResponse(
396
+ overall_market_sentiment="Bullish",
397
+ market_trend={
398
+ "trend_direction": "Upward",
399
+ "trend_strength": SignalStrength.MODERATE,
400
+ "support_levels": [21800, 21600, 21400],
401
+ "resistance_levels": [22200, 22400, 22600],
402
+ "trend_duration": 15
403
+ },
404
+ volatility_index=18.5,
405
+ fear_greed_index=65,
406
+ sector_analysis=[
407
+ {
408
+ "sector_name": "Information Technology",
409
+ "performance": 2.1,
410
+ "outlook": "Positive",
411
+ "key_drivers": ["Strong earnings", "Export growth"],
412
+ "top_stocks": ["TCS", "INFY", "WIPRO"],
413
+ "recommendation": SignalType.BUY
414
+ }
415
+ ],
416
+ macro_factors=["GDP growth", "Inflation control", "FII inflows"],
417
+ market_events=["Q4 earnings season", "Budget announcements"],
418
+ generated_at=datetime.utcnow()
419
+ )
420
+
421
+ except Exception as e:
422
+ logger.error(f"Error generating market analysis: {e}")
423
+ raise
424
+
425
+ async def run_backtesting(self, symbols: List[str], period: str, strategy_config: Dict[str, Any]) -> BacktestResult:
426
+ """Run backtesting on trading strategies"""
427
+ try:
428
+ # Mock backtesting results
429
+ total_signals = random.randint(50, 200)
430
+ successful_signals = int(total_signals * random.uniform(0.6, 0.8))
431
+
432
+ return BacktestResult(
433
+ start_date=datetime.utcnow() - timedelta(days=365),
434
+ end_date=datetime.utcnow(),
435
+ total_signals=total_signals,
436
+ successful_signals=successful_signals,
437
+ accuracy_rate=successful_signals / total_signals,
438
+ total_return=random.uniform(8, 25),
439
+ annualized_return=random.uniform(12, 18),
440
+ max_drawdown=random.uniform(8, 15),
441
+ sharpe_ratio=random.uniform(1.2, 2.1),
442
+ win_rate=random.uniform(60, 75),
443
+ average_win=random.uniform(3, 8),
444
+ average_loss=random.uniform(-2, -5),
445
+ profit_factor=random.uniform(1.5, 2.5)
446
+ )
447
+
448
+ except Exception as e:
449
+ logger.error(f"Error running backtesting: {e}")
450
+ raise
451
+
452
+ async def analyze_stock_sentiment(self, symbol: str) -> Dict[str, Any]:
453
+ """Analyze sentiment for a specific stock"""
454
+ try:
455
+ # Mock sentiment analysis
456
+ sentiment_score = random.uniform(-0.5, 0.8)
457
+
458
+ if sentiment_score > 0.5:
459
+ label = "Very Positive"
460
+ elif sentiment_score > 0.2:
461
+ label = "Positive"
462
+ elif sentiment_score > -0.2:
463
+ label = "Neutral"
464
+ elif sentiment_score > -0.5:
465
+ label = "Negative"
466
+ else:
467
+ label = "Very Negative"
468
+
469
+ return {
470
+ "score": sentiment_score,
471
+ "label": label,
472
+ "confidence": random.uniform(0.7, 0.95),
473
+ "factors": ["earnings report", "analyst upgrades", "sector performance"],
474
+ "news_summary": f"Recent news about {symbol} has been generally {label.lower()}"
475
+ }
476
+
477
+ except Exception as e:
478
+ logger.error(f"Error analyzing sentiment for {symbol}: {e}")
479
+ raise
480
+
481
+ async def get_available_strategies(self) -> List[Dict[str, Any]]:
482
+ """Get list of available AI trading strategies"""
483
+ return [
484
+ {
485
+ "id": "momentum_strategy",
486
+ "name": "Momentum Trading",
487
+ "description": "Identifies stocks with strong price momentum",
488
+ "accuracy": 72.5,
489
+ "risk_level": "Medium"
490
+ },
491
+ {
492
+ "id": "mean_reversion",
493
+ "name": "Mean Reversion",
494
+ "description": "Identifies oversold/overbought conditions",
495
+ "accuracy": 68.3,
496
+ "risk_level": "Low"
497
+ },
498
+ {
499
+ "id": "breakout_strategy",
500
+ "name": "Breakout Detection",
501
+ "description": "Identifies technical breakout patterns",
502
+ "accuracy": 75.1,
503
+ "risk_level": "Medium-High"
504
+ }
505
+ ]
506
+
507
+ async def create_custom_strategy(self, user_id: str, config: Dict[str, Any]) -> str:
508
+ """Create a custom AI trading strategy"""
509
+ try:
510
+ # In a real implementation, would train/configure the strategy
511
+ strategy_id = f"custom_{user_id}_{int(datetime.utcnow().timestamp())}"
512
+
513
+ # Mock strategy creation
514
+ logger.info(f"Created custom strategy {strategy_id} for user {user_id}")
515
+
516
+ return strategy_id
517
+
518
+ except Exception as e:
519
+ logger.error(f"Error creating custom strategy: {e}")
520
+ raise
services/dhan_api_manager.py ADDED
@@ -0,0 +1,258 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Dhan API Manager
3
+ Centralized management for different Dhan API types and tokens
4
+ Handles Trading API, Data API, and Paper Trading API calls
5
+ """
6
+
7
+ from typing import Dict, Any, Optional, Literal
8
+ from enum import Enum
9
+ import httpx
10
+ import logging
11
+ from datetime import datetime
12
+
13
+ from config.settings import get_settings
14
+
15
+ logger = logging.getLogger(__name__)
16
+ settings = get_settings()
17
+
18
+
19
+ class DhanAPIType(Enum):
20
+ """Types of Dhan APIs"""
21
+ TRADING = "trading" # Order management, portfolio, funds
22
+ DATA = "data" # Market data, quotes, historical data
23
+ PAPER_TRADING = "paper" # Sandbox/paper trading
24
+
25
+
26
+ class DhanAPIError(Exception):
27
+ """Custom exception for Dhan API errors"""
28
+ def __init__(self, message: str, status_code: Optional[int] = None, response_data: Optional[Dict] = None):
29
+ super().__init__(message)
30
+ self.status_code = status_code
31
+ self.response_data = response_data
32
+
33
+
34
+ class DhanAPIManager:
35
+ """Centralized manager for all Dhan API interactions"""
36
+
37
+ def __init__(self):
38
+ self.settings = settings
39
+ self.base_url = settings.DHAN_API_BASE_URL
40
+
41
+ # API endpoint mappings by type
42
+ self.api_endpoints = {
43
+ # Trading API endpoints
44
+ DhanAPIType.TRADING: {
45
+ "orders": "/orders",
46
+ "positions": "/positions",
47
+ "holdings": "/holdings",
48
+ "funds": "/funds",
49
+ "margin": "/margin",
50
+ "ledger": "/ledger",
51
+ "statements": "/statements",
52
+ },
53
+
54
+ # Data API endpoints
55
+ DhanAPIType.DATA: {
56
+ "securitymaster": "/securitymaster",
57
+ "marketfeed": "/marketfeed",
58
+ "quotes": "/marketfeed/quote",
59
+ "ltp": "/marketfeed/ltp",
60
+ "ohlc": "/marketfeed/ohlc",
61
+ "historical": "/charts/historical",
62
+ "intraday": "/charts/intraday",
63
+ "optionchain": "/optionchain",
64
+ },
65
+
66
+ # Paper Trading endpoints (same as trading but with different token)
67
+ DhanAPIType.PAPER_TRADING: {
68
+ "orders": "/orders",
69
+ "positions": "/positions",
70
+ "holdings": "/holdings",
71
+ "funds": "/funds",
72
+ "margin": "/margin",
73
+ }
74
+ }
75
+
76
+ def get_token_for_api_type(self, api_type: DhanAPIType) -> str:
77
+ """Get the appropriate token for the API type"""
78
+ if api_type == DhanAPIType.TRADING:
79
+ if self.settings.DHAN_TRADING_MODE == "paper":
80
+ return self.settings.DHAN_PAPER_TRADING_TOKEN
81
+ return self.settings.DHAN_TRADING_API_TOKEN
82
+ elif api_type == DhanAPIType.DATA:
83
+ return self.settings.DHAN_DATA_API_TOKEN
84
+ elif api_type == DhanAPIType.PAPER_TRADING:
85
+ return self.settings.DHAN_PAPER_TRADING_TOKEN
86
+ else:
87
+ raise ValueError(f"Unknown API type: {api_type}")
88
+
89
+ def get_client_id_for_api_type(self, api_type: DhanAPIType) -> str:
90
+ """Get the appropriate client ID for the API type"""
91
+ if api_type == DhanAPIType.PAPER_TRADING or self.settings.DHAN_TRADING_MODE == "paper":
92
+ return self.settings.DHAN_PAPER_TRADING_CLIENT_ID
93
+ return self.settings.DHAN_CLIENT_ID
94
+
95
+ def get_headers(self, api_type: DhanAPIType) -> Dict[str, str]:
96
+ """Get API headers for the specified API type"""
97
+ token = self.get_token_for_api_type(api_type)
98
+ return {
99
+ "Content-Type": "application/json",
100
+ "access-token": token,
101
+ "Accept": "application/json"
102
+ }
103
+
104
+ async def make_request(
105
+ self,
106
+ api_type: DhanAPIType,
107
+ endpoint: str,
108
+ method: str = "GET",
109
+ data: Optional[Dict] = None,
110
+ params: Optional[Dict] = None
111
+ ) -> Dict[str, Any]:
112
+ """Make API request with appropriate token and error handling"""
113
+
114
+ # Validate endpoint exists for this API type
115
+ if api_type in self.api_endpoints and endpoint not in self.api_endpoints[api_type].values():
116
+ # Check if endpoint is a direct path
117
+ if not endpoint.startswith("/"):
118
+ # Try to find the endpoint in the mapping
119
+ mapped_endpoints = self.api_endpoints[api_type]
120
+ if endpoint in mapped_endpoints:
121
+ endpoint = mapped_endpoints[endpoint]
122
+ else:
123
+ logger.warning(f"Endpoint '{endpoint}' not found in {api_type.value} API mapping")
124
+
125
+ url = f"{self.base_url}{endpoint}"
126
+ headers = self.get_headers(api_type)
127
+
128
+ try:
129
+ async with httpx.AsyncClient(timeout=30.0) as client:
130
+ response = await client.request(
131
+ method=method.upper(),
132
+ url=url,
133
+ headers=headers,
134
+ json=data if method.upper() != "GET" else None,
135
+ params=params if method.upper() == "GET" else None
136
+ )
137
+
138
+ response.raise_for_status()
139
+ result = response.json()
140
+
141
+ logger.info(f"Dhan {api_type.value} API {method} {endpoint}: {response.status_code}")
142
+ return result
143
+
144
+ except httpx.HTTPStatusError as e:
145
+ error_msg = f"Dhan {api_type.value} API HTTP error: {e.response.status_code}"
146
+ try:
147
+ error_detail = e.response.json()
148
+ error_msg += f" - {error_detail.get('message', str(error_detail))}"
149
+ except:
150
+ error_msg += f" - {e.response.text}"
151
+
152
+ logger.error(error_msg)
153
+ raise DhanAPIError(error_msg, e.response.status_code, error_detail if 'error_detail' in locals() else None)
154
+
155
+ except httpx.RequestError as e:
156
+ error_msg = f"Dhan {api_type.value} API request error: {str(e)}"
157
+ logger.error(error_msg)
158
+ raise DhanAPIError(error_msg)
159
+
160
+ except Exception as e:
161
+ error_msg = f"Dhan {api_type.value} API unexpected error: {str(e)}"
162
+ logger.error(error_msg)
163
+ raise DhanAPIError(error_msg)
164
+
165
+ # Trading API Methods
166
+ async def place_order(self, order_data: Dict[str, Any]) -> Dict[str, Any]:
167
+ """Place an order using Trading API"""
168
+ api_type = DhanAPIType.PAPER_TRADING if self.settings.DHAN_TRADING_MODE == "paper" else DhanAPIType.TRADING
169
+ return await self.make_request(api_type, "/orders", "POST", order_data)
170
+
171
+ async def get_orders(self) -> Dict[str, Any]:
172
+ """Get orders using Trading API"""
173
+ api_type = DhanAPIType.PAPER_TRADING if self.settings.DHAN_TRADING_MODE == "paper" else DhanAPIType.TRADING
174
+ return await self.make_request(api_type, "/orders", "GET")
175
+
176
+ async def get_positions(self) -> Dict[str, Any]:
177
+ """Get positions using Trading API"""
178
+ api_type = DhanAPIType.PAPER_TRADING if self.settings.DHAN_TRADING_MODE == "paper" else DhanAPIType.TRADING
179
+ return await self.make_request(api_type, "/positions", "GET")
180
+
181
+ async def get_holdings(self) -> Dict[str, Any]:
182
+ """Get holdings using Trading API"""
183
+ api_type = DhanAPIType.PAPER_TRADING if self.settings.DHAN_TRADING_MODE == "paper" else DhanAPIType.TRADING
184
+ return await self.make_request(api_type, "/holdings", "GET")
185
+
186
+ async def get_funds(self) -> Dict[str, Any]:
187
+ """Get funds using Trading API"""
188
+ api_type = DhanAPIType.PAPER_TRADING if self.settings.DHAN_TRADING_MODE == "paper" else DhanAPIType.TRADING
189
+ return await self.make_request(api_type, "/funds", "GET")
190
+
191
+ async def cancel_order(self, order_id: str) -> Dict[str, Any]:
192
+ """Cancel an order using Trading API"""
193
+ api_type = DhanAPIType.PAPER_TRADING if self.settings.DHAN_TRADING_MODE == "paper" else DhanAPIType.TRADING
194
+ return await self.make_request(api_type, f"/orders/{order_id}", "DELETE")
195
+
196
+ async def modify_order(self, order_id: str, modifications: Dict[str, Any]) -> Dict[str, Any]:
197
+ """Modify an order using Trading API"""
198
+ api_type = DhanAPIType.PAPER_TRADING if self.settings.DHAN_TRADING_MODE == "paper" else DhanAPIType.TRADING
199
+ return await self.make_request(api_type, f"/orders/{order_id}", "PUT", modifications)
200
+
201
+ # Data API Methods
202
+ async def get_security_master(self) -> Dict[str, Any]:
203
+ """Get security master using Data API"""
204
+ return await self.make_request(DhanAPIType.DATA, "/securitymaster", "GET")
205
+
206
+ async def get_live_feed(self, instruments: list) -> Dict[str, Any]:
207
+ """Get live market feed using Data API"""
208
+ data = {"instruments": instruments}
209
+ return await self.make_request(DhanAPIType.DATA, "/marketfeed/ltp", "POST", data)
210
+
211
+ async def get_quote(self, security_id: str, exchange_segment: str) -> Dict[str, Any]:
212
+ """Get quote using Data API"""
213
+ params = {"securityId": security_id, "exchangeSegment": exchange_segment}
214
+ return await self.make_request(DhanAPIType.DATA, "/marketfeed/quote", "GET", params=params)
215
+
216
+ async def get_ohlc(self, instruments: list) -> Dict[str, Any]:
217
+ """Get OHLC data using Data API"""
218
+ data = {"instruments": instruments}
219
+ return await self.make_request(DhanAPIType.DATA, "/marketfeed/ohlc", "POST", data)
220
+
221
+ async def get_historical_data(self, request_data: Dict[str, Any]) -> Dict[str, Any]:
222
+ """Get historical data using Data API"""
223
+ return await self.make_request(DhanAPIType.DATA, "/charts/historical", "POST", request_data)
224
+
225
+ async def get_intraday_data(self, request_data: Dict[str, Any]) -> Dict[str, Any]:
226
+ """Get intraday data using Data API"""
227
+ return await self.make_request(DhanAPIType.DATA, "/charts/intraday", "POST", request_data)
228
+
229
+ async def get_option_chain(self, underlying_security_id: str, exchange_segment: str) -> Dict[str, Any]:
230
+ """Get option chain using Data API"""
231
+ params = {"underlyingSecurityId": underlying_security_id, "exchangeSegment": exchange_segment}
232
+ return await self.make_request(DhanAPIType.DATA, "/optionchain", "GET", params=params)
233
+
234
+ # Utility Methods
235
+ def get_current_mode(self) -> str:
236
+ """Get current trading mode"""
237
+ return self.settings.DHAN_TRADING_MODE
238
+
239
+ def is_paper_trading(self) -> bool:
240
+ """Check if currently in paper trading mode"""
241
+ return self.settings.DHAN_TRADING_MODE == "paper"
242
+
243
+ def get_websocket_credentials(self) -> Dict[str, str]:
244
+ """Get WebSocket credentials for current mode"""
245
+ if self.is_paper_trading():
246
+ return {
247
+ "client_id": self.settings.DHAN_PAPER_TRADING_CLIENT_ID,
248
+ "access_token": self.settings.DHAN_PAPER_TRADING_TOKEN
249
+ }
250
+ else:
251
+ return {
252
+ "client_id": self.settings.DHAN_CLIENT_ID,
253
+ "access_token": self.settings.DHAN_DATA_API_TOKEN
254
+ }
255
+
256
+
257
+ # Global instance
258
+ dhan_api_manager = DhanAPIManager()
services/dhan_websocket.py ADDED
@@ -0,0 +1,513 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Dhan WebSocket Service for Real-time Market Data
3
+ Handles live market feed, order updates, and market depth data
4
+ """
5
+
6
+ import json
7
+ import asyncio
8
+ import logging
9
+ from typing import Dict, List, Optional, Callable, Any
10
+ from datetime import datetime
11
+ import websockets
12
+ from websockets.exceptions import ConnectionClosed, WebSocketException
13
+ from dataclasses import dataclass
14
+ from enum import Enum
15
+
16
+ from config.settings import get_settings
17
+ from .dhan_api_manager import dhan_api_manager
18
+
19
+ logger = logging.getLogger(__name__)
20
+ settings = get_settings()
21
+
22
+
23
+ class DhanWSMessageType(Enum):
24
+ """Dhan WebSocket message types"""
25
+ MARKET_FEED_SUBSCRIBE = 15
26
+ ORDER_UPDATE_LOGIN = 42
27
+ MARKET_DEPTH_SUBSCRIBE = 23
28
+ HEARTBEAT = 1
29
+
30
+
31
+ @dataclass
32
+ class MarketDataUpdate:
33
+ """Market data update structure"""
34
+ security_id: str
35
+ exchange_segment: str
36
+ ltp: float
37
+ change: float
38
+ change_percent: float
39
+ volume: int
40
+ open_price: float
41
+ high_price: float
42
+ low_price: float
43
+ close_price: float
44
+ timestamp: datetime
45
+ total_buy_qty: Optional[int] = None
46
+ total_sell_qty: Optional[int] = None
47
+
48
+
49
+ @dataclass
50
+ class OrderUpdate:
51
+ """Order update structure"""
52
+ order_id: str
53
+ client_id: str
54
+ security_id: str
55
+ exchange_segment: str
56
+ order_status: str
57
+ transaction_type: str
58
+ quantity: int
59
+ price: float
60
+ traded_qty: int
61
+ traded_price: float
62
+ avg_traded_price: float
63
+ order_datetime: datetime
64
+ update_datetime: datetime
65
+
66
+
67
+ @dataclass
68
+ class MarketDepthUpdate:
69
+ """Market depth update structure"""
70
+ security_id: str
71
+ exchange_segment: str
72
+ bids: List[Dict[str, float]] # [{"price": float, "quantity": int, "orders": int}]
73
+ asks: List[Dict[str, float]] # [{"price": float, "quantity": int, "orders": int}]
74
+ timestamp: datetime
75
+
76
+
77
+ class DhanWebSocketClient:
78
+ """Dhan WebSocket client for real-time data"""
79
+
80
+ def __init__(self):
81
+ self.settings = settings
82
+ self.market_feed_ws: Optional[websockets.WebSocketServerProtocol] = None
83
+ self.order_update_ws: Optional[websockets.WebSocketServerProtocol] = None
84
+ self.market_depth_ws: Optional[websockets.WebSocketServerProtocol] = None
85
+
86
+ # Connection status
87
+ self.is_market_feed_connected = False
88
+ self.is_order_update_connected = False
89
+ self.is_market_depth_connected = False
90
+
91
+ # Subscribed instruments
92
+ self.subscribed_instruments: List[Dict[str, str]] = []
93
+ self.subscribed_depth_instruments: List[Dict[str, str]] = []
94
+
95
+ # Callbacks
96
+ self.market_data_callback: Optional[Callable[[MarketDataUpdate], None]] = None
97
+ self.order_update_callback: Optional[Callable[[OrderUpdate], None]] = None
98
+ self.market_depth_callback: Optional[Callable[[MarketDepthUpdate], None]] = None
99
+
100
+ # Heartbeat tasks
101
+ self.market_feed_heartbeat_task: Optional[asyncio.Task] = None
102
+ self.order_update_heartbeat_task: Optional[asyncio.Task] = None
103
+ self.market_depth_heartbeat_task: Optional[asyncio.Task] = None
104
+
105
+ # Reconnection parameters
106
+ self.max_reconnect_attempts = 5
107
+ self.reconnect_delay = 5 # seconds
108
+
109
+ async def connect_market_feed(self) -> bool:
110
+ """Connect to Dhan market feed WebSocket"""
111
+ try:
112
+ credentials = dhan_api_manager.get_websocket_credentials()
113
+ url = f"{self.settings.DHAN_WS_MARKET_FEED_URL}?version=2&token={credentials['access_token']}&clientId={credentials['client_id']}&authType=2"
114
+
115
+ self.market_feed_ws = await websockets.connect(url)
116
+ self.is_market_feed_connected = True
117
+
118
+ logger.info("Connected to Dhan market feed WebSocket")
119
+
120
+ # Start listening for messages
121
+ asyncio.create_task(self._listen_market_feed())
122
+
123
+ # Start heartbeat
124
+ self.market_feed_heartbeat_task = asyncio.create_task(self._market_feed_heartbeat())
125
+
126
+ return True
127
+
128
+ except Exception as e:
129
+ logger.error(f"Failed to connect to market feed WebSocket: {e}")
130
+ self.is_market_feed_connected = False
131
+ return False
132
+
133
+ async def connect_order_updates(self) -> bool:
134
+ """Connect to Dhan order update WebSocket"""
135
+ try:
136
+ self.order_update_ws = await websockets.connect(self.settings.DHAN_WS_ORDER_UPDATE_URL)
137
+
138
+ # Send authorization message
139
+ credentials = dhan_api_manager.get_websocket_credentials()
140
+ auth_message = {
141
+ "LoginReq": {
142
+ "MsgCode": DhanWSMessageType.ORDER_UPDATE_LOGIN.value,
143
+ "ClientId": credentials['client_id'],
144
+ "Token": credentials['access_token']
145
+ },
146
+ "UserType": "SELF"
147
+ }
148
+
149
+ await self.order_update_ws.send(json.dumps(auth_message))
150
+ self.is_order_update_connected = True
151
+
152
+ logger.info("Connected to Dhan order update WebSocket")
153
+
154
+ # Start listening for messages
155
+ asyncio.create_task(self._listen_order_updates())
156
+
157
+ return True
158
+
159
+ except Exception as e:
160
+ logger.error(f"Failed to connect to order update WebSocket: {e}")
161
+ self.is_order_update_connected = False
162
+ return False
163
+
164
+ async def connect_market_depth(self) -> bool:
165
+ """Connect to Dhan market depth WebSocket"""
166
+ try:
167
+ credentials = dhan_api_manager.get_websocket_credentials()
168
+ url = f"{self.settings.DHAN_WS_MARKET_DEPTH_URL}?token={credentials['access_token']}&clientId={credentials['client_id']}&authType=2"
169
+
170
+ self.market_depth_ws = await websockets.connect(url)
171
+ self.is_market_depth_connected = True
172
+
173
+ logger.info("Connected to Dhan market depth WebSocket")
174
+
175
+ # Start listening for messages
176
+ asyncio.create_task(self._listen_market_depth())
177
+
178
+ # Start heartbeat
179
+ self.market_depth_heartbeat_task = asyncio.create_task(self._market_depth_heartbeat())
180
+
181
+ return True
182
+
183
+ except Exception as e:
184
+ logger.error(f"Failed to connect to market depth WebSocket: {e}")
185
+ self.is_market_depth_connected = False
186
+ return False
187
+
188
+ async def subscribe_market_data(self, instruments: List[Dict[str, str]]) -> bool:
189
+ """Subscribe to market data for instruments"""
190
+ if not self.is_market_feed_connected or not self.market_feed_ws:
191
+ logger.error("Market feed WebSocket not connected")
192
+ return False
193
+
194
+ try:
195
+ # Limit to 5000 instruments per connection
196
+ if len(instruments) > 5000:
197
+ logger.warning("Too many instruments, limiting to 5000")
198
+ instruments = instruments[:5000]
199
+
200
+ subscribe_message = {
201
+ "RequestCode": DhanWSMessageType.MARKET_FEED_SUBSCRIBE.value,
202
+ "InstrumentCount": len(instruments),
203
+ "InstrumentList": [
204
+ {
205
+ "ExchangeSegment": instrument["exchange_segment"],
206
+ "SecurityId": instrument["security_id"]
207
+ }
208
+ for instrument in instruments
209
+ ]
210
+ }
211
+
212
+ await self.market_feed_ws.send(json.dumps(subscribe_message))
213
+ self.subscribed_instruments.extend(instruments)
214
+
215
+ logger.info(f"Subscribed to {len(instruments)} instruments for market data")
216
+ return True
217
+
218
+ except Exception as e:
219
+ logger.error(f"Failed to subscribe to market data: {e}")
220
+ return False
221
+
222
+ async def subscribe_market_depth(self, instruments: List[Dict[str, str]]) -> bool:
223
+ """Subscribe to market depth for instruments (max 50 per connection)"""
224
+ if not self.is_market_depth_connected or not self.market_depth_ws:
225
+ logger.error("Market depth WebSocket not connected")
226
+ return False
227
+
228
+ try:
229
+ # Limit to 50 instruments per connection
230
+ if len(instruments) > 50:
231
+ logger.warning("Too many instruments for market depth, limiting to 50")
232
+ instruments = instruments[:50]
233
+
234
+ subscribe_message = {
235
+ "RequestCode": DhanWSMessageType.MARKET_DEPTH_SUBSCRIBE.value,
236
+ "InstrumentCount": len(instruments),
237
+ "InstrumentList": [
238
+ {
239
+ "ExchangeSegment": instrument["exchange_segment"],
240
+ "SecurityId": instrument["security_id"]
241
+ }
242
+ for instrument in instruments
243
+ ]
244
+ }
245
+
246
+ await self.market_depth_ws.send(json.dumps(subscribe_message))
247
+ self.subscribed_depth_instruments.extend(instruments)
248
+
249
+ logger.info(f"Subscribed to {len(instruments)} instruments for market depth")
250
+ return True
251
+
252
+ except Exception as e:
253
+ logger.error(f"Failed to subscribe to market depth: {e}")
254
+ return False
255
+
256
+ async def _listen_market_feed(self):
257
+ """Listen for market feed messages"""
258
+ try:
259
+ async for message in self.market_feed_ws:
260
+ try:
261
+ data = json.loads(message)
262
+ await self._process_market_data(data)
263
+ except json.JSONDecodeError:
264
+ logger.error(f"Failed to decode market feed message: {message}")
265
+ except Exception as e:
266
+ logger.error(f"Error processing market feed message: {e}")
267
+
268
+ except ConnectionClosed:
269
+ logger.warning("Market feed WebSocket connection closed")
270
+ self.is_market_feed_connected = False
271
+ await self._reconnect_market_feed()
272
+ except Exception as e:
273
+ logger.error(f"Error in market feed listener: {e}")
274
+ self.is_market_feed_connected = False
275
+
276
+ async def _listen_order_updates(self):
277
+ """Listen for order update messages"""
278
+ try:
279
+ async for message in self.order_update_ws:
280
+ try:
281
+ data = json.loads(message)
282
+ await self._process_order_update(data)
283
+ except json.JSONDecodeError:
284
+ logger.error(f"Failed to decode order update message: {message}")
285
+ except Exception as e:
286
+ logger.error(f"Error processing order update message: {e}")
287
+
288
+ except ConnectionClosed:
289
+ logger.warning("Order update WebSocket connection closed")
290
+ self.is_order_update_connected = False
291
+ await self._reconnect_order_updates()
292
+ except Exception as e:
293
+ logger.error(f"Error in order update listener: {e}")
294
+ self.is_order_update_connected = False
295
+
296
+ async def _listen_market_depth(self):
297
+ """Listen for market depth messages"""
298
+ try:
299
+ async for message in self.market_depth_ws:
300
+ try:
301
+ data = json.loads(message)
302
+ await self._process_market_depth(data)
303
+ except json.JSONDecodeError:
304
+ logger.error(f"Failed to decode market depth message: {message}")
305
+ except Exception as e:
306
+ logger.error(f"Error processing market depth message: {e}")
307
+
308
+ except ConnectionClosed:
309
+ logger.warning("Market depth WebSocket connection closed")
310
+ self.is_market_depth_connected = False
311
+ await self._reconnect_market_depth()
312
+ except Exception as e:
313
+ logger.error(f"Error in market depth listener: {e}")
314
+ self.is_market_depth_connected = False
315
+
316
+ async def _process_market_data(self, data: Dict[str, Any]):
317
+ """Process market data update"""
318
+ try:
319
+ # Parse market data from Dhan format
320
+ if self.market_data_callback and "SecurityId" in data:
321
+ market_update = MarketDataUpdate(
322
+ security_id=str(data.get("SecurityId", "")),
323
+ exchange_segment=data.get("ExchangeSegment", ""),
324
+ ltp=float(data.get("LastTradedPrice", 0)),
325
+ change=float(data.get("NetChange", 0)),
326
+ change_percent=float(data.get("PercentChange", 0)),
327
+ volume=int(data.get("TotalTradedQty", 0)),
328
+ open_price=float(data.get("OpenPrice", 0)),
329
+ high_price=float(data.get("HighPrice", 0)),
330
+ low_price=float(data.get("LowPrice", 0)),
331
+ close_price=float(data.get("ClosePrice", 0)),
332
+ total_buy_qty=data.get("TotalBuyQty"),
333
+ total_sell_qty=data.get("TotalSellQty"),
334
+ timestamp=datetime.now()
335
+ )
336
+
337
+ await self.market_data_callback(market_update)
338
+
339
+ except Exception as e:
340
+ logger.error(f"Error processing market data: {e}")
341
+
342
+ async def _process_order_update(self, data: Dict[str, Any]):
343
+ """Process order update"""
344
+ try:
345
+ if self.order_update_callback and "OrderNo" in data:
346
+ order_update = OrderUpdate(
347
+ order_id=str(data.get("OrderNo", "")),
348
+ client_id=str(data.get("ClientId", "")),
349
+ security_id=str(data.get("SecurityId", "")),
350
+ exchange_segment=data.get("Segment", ""),
351
+ order_status=data.get("Status", ""),
352
+ transaction_type=data.get("TxnType", ""),
353
+ quantity=int(data.get("Quantity", 0)),
354
+ price=float(data.get("Price", 0)),
355
+ traded_qty=int(data.get("TradedQty", 0)),
356
+ traded_price=float(data.get("TradedPrice", 0)),
357
+ avg_traded_price=float(data.get("AvgTradedPrice", 0)),
358
+ order_datetime=datetime.now(),
359
+ update_datetime=datetime.now()
360
+ )
361
+
362
+ await self.order_update_callback(order_update)
363
+
364
+ except Exception as e:
365
+ logger.error(f"Error processing order update: {e}")
366
+
367
+ async def _process_market_depth(self, data: Dict[str, Any]):
368
+ """Process market depth update"""
369
+ try:
370
+ if self.market_depth_callback and "SecurityId" in data:
371
+ bids = []
372
+ asks = []
373
+
374
+ # Parse bid and ask data
375
+ for i in range(1, 21): # 20 levels
376
+ bid_price = data.get(f"BidPrice{i}")
377
+ bid_qty = data.get(f"BidQty{i}")
378
+ bid_orders = data.get(f"BidOrders{i}")
379
+
380
+ if bid_price and bid_qty:
381
+ bids.append({
382
+ "price": float(bid_price),
383
+ "quantity": int(bid_qty),
384
+ "orders": int(bid_orders or 0)
385
+ })
386
+
387
+ ask_price = data.get(f"AskPrice{i}")
388
+ ask_qty = data.get(f"AskQty{i}")
389
+ ask_orders = data.get(f"AskOrders{i}")
390
+
391
+ if ask_price and ask_qty:
392
+ asks.append({
393
+ "price": float(ask_price),
394
+ "quantity": int(ask_qty),
395
+ "orders": int(ask_orders or 0)
396
+ })
397
+
398
+ depth_update = MarketDepthUpdate(
399
+ security_id=str(data.get("SecurityId", "")),
400
+ exchange_segment=data.get("ExchangeSegment", ""),
401
+ bids=bids,
402
+ asks=asks,
403
+ timestamp=datetime.now()
404
+ )
405
+
406
+ await self.market_depth_callback(depth_update)
407
+
408
+ except Exception as e:
409
+ logger.error(f"Error processing market depth: {e}")
410
+
411
+ async def _market_feed_heartbeat(self):
412
+ """Send heartbeat to keep market feed connection alive"""
413
+ while self.is_market_feed_connected:
414
+ try:
415
+ await asyncio.sleep(30) # Send heartbeat every 30 seconds
416
+ if self.market_feed_ws and self.is_market_feed_connected:
417
+ await self.market_feed_ws.ping()
418
+ except Exception as e:
419
+ logger.error(f"Market feed heartbeat error: {e}")
420
+ break
421
+
422
+ async def _market_depth_heartbeat(self):
423
+ """Send heartbeat to keep market depth connection alive"""
424
+ while self.is_market_depth_connected:
425
+ try:
426
+ await asyncio.sleep(30) # Send heartbeat every 30 seconds
427
+ if self.market_depth_ws and self.is_market_depth_connected:
428
+ await self.market_depth_ws.ping()
429
+ except Exception as e:
430
+ logger.error(f"Market depth heartbeat error: {e}")
431
+ break
432
+
433
+ async def _reconnect_market_feed(self):
434
+ """Reconnect to market feed WebSocket"""
435
+ for attempt in range(self.max_reconnect_attempts):
436
+ logger.info(f"Attempting to reconnect market feed (attempt {attempt + 1})")
437
+ await asyncio.sleep(self.reconnect_delay)
438
+
439
+ if await self.connect_market_feed():
440
+ # Re-subscribe to instruments
441
+ if self.subscribed_instruments:
442
+ await self.subscribe_market_data(self.subscribed_instruments)
443
+ return
444
+
445
+ logger.error("Failed to reconnect market feed after maximum attempts")
446
+
447
+ async def _reconnect_order_updates(self):
448
+ """Reconnect to order update WebSocket"""
449
+ for attempt in range(self.max_reconnect_attempts):
450
+ logger.info(f"Attempting to reconnect order updates (attempt {attempt + 1})")
451
+ await asyncio.sleep(self.reconnect_delay)
452
+
453
+ if await self.connect_order_updates():
454
+ return
455
+
456
+ logger.error("Failed to reconnect order updates after maximum attempts")
457
+
458
+ async def _reconnect_market_depth(self):
459
+ """Reconnect to market depth WebSocket"""
460
+ for attempt in range(self.max_reconnect_attempts):
461
+ logger.info(f"Attempting to reconnect market depth (attempt {attempt + 1})")
462
+ await asyncio.sleep(self.reconnect_delay)
463
+
464
+ if await self.connect_market_depth():
465
+ # Re-subscribe to instruments
466
+ if self.subscribed_depth_instruments:
467
+ await self.subscribe_market_depth(self.subscribed_depth_instruments)
468
+ return
469
+
470
+ logger.error("Failed to reconnect market depth after maximum attempts")
471
+
472
+ def set_market_data_callback(self, callback: Callable[[MarketDataUpdate], None]):
473
+ """Set callback for market data updates"""
474
+ self.market_data_callback = callback
475
+
476
+ def set_order_update_callback(self, callback: Callable[[OrderUpdate], None]):
477
+ """Set callback for order updates"""
478
+ self.order_update_callback = callback
479
+
480
+ def set_market_depth_callback(self, callback: Callable[[MarketDepthUpdate], None]):
481
+ """Set callback for market depth updates"""
482
+ self.market_depth_callback = callback
483
+
484
+ async def disconnect_all(self):
485
+ """Disconnect all WebSocket connections"""
486
+ try:
487
+ # Cancel heartbeat tasks
488
+ if self.market_feed_heartbeat_task:
489
+ self.market_feed_heartbeat_task.cancel()
490
+ if self.market_depth_heartbeat_task:
491
+ self.market_depth_heartbeat_task.cancel()
492
+
493
+ # Close connections
494
+ if self.market_feed_ws:
495
+ await self.market_feed_ws.close()
496
+ if self.order_update_ws:
497
+ await self.order_update_ws.close()
498
+ if self.market_depth_ws:
499
+ await self.market_depth_ws.close()
500
+
501
+ # Reset connection status
502
+ self.is_market_feed_connected = False
503
+ self.is_order_update_connected = False
504
+ self.is_market_depth_connected = False
505
+
506
+ logger.info("All WebSocket connections closed")
507
+
508
+ except Exception as e:
509
+ logger.error(f"Error disconnecting WebSockets: {e}")
510
+
511
+
512
+ # Global WebSocket client instance
513
+ dhan_ws_client = DhanWebSocketClient()
services/market_data.py ADDED
@@ -0,0 +1,396 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Market Data Service
3
+ Handles market data fetching, caching, and processing
4
+ """
5
+
6
+ import logging
7
+ from typing import Dict, List, Any, Optional
8
+ from datetime import datetime, timedelta
9
+ import asyncio
10
+ import random
11
+
12
+ from services.dhan_api_manager import dhan_api_manager, DhanAPIError
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+ class MarketDataService:
17
+ """Service for market data operations"""
18
+
19
+ def __init__(self):
20
+ self.dhan_api = dhan_api_manager
21
+ self.cache = {}
22
+ self.cache_ttl = 60 # 1 minute cache
23
+
24
+ async def get_market_data_batch(self, securities: List[Dict[str, str]]) -> List[Dict[str, Any]]:
25
+ """Get market data for multiple securities"""
26
+ try:
27
+ market_data = []
28
+
29
+ for security in securities:
30
+ data = await self.get_single_security_data(
31
+ security.get("security_id", ""),
32
+ security.get("exchange_segment", "NSE_EQ")
33
+ )
34
+ if data:
35
+ market_data.append(data)
36
+
37
+ return market_data
38
+
39
+ except Exception as e:
40
+ logger.error(f"Error getting batch market data: {e}")
41
+ raise
42
+
43
+ async def get_single_security_data(self, security_id: str, exchange_segment: str) -> Optional[Dict[str, Any]]:
44
+ """Get market data for a single security"""
45
+ try:
46
+ cache_key = f"{security_id}_{exchange_segment}"
47
+
48
+ # Check cache first
49
+ if cache_key in self.cache:
50
+ cached_data, timestamp = self.cache[cache_key]
51
+ if datetime.utcnow() - timestamp < timedelta(seconds=self.cache_ttl):
52
+ return cached_data
53
+
54
+ # Fetch from Dhan API
55
+ try:
56
+ result = await self.dhan_api.get_quote(security_id, exchange_segment)
57
+
58
+ if result and "data" in result:
59
+ data = result["data"]
60
+
61
+ # Standardize the data format
62
+ standardized_data = {
63
+ "security_id": security_id,
64
+ "exchange_segment": exchange_segment,
65
+ "symbol": data.get("tradingSymbol", ""),
66
+ "ltp": float(data.get("LTP", 0)),
67
+ "change": float(data.get("change", 0)),
68
+ "change_percent": float(data.get("pChange", 0)),
69
+ "volume": int(data.get("volume", 0)),
70
+ "open": float(data.get("open", 0)),
71
+ "high": float(data.get("high", 0)),
72
+ "low": float(data.get("low", 0)),
73
+ "close": float(data.get("close", 0)),
74
+ "timestamp": datetime.utcnow().isoformat()
75
+ }
76
+
77
+ # Cache the data
78
+ self.cache[cache_key] = (standardized_data, datetime.utcnow())
79
+
80
+ return standardized_data
81
+
82
+ except DhanAPIError as e:
83
+ logger.warning(f"Dhan API error for {security_id}: {e}")
84
+ # Fall back to mock data
85
+ return self._generate_mock_data(security_id, exchange_segment)
86
+
87
+ return None
88
+
89
+ except Exception as e:
90
+ logger.error(f"Error getting single security data for {security_id}: {e}")
91
+ return self._generate_mock_data(security_id, exchange_segment)
92
+
93
+ def _generate_mock_data(self, security_id: str, exchange_segment: str) -> Dict[str, Any]:
94
+ """Generate mock market data for testing"""
95
+
96
+ # Mock data based on common symbols
97
+ mock_prices = {
98
+ "2885": 2650.0, # RELIANCE
99
+ "11723": 3890.0, # TCS
100
+ "1333": 1580.0, # HDFC
101
+ "1594": 1685.0, # INFY
102
+ }
103
+
104
+ base_price = mock_prices.get(security_id, random.uniform(100, 5000))
105
+ change_percent = random.uniform(-3, 3)
106
+ change = base_price * change_percent / 100
107
+
108
+ return {
109
+ "security_id": security_id,
110
+ "exchange_segment": exchange_segment,
111
+ "symbol": f"SYMBOL_{security_id}",
112
+ "ltp": round(base_price + change, 2),
113
+ "change": round(change, 2),
114
+ "change_percent": round(change_percent, 2),
115
+ "volume": random.randint(100000, 5000000),
116
+ "open": round(base_price * random.uniform(0.98, 1.02), 2),
117
+ "high": round(base_price * random.uniform(1.01, 1.05), 2),
118
+ "low": round(base_price * random.uniform(0.95, 0.99), 2),
119
+ "close": round(base_price, 2),
120
+ "timestamp": datetime.utcnow().isoformat()
121
+ }
122
+
123
+ async def get_historical_data(
124
+ self,
125
+ security_id: str,
126
+ exchange_segment: str,
127
+ from_date: str,
128
+ to_date: str,
129
+ interval: str = "day"
130
+ ) -> List[Dict[str, Any]]:
131
+ """Get historical data for a security"""
132
+ try:
133
+ # Try to get from Dhan API
134
+ request_data = {
135
+ "securityId": security_id,
136
+ "exchangeSegment": exchange_segment,
137
+ "instrument": "EQUITY",
138
+ "interval": interval,
139
+ "fromDate": from_date,
140
+ "toDate": to_date
141
+ }
142
+
143
+ try:
144
+ result = await self.dhan_api.get_historical_data(request_data)
145
+ if result and "data" in result:
146
+ return result["data"]
147
+ except DhanAPIError as e:
148
+ logger.warning(f"Dhan API error getting historical data: {e}")
149
+
150
+ # Fall back to mock data
151
+ return self._generate_mock_historical_data(from_date, to_date, interval)
152
+
153
+ except Exception as e:
154
+ logger.error(f"Error getting historical data: {e}")
155
+ return []
156
+
157
+ def _generate_mock_historical_data(self, from_date: str, to_date: str, interval: str) -> List[Dict[str, Any]]:
158
+ """Generate mock historical data"""
159
+ try:
160
+ start_date = datetime.strptime(from_date, "%Y-%m-%d")
161
+ end_date = datetime.strptime(to_date, "%Y-%m-%d")
162
+
163
+ data_points = []
164
+ current_date = start_date
165
+ base_price = random.uniform(1000, 3000)
166
+
167
+ while current_date <= end_date:
168
+ # Simulate price movement
169
+ change_percent = random.uniform(-2, 2)
170
+ open_price = base_price
171
+ close_price = base_price * (1 + change_percent / 100)
172
+ high_price = max(open_price, close_price) * random.uniform(1.001, 1.02)
173
+ low_price = min(open_price, close_price) * random.uniform(0.98, 0.999)
174
+
175
+ data_points.append({
176
+ "timestamp": current_date.strftime("%Y-%m-%d"),
177
+ "open": round(open_price, 2),
178
+ "high": round(high_price, 2),
179
+ "low": round(low_price, 2),
180
+ "close": round(close_price, 2),
181
+ "volume": random.randint(100000, 2000000)
182
+ })
183
+
184
+ base_price = close_price
185
+ current_date += timedelta(days=1)
186
+
187
+ return data_points
188
+
189
+ except Exception as e:
190
+ logger.error(f"Error generating mock historical data: {e}")
191
+ return []
192
+
193
+ async def get_live_quotes(self, symbols: List[str]) -> Dict[str, Dict[str, Any]]:
194
+ """Get live quotes for multiple symbols"""
195
+ try:
196
+ quotes = {}
197
+
198
+ for symbol in symbols:
199
+ # This is a simplified implementation
200
+ # In reality, you'd need to map symbols to security_ids
201
+ quote_data = await self.get_single_security_data(symbol, "NSE_EQ")
202
+ if quote_data:
203
+ quotes[symbol] = quote_data
204
+
205
+ return quotes
206
+
207
+ except Exception as e:
208
+ logger.error(f"Error getting live quotes: {e}")
209
+ return {}
210
+
211
+ async def get_market_indices(self) -> Dict[str, Dict[str, Any]]:
212
+ """Get major market indices data"""
213
+ try:
214
+ indices = {}
215
+
216
+ # Major Indian indices
217
+ index_mapping = {
218
+ "NIFTY": {"security_id": "26000", "exchange_segment": "NSE_INDEX"},
219
+ "SENSEX": {"security_id": "1", "exchange_segment": "BSE_INDEX"},
220
+ "BANKNIFTY": {"security_id": "26009", "exchange_segment": "NSE_INDEX"}
221
+ }
222
+
223
+ for index_name, mapping in index_mapping.items():
224
+ try:
225
+ data = await self.get_single_security_data(
226
+ mapping["security_id"],
227
+ mapping["exchange_segment"]
228
+ )
229
+ if data:
230
+ indices[index_name] = data
231
+ except Exception as e:
232
+ logger.warning(f"Error getting data for {index_name}: {e}")
233
+ # Add mock data for the index
234
+ indices[index_name] = self._generate_mock_index_data(index_name)
235
+
236
+ return indices
237
+
238
+ except Exception as e:
239
+ logger.error(f"Error getting market indices: {e}")
240
+ return {}
241
+
242
+ def _generate_mock_index_data(self, index_name: str) -> Dict[str, Any]:
243
+ """Generate mock index data"""
244
+ base_values = {
245
+ "NIFTY": 22000,
246
+ "SENSEX": 72000,
247
+ "BANKNIFTY": 47000
248
+ }
249
+
250
+ base_value = base_values.get(index_name, 10000)
251
+ change_percent = random.uniform(-1, 1.5)
252
+ change = base_value * change_percent / 100
253
+
254
+ return {
255
+ "symbol": index_name,
256
+ "ltp": round(base_value + change, 2),
257
+ "change": round(change, 2),
258
+ "change_percent": round(change_percent, 2),
259
+ "timestamp": datetime.utcnow().isoformat()
260
+ }
261
+
262
+ async def get_sector_data(self) -> List[Dict[str, Any]]:
263
+ """Get sector-wise performance data"""
264
+ try:
265
+ sectors = [
266
+ {
267
+ "sector": "Information Technology",
268
+ "change_percent": random.uniform(-2, 3),
269
+ "market_cap": random.uniform(8000000000000, 15000000000000),
270
+ "stocks_count": 25
271
+ },
272
+ {
273
+ "sector": "Banking",
274
+ "change_percent": random.uniform(-1.5, 2),
275
+ "market_cap": random.uniform(6000000000000, 12000000000000),
276
+ "stocks_count": 30
277
+ },
278
+ {
279
+ "sector": "Oil & Gas",
280
+ "change_percent": random.uniform(-1, 2.5),
281
+ "market_cap": random.uniform(4000000000000, 8000000000000),
282
+ "stocks_count": 15
283
+ },
284
+ {
285
+ "sector": "Pharmaceuticals",
286
+ "change_percent": random.uniform(-0.5, 1.8),
287
+ "market_cap": random.uniform(3000000000000, 6000000000000),
288
+ "stocks_count": 20
289
+ },
290
+ {
291
+ "sector": "Automotive",
292
+ "change_percent": random.uniform(-2, 1.5),
293
+ "market_cap": random.uniform(2000000000000, 5000000000000),
294
+ "stocks_count": 18
295
+ }
296
+ ]
297
+
298
+ return sectors
299
+
300
+ except Exception as e:
301
+ logger.error(f"Error getting sector data: {e}")
302
+ return []
303
+
304
+ async def get_top_movers(self, mover_type: str = "gainers", limit: int = 20) -> List[Dict[str, Any]]:
305
+ """Get top gainers or losers"""
306
+ try:
307
+ # This would typically fetch real data from the API
308
+ # For now, generating mock data
309
+
310
+ movers = []
311
+
312
+ for i in range(limit):
313
+ if mover_type == "gainers":
314
+ change_percent = random.uniform(3, 15)
315
+ else: # losers
316
+ change_percent = random.uniform(-15, -3)
317
+
318
+ base_price = random.uniform(100, 5000)
319
+ change = base_price * change_percent / 100
320
+
321
+ movers.append({
322
+ "symbol": f"STOCK{i+1}",
323
+ "security_id": f"{1000+i}",
324
+ "ltp": round(base_price + change, 2),
325
+ "change": round(change, 2),
326
+ "change_percent": round(change_percent, 2),
327
+ "volume": random.randint(500000, 10000000)
328
+ })
329
+
330
+ # Sort by change percentage
331
+ movers.sort(key=lambda x: x["change_percent"], reverse=(mover_type == "gainers"))
332
+
333
+ return movers
334
+
335
+ except Exception as e:
336
+ logger.error(f"Error getting top {mover_type}: {e}")
337
+ return []
338
+
339
+ async def get_option_chain(self, underlying_symbol: str, expiry_date: Optional[str] = None) -> Dict[str, Any]:
340
+ """Get option chain data"""
341
+ try:
342
+ # This is a complex endpoint that would need significant implementation
343
+ # For now, returning mock structure
344
+
345
+ return {
346
+ "underlying": underlying_symbol,
347
+ "expiry_date": expiry_date or (datetime.utcnow() + timedelta(days=30)).strftime("%Y-%m-%d"),
348
+ "underlying_price": random.uniform(2000, 3000),
349
+ "options": {
350
+ "calls": [
351
+ {
352
+ "strike": strike,
353
+ "ltp": random.uniform(10, 200),
354
+ "change": random.uniform(-20, 20),
355
+ "volume": random.randint(1000, 100000),
356
+ "oi": random.randint(5000, 500000)
357
+ }
358
+ for strike in range(2000, 3200, 50)
359
+ ],
360
+ "puts": [
361
+ {
362
+ "strike": strike,
363
+ "ltp": random.uniform(10, 200),
364
+ "change": random.uniform(-20, 20),
365
+ "volume": random.randint(1000, 100000),
366
+ "oi": random.randint(5000, 500000)
367
+ }
368
+ for strike in range(2000, 3200, 50)
369
+ ]
370
+ }
371
+ }
372
+
373
+ except Exception as e:
374
+ logger.error(f"Error getting option chain: {e}")
375
+ return {}
376
+
377
+ def clear_cache(self):
378
+ """Clear the market data cache"""
379
+ self.cache.clear()
380
+ logger.info("Market data cache cleared")
381
+
382
+ def get_cache_stats(self) -> Dict[str, Any]:
383
+ """Get cache statistics"""
384
+ now = datetime.utcnow()
385
+ active_entries = 0
386
+
387
+ for cache_key, (data, timestamp) in self.cache.items():
388
+ if now - timestamp < timedelta(seconds=self.cache_ttl):
389
+ active_entries += 1
390
+
391
+ return {
392
+ "total_entries": len(self.cache),
393
+ "active_entries": active_entries,
394
+ "cache_ttl": self.cache_ttl,
395
+ "last_cleared": now.isoformat()
396
+ }
services/portfolio_service.py ADDED
@@ -0,0 +1,374 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Portfolio Service
3
+ Handles portfolio calculations, analytics, and optimization
4
+ """
5
+
6
+ import logging
7
+ from typing import Dict, List, Any, Optional
8
+ from datetime import datetime, timedelta
9
+ import asyncio
10
+ import numpy as np
11
+ import pandas as pd
12
+
13
+ from models.portfolio import (
14
+ PortfolioSummary, PortfolioHolding, PerformanceMetrics,
15
+ AllocationRequest, RiskMetrics
16
+ )
17
+ from services.dhan_api_manager import dhan_api_manager
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+ class PortfolioService:
22
+ """Service for portfolio management and analytics"""
23
+
24
+ def __init__(self):
25
+ self.dhan_api = dhan_api_manager
26
+
27
+ async def get_portfolio_summary(self, user_id: str) -> PortfolioSummary:
28
+ """Get comprehensive portfolio summary"""
29
+ try:
30
+ # Get holdings and positions
31
+ holdings_result = await self.dhan_api.get_holdings()
32
+ holdings = holdings_result.get("data", [])
33
+
34
+ # Calculate summary metrics
35
+ total_value = sum(holding.get("currentValue", 0) for holding in holdings)
36
+ total_invested = sum(holding.get("avgPrice", 0) * holding.get("totalQty", 0) for holding in holdings)
37
+ total_pnl = total_value - total_invested
38
+ pnl_percentage = (total_pnl / total_invested * 100) if total_invested > 0 else 0
39
+
40
+ # Mock day's change calculation
41
+ day_change = total_value * 0.015 # Mock 1.5% daily change
42
+ day_change_percentage = 1.5
43
+
44
+ return PortfolioSummary(
45
+ total_value=total_value,
46
+ total_invested=total_invested,
47
+ total_pnl=total_pnl,
48
+ pnl_percentage=pnl_percentage,
49
+ day_change=day_change,
50
+ day_change_percentage=day_change_percentage,
51
+ holdings_count=len(holdings),
52
+ last_updated=datetime.utcnow()
53
+ )
54
+
55
+ except Exception as e:
56
+ logger.error(f"Error getting portfolio summary: {e}")
57
+ raise
58
+
59
+ async def get_holdings(self, user_id: str) -> List[Dict[str, Any]]:
60
+ """Get user's holdings"""
61
+ try:
62
+ holdings_result = await self.dhan_api.get_holdings()
63
+ return holdings_result.get("data", [])
64
+ except Exception as e:
65
+ logger.error(f"Error getting holdings: {e}")
66
+ raise
67
+
68
+ async def calculate_performance_metrics(
69
+ self,
70
+ user_id: str,
71
+ start_date: datetime,
72
+ end_date: datetime
73
+ ) -> PerformanceMetrics:
74
+ """Calculate portfolio performance metrics"""
75
+ try:
76
+ # Mock implementation - in real system would analyze historical data
77
+ days = (end_date - start_date).days
78
+
79
+ return PerformanceMetrics(
80
+ total_return=8.5,
81
+ annualized_return=12.8,
82
+ best_day=3.2,
83
+ worst_day=-2.8,
84
+ positive_days=int(days * 0.6),
85
+ negative_days=int(days * 0.4),
86
+ win_rate=60.0,
87
+ calmar_ratio=0.85,
88
+ tracking_error=2.1
89
+ )
90
+
91
+ except Exception as e:
92
+ logger.error(f"Error calculating performance metrics: {e}")
93
+ raise
94
+
95
+ async def get_asset_allocation(self, user_id: str) -> Dict[str, Any]:
96
+ """Get current asset allocation breakdown"""
97
+ try:
98
+ holdings = await self.get_holdings(user_id)
99
+
100
+ # Mock allocation calculation
101
+ # In real implementation, would categorize holdings by sector, market cap, etc.
102
+
103
+ return {
104
+ "sectors": {
105
+ "Information Technology": 35.5,
106
+ "Banking": 25.2,
107
+ "Pharmaceuticals": 15.8,
108
+ "Automotive": 12.3,
109
+ "Oil & Gas": 11.2
110
+ },
111
+ "market_caps": {
112
+ "Large Cap": 65.0,
113
+ "Mid Cap": 25.0,
114
+ "Small Cap": 10.0
115
+ },
116
+ "geography": {
117
+ "Domestic": 85.0,
118
+ "International": 15.0
119
+ },
120
+ "concentration": {
121
+ "top_5_holdings": 45.2,
122
+ "top_10_holdings": 67.8
123
+ },
124
+ "diversification_score": 75.5
125
+ }
126
+
127
+ except Exception as e:
128
+ logger.error(f"Error calculating asset allocation: {e}")
129
+ raise
130
+
131
+ async def calculate_rebalancing(
132
+ self,
133
+ current_portfolio: List[Dict],
134
+ target_allocation: AllocationRequest
135
+ ) -> Dict[str, Any]:
136
+ """Calculate portfolio rebalancing suggestions"""
137
+ try:
138
+ # Mock rebalancing calculation
139
+ total_value = sum(holding.get("currentValue", 0) for holding in current_portfolio)
140
+
141
+ return {
142
+ "current": {
143
+ "equity": 80.0,
144
+ "debt": 15.0,
145
+ "cash": 5.0
146
+ },
147
+ "target": {
148
+ "equity": 70.0,
149
+ "debt": 25.0,
150
+ "cash": 5.0
151
+ },
152
+ "trades": [
153
+ {
154
+ "action": "sell",
155
+ "security": "RELIANCE",
156
+ "quantity": 50,
157
+ "amount": 125000
158
+ },
159
+ {
160
+ "action": "buy",
161
+ "security": "DEBT_FUND",
162
+ "amount": 125000
163
+ }
164
+ ],
165
+ "cost": 500.0, # Transaction costs
166
+ "tax_impact": 2500.0 # Estimated tax
167
+ }
168
+
169
+ except Exception as e:
170
+ logger.error(f"Error calculating rebalancing: {e}")
171
+ raise
172
+
173
+ async def calculate_correlation_matrix(self, holdings: List[Dict]) -> List[List[float]]:
174
+ """Calculate correlation matrix for portfolio holdings"""
175
+ try:
176
+ # Mock correlation matrix
177
+ n_holdings = min(len(holdings), 10) # Limit to 10 holdings
178
+
179
+ # Generate mock correlation matrix
180
+ np.random.seed(42) # For consistent results
181
+ correlation_matrix = np.random.uniform(0.1, 0.8, (n_holdings, n_holdings))
182
+
183
+ # Make matrix symmetric and set diagonal to 1
184
+ for i in range(n_holdings):
185
+ for j in range(i, n_holdings):
186
+ if i == j:
187
+ correlation_matrix[i][j] = 1.0
188
+ else:
189
+ correlation_matrix[j][i] = correlation_matrix[i][j]
190
+
191
+ return correlation_matrix.tolist()
192
+
193
+ except Exception as e:
194
+ logger.error(f"Error calculating correlation matrix: {e}")
195
+ raise
196
+
197
+ async def get_portfolio_for_var_calculation(self, user_id: str) -> Dict[str, Any]:
198
+ """Get portfolio data formatted for VaR calculation"""
199
+ try:
200
+ holdings = await self.get_holdings(user_id)
201
+
202
+ return {
203
+ "holdings": holdings,
204
+ "total_value": sum(h.get("currentValue", 0) for h in holdings),
205
+ "weights": [h.get("currentValue", 0) for h in holdings],
206
+ "symbols": [h.get("tradingSymbol", "") for h in holdings]
207
+ }
208
+
209
+ except Exception as e:
210
+ logger.error(f"Error getting portfolio for VaR: {e}")
211
+ raise
212
+
213
+ async def optimize_portfolio(
214
+ self,
215
+ current_portfolio: List[Dict],
216
+ objective: str,
217
+ constraints: Dict[str, Any]
218
+ ) -> Dict[str, Any]:
219
+ """Run portfolio optimization"""
220
+ try:
221
+ # Mock optimization results
222
+ return {
223
+ "current": {
224
+ "expected_return": 12.5,
225
+ "volatility": 18.2,
226
+ "sharpe_ratio": 0.65
227
+ },
228
+ "optimal": {
229
+ "expected_return": 14.2,
230
+ "volatility": 16.8,
231
+ "sharpe_ratio": 0.82
232
+ },
233
+ "expected_return": 14.2,
234
+ "expected_volatility": 16.8,
235
+ "sharpe_ratio": 0.82,
236
+ "changes": [
237
+ {
238
+ "security": "RELIANCE",
239
+ "current_weight": 15.0,
240
+ "optimal_weight": 12.0,
241
+ "change": -3.0
242
+ },
243
+ {
244
+ "security": "TCS",
245
+ "current_weight": 10.0,
246
+ "optimal_weight": 13.0,
247
+ "change": 3.0
248
+ }
249
+ ]
250
+ }
251
+
252
+ except Exception as e:
253
+ logger.error(f"Error optimizing portfolio: {e}")
254
+ raise
255
+
256
+ async def get_holdings_with_cost_basis(self, user_id: str) -> List[Dict[str, Any]]:
257
+ """Get holdings with detailed cost basis information"""
258
+ try:
259
+ holdings = await self.get_holdings(user_id)
260
+
261
+ # Add mock cost basis data
262
+ for holding in holdings:
263
+ holding["purchase_date"] = "2023-01-15" # Mock data
264
+ holding["cost_basis"] = holding.get("avgPrice", 0) * holding.get("totalQty", 0)
265
+ holding["days_held"] = 180 # Mock data
266
+ holding["is_long_term"] = holding["days_held"] > 365
267
+
268
+ return holdings
269
+
270
+ except Exception as e:
271
+ logger.error(f"Error getting holdings with cost basis: {e}")
272
+ raise
273
+
274
+ async def calculate_tax_optimization(
275
+ self,
276
+ holdings: List[Dict],
277
+ tax_year: int
278
+ ) -> Dict[str, Any]:
279
+ """Calculate tax optimization strategies"""
280
+ try:
281
+ # Mock tax optimization calculation
282
+ total_gains = sum(h.get("pnl", 0) for h in holdings if h.get("pnl", 0) > 0)
283
+ total_losses = sum(h.get("pnl", 0) for h in holdings if h.get("pnl", 0) < 0)
284
+
285
+ return {
286
+ "current_liability": max(0, total_gains * 0.15), # Mock 15% tax
287
+ "loss_harvesting": [
288
+ {
289
+ "security": h.get("tradingSymbol"),
290
+ "unrealized_loss": h.get("pnl", 0),
291
+ "tax_benefit": abs(h.get("pnl", 0)) * 0.15
292
+ }
293
+ for h in holdings if h.get("pnl", 0) < -5000 # Loss > 5000
294
+ ][:5], # Top 5 opportunities
295
+ "wash_sale_warnings": [],
296
+ "actions": [
297
+ "Consider booking losses in HDFC before year-end",
298
+ "Hold TCS for long-term capital gains treatment"
299
+ ],
300
+ "potential_savings": abs(total_losses) * 0.15
301
+ }
302
+
303
+ except Exception as e:
304
+ logger.error(f"Error calculating tax optimization: {e}")
305
+ raise
306
+
307
+ async def compare_with_benchmark(
308
+ self,
309
+ user_id: str,
310
+ benchmark: str,
311
+ start_date: datetime,
312
+ end_date: datetime
313
+ ) -> Dict[str, Any]:
314
+ """Compare portfolio performance with benchmark"""
315
+ try:
316
+ # Mock benchmark comparison
317
+ return {
318
+ "portfolio_return": 12.5,
319
+ "benchmark_return": 8.7,
320
+ "alpha": 3.8,
321
+ "beta": 1.15,
322
+ "tracking_error": 4.2,
323
+ "information_ratio": 0.90,
324
+ "up_capture": 110.5,
325
+ "down_capture": 85.2
326
+ }
327
+
328
+ except Exception as e:
329
+ logger.error(f"Error comparing with benchmark: {e}")
330
+ raise
331
+
332
+ async def simulate_portfolio_changes(
333
+ self,
334
+ current_portfolio: List[Dict],
335
+ proposed_changes: List[Dict[str, Any]]
336
+ ) -> Dict[str, Any]:
337
+ """Simulate the impact of portfolio changes"""
338
+ try:
339
+ # Mock simulation results
340
+ current_value = sum(h.get("currentValue", 0) for h in current_portfolio)
341
+
342
+ return {
343
+ "current": {
344
+ "value": current_value,
345
+ "risk": 18.5,
346
+ "expected_return": 12.0,
347
+ "sharpe_ratio": 0.65
348
+ },
349
+ "projected": {
350
+ "value": current_value * 1.05, # 5% increase assumed
351
+ "risk": 16.8,
352
+ "expected_return": 13.2,
353
+ "sharpe_ratio": 0.78
354
+ },
355
+ "impact": {
356
+ "risk_change": -1.7,
357
+ "return_change": 1.2,
358
+ "sharpe_improvement": 0.13
359
+ },
360
+ "risk_change": {
361
+ "before": 18.5,
362
+ "after": 16.8,
363
+ "improvement": True
364
+ },
365
+ "recommendations": [
366
+ "The proposed changes will improve risk-adjusted returns",
367
+ "Consider gradual implementation over 2-3 weeks",
368
+ "Monitor market conditions before executing"
369
+ ]
370
+ }
371
+
372
+ except Exception as e:
373
+ logger.error(f"Error simulating portfolio changes: {e}")
374
+ raise
services/risk_engine.py ADDED
@@ -0,0 +1,390 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Risk Engine Service
3
+ Handles portfolio risk assessment, VaR calculations, and stress testing
4
+ """
5
+
6
+ import logging
7
+ from typing import Dict, List, Any, Optional
8
+ from datetime import datetime, timedelta
9
+ import numpy as np
10
+ import pandas as pd
11
+
12
+ from models.portfolio import RiskAssessment, RiskLevel
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+ class RiskEngine:
17
+ """Service for portfolio risk analysis and management"""
18
+
19
+ def __init__(self):
20
+ self.confidence_levels = [0.95, 0.99]
21
+ self.time_horizons = [1, 5, 21] # 1 day, 1 week, 1 month
22
+
23
+ async def assess_portfolio_risk(self, holdings: List[Dict[str, Any]]) -> RiskAssessment:
24
+ """Perform comprehensive portfolio risk assessment"""
25
+ try:
26
+ # Calculate various risk metrics
27
+ concentration_risk = self._calculate_concentration_risk(holdings)
28
+ sector_concentration = self._calculate_sector_concentration(holdings)
29
+ liquidity_risk = self._calculate_liquidity_risk(holdings)
30
+ currency_risk = self._calculate_currency_risk(holdings)
31
+
32
+ # Calculate overall risk score
33
+ overall_risk_score = self._calculate_overall_risk_score(
34
+ concentration_risk, liquidity_risk, currency_risk
35
+ )
36
+
37
+ # Determine risk level
38
+ risk_level = self._determine_risk_level(overall_risk_score)
39
+
40
+ # Generate recommendations
41
+ recommendations = self._generate_risk_recommendations(
42
+ overall_risk_score, concentration_risk, sector_concentration
43
+ )
44
+
45
+ return RiskAssessment(
46
+ overall_risk_score=overall_risk_score,
47
+ risk_level=risk_level,
48
+ concentration_risk=concentration_risk,
49
+ sector_concentration=sector_concentration,
50
+ liquidity_risk=liquidity_risk,
51
+ currency_risk=currency_risk,
52
+ recommendations=recommendations
53
+ )
54
+
55
+ except Exception as e:
56
+ logger.error(f"Error assessing portfolio risk: {e}")
57
+ raise
58
+
59
+ async def calculate_value_at_risk(
60
+ self,
61
+ portfolio_data: Dict[str, Any],
62
+ confidence_level: float = 0.95,
63
+ time_horizon: int = 1
64
+ ) -> Dict[str, Any]:
65
+ """Calculate Value at Risk using multiple methods"""
66
+ try:
67
+ portfolio_value = portfolio_data.get("total_value", 0)
68
+ holdings = portfolio_data.get("holdings", [])
69
+
70
+ # Mock VaR calculations (in real system would use historical data)
71
+ # Historical VaR
72
+ historical_var = portfolio_value * 0.02 * np.sqrt(time_horizon)
73
+
74
+ # Parametric VaR
75
+ portfolio_volatility = 0.18 # Mock 18% annual volatility
76
+ daily_volatility = portfolio_volatility / np.sqrt(252)
77
+ z_score = 1.645 if confidence_level == 0.95 else 2.326 # For 95% and 99%
78
+ parametric_var = portfolio_value * z_score * daily_volatility * np.sqrt(time_horizon)
79
+
80
+ # Monte Carlo VaR
81
+ monte_carlo_var = portfolio_value * 0.025 * np.sqrt(time_horizon)
82
+
83
+ # Expected Shortfall (Conditional VaR)
84
+ expected_shortfall = parametric_var * 1.3 # Typically 30% higher than VaR
85
+
86
+ return {
87
+ "historical": historical_var,
88
+ "parametric": parametric_var,
89
+ "monte_carlo": monte_carlo_var,
90
+ "expected_shortfall": expected_shortfall,
91
+ "portfolio_value": portfolio_value,
92
+ "confidence_level": confidence_level,
93
+ "time_horizon": time_horizon
94
+ }
95
+
96
+ except Exception as e:
97
+ logger.error(f"Error calculating VaR: {e}")
98
+ raise
99
+
100
+ async def run_stress_testing(
101
+ self,
102
+ portfolio_data: List[Dict[str, Any]],
103
+ scenario: str
104
+ ) -> Dict[str, Any]:
105
+ """Run stress testing scenarios on portfolio"""
106
+ try:
107
+ total_value = sum(holding.get("currentValue", 0) for holding in portfolio_data)
108
+
109
+ # Define stress scenarios
110
+ scenarios = {
111
+ "market_crash": {
112
+ "description": "20% market decline scenario",
113
+ "market_impact": -0.20,
114
+ "sector_impacts": {
115
+ "Banking": -0.25,
116
+ "IT": -0.15,
117
+ "Auto": -0.30,
118
+ "Pharma": -0.10
119
+ }
120
+ },
121
+ "interest_rate_shock": {
122
+ "description": "2% interest rate increase",
123
+ "market_impact": -0.12,
124
+ "sector_impacts": {
125
+ "Banking": 0.05, # Banks benefit from higher rates
126
+ "Real Estate": -0.20,
127
+ "IT": -0.10,
128
+ "FMCG": -0.08
129
+ }
130
+ },
131
+ "inflation_surge": {
132
+ "description": "High inflation scenario",
133
+ "market_impact": -0.15,
134
+ "sector_impacts": {
135
+ "FMCG": -0.20,
136
+ "Auto": -0.18,
137
+ "IT": -0.05,
138
+ "Commodities": 0.10
139
+ }
140
+ }
141
+ }
142
+
143
+ scenario_data = scenarios.get(scenario, scenarios["market_crash"])
144
+
145
+ # Calculate portfolio impact
146
+ portfolio_impact = total_value * scenario_data["market_impact"]
147
+ impact_percentage = scenario_data["market_impact"] * 100
148
+
149
+ # Calculate most affected holdings (mock)
150
+ worst_affected = [
151
+ {
152
+ "symbol": "HDFC",
153
+ "current_value": 150000,
154
+ "stressed_value": 120000,
155
+ "impact": -20.0
156
+ },
157
+ {
158
+ "symbol": "MARUTI",
159
+ "current_value": 200000,
160
+ "stressed_value": 140000,
161
+ "impact": -30.0
162
+ }
163
+ ]
164
+
165
+ recommendations = self._generate_stress_test_recommendations(scenario, impact_percentage)
166
+
167
+ return {
168
+ "scenario_description": scenario_data["description"],
169
+ "portfolio_impact": portfolio_impact,
170
+ "impact_percentage": impact_percentage,
171
+ "worst_affected_holdings": worst_affected,
172
+ "recommendations": recommendations,
173
+ "recovery_time_estimate": "6-12 months",
174
+ "hedging_suggestions": [
175
+ "Consider defensive sector allocation",
176
+ "Add hedge instruments like gold or bonds",
177
+ "Implement stop-loss orders for high-risk positions"
178
+ ]
179
+ }
180
+
181
+ except Exception as e:
182
+ logger.error(f"Error running stress testing: {e}")
183
+ raise
184
+
185
+ def _calculate_concentration_risk(self, holdings: List[Dict[str, Any]]) -> float:
186
+ """Calculate portfolio concentration risk using Herfindahl index"""
187
+ try:
188
+ total_value = sum(holding.get("currentValue", 0) for holding in holdings)
189
+ if total_value == 0:
190
+ return 0.0
191
+
192
+ # Calculate weights
193
+ weights = [holding.get("currentValue", 0) / total_value for holding in holdings]
194
+
195
+ # Calculate Herfindahl index
196
+ herfindahl_index = sum(w * w for w in weights)
197
+
198
+ # Convert to risk score (0-100)
199
+ concentration_risk = herfindahl_index * 100
200
+
201
+ return round(concentration_risk, 2)
202
+
203
+ except Exception as e:
204
+ logger.error(f"Error calculating concentration risk: {e}")
205
+ return 50.0 # Default moderate risk
206
+
207
+ def _calculate_sector_concentration(self, holdings: List[Dict[str, Any]]) -> Dict[str, float]:
208
+ """Calculate sector-wise concentration"""
209
+ try:
210
+ # Mock sector allocation (in real system would get from security master)
211
+ sector_mapping = {
212
+ "RELIANCE": "Oil & Gas",
213
+ "TCS": "Information Technology",
214
+ "HDFC": "Banking",
215
+ "INFY": "Information Technology",
216
+ "ICICI": "Banking",
217
+ "WIPRO": "Information Technology"
218
+ }
219
+
220
+ sector_values = {}
221
+ total_value = sum(holding.get("currentValue", 0) for holding in holdings)
222
+
223
+ for holding in holdings:
224
+ symbol = holding.get("tradingSymbol", "")
225
+ sector = sector_mapping.get(symbol, "Others")
226
+ value = holding.get("currentValue", 0)
227
+
228
+ if sector in sector_values:
229
+ sector_values[sector] += value
230
+ else:
231
+ sector_values[sector] = value
232
+
233
+ # Convert to percentages
234
+ sector_percentages = {}
235
+ for sector, value in sector_values.items():
236
+ sector_percentages[sector] = round((value / total_value) * 100, 2) if total_value > 0 else 0
237
+
238
+ return sector_percentages
239
+
240
+ except Exception as e:
241
+ logger.error(f"Error calculating sector concentration: {e}")
242
+ return {}
243
+
244
+ def _calculate_liquidity_risk(self, holdings: List[Dict[str, Any]]) -> float:
245
+ """Calculate portfolio liquidity risk"""
246
+ try:
247
+ # Mock liquidity scoring based on market cap and trading volume
248
+ total_value = sum(holding.get("currentValue", 0) for holding in holdings)
249
+
250
+ liquidity_scores = []
251
+ for holding in holdings:
252
+ # Mock liquidity score (in real system would use actual volume data)
253
+ symbol = holding.get("tradingSymbol", "")
254
+ if symbol in ["RELIANCE", "TCS", "HDFC", "INFY"]:
255
+ liquidity_score = 90 # High liquidity
256
+ elif symbol in ["WIPRO", "ICICI"]:
257
+ liquidity_score = 70 # Medium liquidity
258
+ else:
259
+ liquidity_score = 50 # Lower liquidity
260
+
261
+ weight = holding.get("currentValue", 0) / total_value if total_value > 0 else 0
262
+ liquidity_scores.append(liquidity_score * weight)
263
+
264
+ weighted_liquidity = sum(liquidity_scores)
265
+ liquidity_risk = 100 - weighted_liquidity # Higher score = lower risk
266
+
267
+ return round(liquidity_risk, 2)
268
+
269
+ except Exception as e:
270
+ logger.error(f"Error calculating liquidity risk: {e}")
271
+ return 30.0 # Default moderate liquidity risk
272
+
273
+ def _calculate_currency_risk(self, holdings: List[Dict[str, Any]]) -> float:
274
+ """Calculate currency risk exposure"""
275
+ try:
276
+ # For Indian equities, currency risk is generally low
277
+ # In real system would consider international holdings
278
+
279
+ total_value = sum(holding.get("currentValue", 0) for holding in holdings)
280
+ international_exposure = 0 # Mock - no international holdings
281
+
282
+ currency_risk = (international_exposure / total_value) * 50 if total_value > 0 else 0
283
+
284
+ return round(currency_risk, 2)
285
+
286
+ except Exception as e:
287
+ logger.error(f"Error calculating currency risk: {e}")
288
+ return 5.0 # Default low currency risk
289
+
290
+ def _calculate_overall_risk_score(
291
+ self,
292
+ concentration_risk: float,
293
+ liquidity_risk: float,
294
+ currency_risk: float
295
+ ) -> float:
296
+ """Calculate overall portfolio risk score"""
297
+ try:
298
+ # Weighted average of different risk components
299
+ weights = {
300
+ "concentration": 0.4,
301
+ "liquidity": 0.3,
302
+ "currency": 0.1,
303
+ "market": 0.2 # Mock market risk component
304
+ }
305
+
306
+ market_risk = 45.0 # Mock market risk score
307
+
308
+ overall_score = (
309
+ concentration_risk * weights["concentration"] +
310
+ liquidity_risk * weights["liquidity"] +
311
+ currency_risk * weights["currency"] +
312
+ market_risk * weights["market"]
313
+ )
314
+
315
+ return round(min(overall_score, 100), 1)
316
+
317
+ except Exception as e:
318
+ logger.error(f"Error calculating overall risk score: {e}")
319
+ return 50.0 # Default moderate risk
320
+
321
+ def _determine_risk_level(self, risk_score: float) -> RiskLevel:
322
+ """Determine risk level based on risk score"""
323
+ if risk_score < 30:
324
+ return RiskLevel.LOW
325
+ elif risk_score < 70:
326
+ return RiskLevel.MEDIUM
327
+ else:
328
+ return RiskLevel.HIGH
329
+
330
+ def _generate_risk_recommendations(
331
+ self,
332
+ overall_risk: float,
333
+ concentration_risk: float,
334
+ sector_concentration: Dict[str, float]
335
+ ) -> List[str]:
336
+ """Generate risk mitigation recommendations"""
337
+ recommendations = []
338
+
339
+ if overall_risk > 70:
340
+ recommendations.append("Portfolio has high risk. Consider immediate diversification.")
341
+
342
+ if concentration_risk > 50:
343
+ recommendations.append("High concentration risk detected. Diversify across more securities.")
344
+
345
+ # Check for sector concentration
346
+ for sector, percentage in sector_concentration.items():
347
+ if percentage > 40:
348
+ recommendations.append(f"Overweight in {sector} sector ({percentage}%). Consider rebalancing.")
349
+
350
+ if overall_risk < 30:
351
+ recommendations.append("Portfolio appears well-diversified with low risk.")
352
+ elif overall_risk < 50:
353
+ recommendations.append("Portfolio has moderate risk levels. Monitor regularly.")
354
+
355
+ # Add general recommendations
356
+ if len(recommendations) == 0:
357
+ recommendations.append("Portfolio risk levels are within acceptable ranges.")
358
+
359
+ recommendations.append("Regularly review and rebalance your portfolio.")
360
+ recommendations.append("Consider your risk tolerance and investment horizon.")
361
+
362
+ return recommendations[:5] # Limit to top 5 recommendations
363
+
364
+ def _generate_stress_test_recommendations(self, scenario: str, impact_percentage: float) -> List[str]:
365
+ """Generate recommendations based on stress test results"""
366
+ recommendations = []
367
+
368
+ if abs(impact_percentage) > 25:
369
+ recommendations.append("High portfolio vulnerability detected. Consider defensive positioning.")
370
+
371
+ if scenario == "market_crash":
372
+ recommendations.extend([
373
+ "Consider increasing cash allocation during market uncertainty",
374
+ "Add defensive sectors like utilities and consumer staples",
375
+ "Implement systematic stop-loss orders"
376
+ ])
377
+ elif scenario == "interest_rate_shock":
378
+ recommendations.extend([
379
+ "Reduce exposure to interest-sensitive sectors",
380
+ "Consider floating-rate instruments",
381
+ "Review debt component of portfolio"
382
+ ])
383
+ elif scenario == "inflation_surge":
384
+ recommendations.extend([
385
+ "Add inflation-hedged assets like commodities",
386
+ "Consider real estate and infrastructure exposure",
387
+ "Review companies with pricing power"
388
+ ])
389
+
390
+ return recommendations[:5] # Limit to top 5
services/websocket_manager.py CHANGED
@@ -1,6 +1,7 @@
1
  """
2
  WebSocket Manager for Zyon Traders Backend
3
  Handles real-time data distribution to connected clients
 
4
  """
5
 
6
  import asyncio
@@ -11,6 +12,8 @@ from fastapi import WebSocket
11
  from datetime import datetime
12
  import random
13
 
 
 
14
  logger = logging.getLogger(__name__)
15
 
16
 
@@ -20,8 +23,15 @@ class WebSocketManager:
20
  def __init__(self):
21
  self.active_connections: List[WebSocket] = []
22
  self.subscriptions: Dict[WebSocket, Set[str]] = {}
 
23
  self.background_task: asyncio.Task = None
24
  self.running = False
 
 
 
 
 
 
25
 
26
  async def connect(self, websocket: WebSocket):
27
  """Accept new WebSocket connection"""
@@ -42,6 +52,25 @@ class WebSocketManager:
42
  """Subscribe to specific symbols for real-time updates"""
43
  if websocket in self.subscriptions:
44
  self.subscriptions[websocket].update(symbols)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  await websocket.send_text(json.dumps({
46
  "type": "subscription_success",
47
  "symbols": symbols,
@@ -108,31 +137,33 @@ class WebSocketManager:
108
 
109
  while self.running:
110
  try:
111
- # Simulate price updates
112
- for symbol in symbols:
113
- # Generate realistic price movement
114
- base_price = {
115
- "NIFTY": 22000,
116
- "SENSEX": 72000,
117
- "BANKNIFTY": 47000,
118
- "RELIANCE": 2500,
119
- "TCS": 3800,
120
- "INFY": 1700,
121
- "HDFC": 1600
122
- }.get(symbol, 1000)
123
-
124
- change_percent = random.uniform(-2, 2)
125
- price = base_price * (1 + change_percent / 100)
126
-
127
- data = {
128
- "price": round(price, 2),
129
- "change": round(base_price * change_percent / 100, 2),
130
- "change_percent": round(change_percent, 2),
131
- "volume": random.randint(100000, 10000000),
132
- "last_updated": datetime.utcnow().isoformat()
133
- }
134
-
135
- await self.broadcast_to_symbol_subscribers(symbol, data)
 
 
136
 
137
  # Wait before next update
138
  await asyncio.sleep(2) # Update every 2 seconds
@@ -141,9 +172,135 @@ class WebSocketManager:
141
  logger.error(f"Error in market data simulation: {e}")
142
  await asyncio.sleep(5)
143
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
  async def start_background_tasks(self):
145
- """Start background tasks for data simulation"""
146
  self.running = True
 
 
 
 
 
147
  self.background_task = asyncio.create_task(self.simulate_market_data())
148
  logger.info("WebSocket background tasks started")
149
 
 
1
  """
2
  WebSocket Manager for Zyon Traders Backend
3
  Handles real-time data distribution to connected clients
4
+ Integrates with Dhan WebSocket service for real-time market data
5
  """
6
 
7
  import asyncio
 
12
  from datetime import datetime
13
  import random
14
 
15
+ from .dhan_websocket import dhan_ws_client, MarketDataUpdate, OrderUpdate, MarketDepthUpdate
16
+
17
  logger = logging.getLogger(__name__)
18
 
19
 
 
23
  def __init__(self):
24
  self.active_connections: List[WebSocket] = []
25
  self.subscriptions: Dict[WebSocket, Set[str]] = {}
26
+ self.symbol_to_instrument: Dict[str, Dict[str, str]] = {}
27
  self.background_task: asyncio.Task = None
28
  self.running = False
29
+ self.dhan_connected = False
30
+
31
+ # Set up Dhan WebSocket callbacks
32
+ dhan_ws_client.set_market_data_callback(self._on_market_data_update)
33
+ dhan_ws_client.set_order_update_callback(self._on_order_update)
34
+ dhan_ws_client.set_market_depth_callback(self._on_market_depth_update)
35
 
36
  async def connect(self, websocket: WebSocket):
37
  """Accept new WebSocket connection"""
 
52
  """Subscribe to specific symbols for real-time updates"""
53
  if websocket in self.subscriptions:
54
  self.subscriptions[websocket].update(symbols)
55
+
56
+ # Convert symbols to Dhan instruments and subscribe to Dhan WebSocket
57
+ instruments_to_subscribe = []
58
+ for symbol in symbols:
59
+ # Map common symbols to Dhan instrument format
60
+ instrument = self._get_instrument_from_symbol(symbol)
61
+ if instrument:
62
+ self.symbol_to_instrument[symbol] = instrument
63
+ instruments_to_subscribe.append(instrument)
64
+
65
+ # Subscribe to Dhan WebSocket if we have instruments
66
+ if instruments_to_subscribe and self.dhan_connected:
67
+ await dhan_ws_client.subscribe_market_data(instruments_to_subscribe)
68
+
69
+ # Also subscribe to market depth for the first few instruments (limited to 50)
70
+ depth_instruments = instruments_to_subscribe[:50] # Limit to 50 for market depth
71
+ if depth_instruments:
72
+ await dhan_ws_client.subscribe_market_depth(depth_instruments)
73
+
74
  await websocket.send_text(json.dumps({
75
  "type": "subscription_success",
76
  "symbols": symbols,
 
137
 
138
  while self.running:
139
  try:
140
+ # Only simulate if Dhan is not connected
141
+ if not self.dhan_connected:
142
+ # Simulate price updates
143
+ for symbol in symbols:
144
+ # Generate realistic price movement
145
+ base_price = {
146
+ "NIFTY": 22000,
147
+ "SENSEX": 72000,
148
+ "BANKNIFTY": 47000,
149
+ "RELIANCE": 2500,
150
+ "TCS": 3800,
151
+ "INFY": 1700,
152
+ "HDFC": 1600
153
+ }.get(symbol, 1000)
154
+
155
+ change_percent = random.uniform(-2, 2)
156
+ price = base_price * (1 + change_percent / 100)
157
+
158
+ data = {
159
+ "price": round(price, 2),
160
+ "change": round(base_price * change_percent / 100, 2),
161
+ "change_percent": round(change_percent, 2),
162
+ "volume": random.randint(100000, 10000000),
163
+ "last_updated": datetime.utcnow().isoformat()
164
+ }
165
+
166
+ await self.broadcast_to_symbol_subscribers(symbol, data)
167
 
168
  # Wait before next update
169
  await asyncio.sleep(2) # Update every 2 seconds
 
172
  logger.error(f"Error in market data simulation: {e}")
173
  await asyncio.sleep(5)
174
 
175
+ def _get_instrument_from_symbol(self, symbol: str) -> Dict[str, str]:
176
+ """Convert symbol to Dhan instrument format"""
177
+ # Common symbol mappings (this should be expanded based on actual requirements)
178
+ symbol_mappings = {
179
+ "NIFTY": {"security_id": "26000", "exchange_segment": "NSE_INDEX"},
180
+ "BANKNIFTY": {"security_id": "26009", "exchange_segment": "NSE_INDEX"},
181
+ "SENSEX": {"security_id": "1", "exchange_segment": "BSE_INDEX"},
182
+ "RELIANCE": {"security_id": "2885", "exchange_segment": "NSE_EQ"},
183
+ "TCS": {"security_id": "11723", "exchange_segment": "NSE_EQ"},
184
+ "INFY": {"security_id": "1594", "exchange_segment": "NSE_EQ"},
185
+ "HDFC": {"security_id": "1333", "exchange_segment": "NSE_EQ"},
186
+ }
187
+
188
+ return symbol_mappings.get(symbol.upper())
189
+
190
+ def _get_symbol_from_instrument(self, security_id: str, exchange_segment: str) -> str:
191
+ """Convert Dhan instrument back to symbol"""
192
+ # Reverse lookup
193
+ for symbol, instrument in self.symbol_to_instrument.items():
194
+ if (instrument["security_id"] == security_id and
195
+ instrument["exchange_segment"] == exchange_segment):
196
+ return symbol
197
+
198
+ # Fallback to security_id if no mapping found
199
+ return f"{exchange_segment}:{security_id}"
200
+
201
+ async def _on_market_data_update(self, update: MarketDataUpdate):
202
+ """Handle market data updates from Dhan WebSocket"""
203
+ try:
204
+ symbol = self._get_symbol_from_instrument(update.security_id, update.exchange_segment)
205
+
206
+ data = {
207
+ "security_id": update.security_id,
208
+ "exchange_segment": update.exchange_segment,
209
+ "price": update.ltp,
210
+ "change": update.change,
211
+ "change_percent": update.change_percent,
212
+ "volume": update.volume,
213
+ "open": update.open_price,
214
+ "high": update.high_price,
215
+ "low": update.low_price,
216
+ "close": update.close_price,
217
+ "total_buy_qty": update.total_buy_qty,
218
+ "total_sell_qty": update.total_sell_qty,
219
+ "timestamp": update.timestamp.isoformat()
220
+ }
221
+
222
+ await self.broadcast_to_symbol_subscribers(symbol, data)
223
+
224
+ except Exception as e:
225
+ logger.error(f"Error processing market data update: {e}")
226
+
227
+ async def _on_order_update(self, update: OrderUpdate):
228
+ """Handle order updates from Dhan WebSocket"""
229
+ try:
230
+ data = {
231
+ "type": "order_update",
232
+ "order_id": update.order_id,
233
+ "client_id": update.client_id,
234
+ "security_id": update.security_id,
235
+ "exchange_segment": update.exchange_segment,
236
+ "status": update.order_status,
237
+ "transaction_type": update.transaction_type,
238
+ "quantity": update.quantity,
239
+ "price": update.price,
240
+ "traded_qty": update.traded_qty,
241
+ "traded_price": update.traded_price,
242
+ "avg_traded_price": update.avg_traded_price,
243
+ "timestamp": update.update_datetime.isoformat()
244
+ }
245
+
246
+ await self.broadcast_to_all(data)
247
+
248
+ except Exception as e:
249
+ logger.error(f"Error processing order update: {e}")
250
+
251
+ async def _on_market_depth_update(self, update: MarketDepthUpdate):
252
+ """Handle market depth updates from Dhan WebSocket"""
253
+ try:
254
+ symbol = self._get_symbol_from_instrument(update.security_id, update.exchange_segment)
255
+
256
+ data = {
257
+ "type": "market_depth",
258
+ "security_id": update.security_id,
259
+ "exchange_segment": update.exchange_segment,
260
+ "bids": update.bids,
261
+ "asks": update.asks,
262
+ "timestamp": update.timestamp.isoformat()
263
+ }
264
+
265
+ await self.broadcast_to_symbol_subscribers(symbol, data)
266
+
267
+ except Exception as e:
268
+ logger.error(f"Error processing market depth update: {e}")
269
+
270
+ async def connect_to_dhan(self):
271
+ """Connect to Dhan WebSocket services"""
272
+ try:
273
+ # Connect to market feed
274
+ market_feed_connected = await dhan_ws_client.connect_market_feed()
275
+
276
+ # Connect to order updates
277
+ order_update_connected = await dhan_ws_client.connect_order_updates()
278
+
279
+ # Connect to market depth (optional)
280
+ market_depth_connected = await dhan_ws_client.connect_market_depth()
281
+
282
+ self.dhan_connected = market_feed_connected
283
+
284
+ if self.dhan_connected:
285
+ logger.info("Connected to Dhan WebSocket services")
286
+ else:
287
+ logger.error("Failed to connect to Dhan WebSocket services")
288
+
289
+ return self.dhan_connected
290
+
291
+ except Exception as e:
292
+ logger.error(f"Error connecting to Dhan WebSocket: {e}")
293
+ self.dhan_connected = False
294
+ return False
295
+
296
  async def start_background_tasks(self):
297
+ """Start background tasks for data simulation and Dhan connection"""
298
  self.running = True
299
+
300
+ # Try to connect to Dhan WebSocket first
301
+ await self.connect_to_dhan()
302
+
303
+ # Start simulation task (will only simulate if Dhan is not connected)
304
  self.background_task = asyncio.create_task(self.simulate_market_data())
305
  logger.info("WebSocket background tasks started")
306
 
utils/cache.py ADDED
@@ -0,0 +1,259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Cache utilities for Redis and in-memory caching
3
+ """
4
+
5
+ import logging
6
+ import json
7
+ from typing import Any, Optional
8
+ from datetime import datetime, timedelta
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+ # In-memory cache as fallback
13
+ _memory_cache = {}
14
+ _cache_timestamps = {}
15
+
16
+ class MemoryCache:
17
+ """Simple in-memory cache implementation"""
18
+
19
+ def __init__(self):
20
+ self.cache = {}
21
+ self.timestamps = {}
22
+
23
+ async def get(self, key: str) -> Optional[Any]:
24
+ """Get value from cache"""
25
+ try:
26
+ if key in self.cache and key in self.timestamps:
27
+ # Check if not expired (default 5 minutes)
28
+ if datetime.utcnow() - self.timestamps[key] < timedelta(minutes=5):
29
+ return self.cache[key]
30
+ else:
31
+ # Remove expired entry
32
+ del self.cache[key]
33
+ del self.timestamps[key]
34
+ return None
35
+ except Exception as e:
36
+ logger.error(f"Error getting from memory cache: {e}")
37
+ return None
38
+
39
+ async def set(self, key: str, value: Any) -> bool:
40
+ """Set value in cache"""
41
+ try:
42
+ self.cache[key] = value
43
+ self.timestamps[key] = datetime.utcnow()
44
+ return True
45
+ except Exception as e:
46
+ logger.error(f"Error setting memory cache: {e}")
47
+ return False
48
+
49
+ async def setex(self, key: str, ttl: int, value: Any) -> bool:
50
+ """Set value with expiration time"""
51
+ try:
52
+ self.cache[key] = value
53
+ self.timestamps[key] = datetime.utcnow()
54
+ # TTL is ignored in this simple implementation
55
+ return True
56
+ except Exception as e:
57
+ logger.error(f"Error setting memory cache with TTL: {e}")
58
+ return False
59
+
60
+ async def delete(self, key: str) -> bool:
61
+ """Delete key from cache"""
62
+ try:
63
+ if key in self.cache:
64
+ del self.cache[key]
65
+ if key in self.timestamps:
66
+ del self.timestamps[key]
67
+ return True
68
+ except Exception as e:
69
+ logger.error(f"Error deleting from memory cache: {e}")
70
+ return False
71
+
72
+ async def exists(self, key: str) -> bool:
73
+ """Check if key exists in cache"""
74
+ try:
75
+ if key in self.cache and key in self.timestamps:
76
+ # Check if not expired
77
+ if datetime.utcnow() - self.timestamps[key] < timedelta(minutes=5):
78
+ return True
79
+ else:
80
+ # Remove expired entry
81
+ del self.cache[key]
82
+ del self.timestamps[key]
83
+ return False
84
+ except Exception as e:
85
+ logger.error(f"Error checking memory cache existence: {e}")
86
+ return False
87
+
88
+ def clear(self):
89
+ """Clear all cache"""
90
+ self.cache.clear()
91
+ self.timestamps.clear()
92
+
93
+ def size(self) -> int:
94
+ """Get cache size"""
95
+ return len(self.cache)
96
+
97
+ # Global memory cache instance
98
+ memory_cache = MemoryCache()
99
+
100
+ def get_redis_client():
101
+ """Get Redis client (returns None if Redis not available)"""
102
+ try:
103
+ # In a real implementation, you would:
104
+ # import redis
105
+ # from config.settings import get_settings
106
+ # settings = get_settings()
107
+ # if settings.REDIS_URL:
108
+ # return redis.from_url(settings.REDIS_URL)
109
+
110
+ # For now, return None to use memory cache fallback
111
+ return None
112
+ except Exception as e:
113
+ logger.warning(f"Redis not available, using memory cache: {e}")
114
+ return None
115
+
116
+ async def get_cached_data(key: str, default: Any = None) -> Any:
117
+ """Get data from cache with fallback to memory cache"""
118
+ try:
119
+ redis_client = get_redis_client()
120
+
121
+ if redis_client:
122
+ # Try Redis first
123
+ try:
124
+ data = await redis_client.get(key)
125
+ if data:
126
+ return json.loads(data)
127
+ except Exception as e:
128
+ logger.warning(f"Redis get error: {e}")
129
+
130
+ # Fallback to memory cache
131
+ data = await memory_cache.get(key)
132
+ return data if data is not None else default
133
+
134
+ except Exception as e:
135
+ logger.error(f"Error getting cached data: {e}")
136
+ return default
137
+
138
+ async def set_cached_data(key: str, value: Any, ttl: int = 300) -> bool:
139
+ """Set data in cache with fallback to memory cache"""
140
+ try:
141
+ redis_client = get_redis_client()
142
+
143
+ if redis_client:
144
+ # Try Redis first
145
+ try:
146
+ await redis_client.setex(key, ttl, json.dumps(value))
147
+ return True
148
+ except Exception as e:
149
+ logger.warning(f"Redis set error: {e}")
150
+
151
+ # Fallback to memory cache
152
+ return await memory_cache.setex(key, ttl, value)
153
+
154
+ except Exception as e:
155
+ logger.error(f"Error setting cached data: {e}")
156
+ return False
157
+
158
+ async def delete_cached_data(key: str) -> bool:
159
+ """Delete data from cache"""
160
+ try:
161
+ redis_client = get_redis_client()
162
+
163
+ if redis_client:
164
+ try:
165
+ await redis_client.delete(key)
166
+ return True
167
+ except Exception as e:
168
+ logger.warning(f"Redis delete error: {e}")
169
+
170
+ # Fallback to memory cache
171
+ return await memory_cache.delete(key)
172
+
173
+ except Exception as e:
174
+ logger.error(f"Error deleting cached data: {e}")
175
+ return False
176
+
177
+ async def cache_exists(key: str) -> bool:
178
+ """Check if key exists in cache"""
179
+ try:
180
+ redis_client = get_redis_client()
181
+
182
+ if redis_client:
183
+ try:
184
+ return await redis_client.exists(key)
185
+ except Exception as e:
186
+ logger.warning(f"Redis exists error: {e}")
187
+
188
+ # Fallback to memory cache
189
+ return await memory_cache.exists(key)
190
+
191
+ except Exception as e:
192
+ logger.error(f"Error checking cache existence: {e}")
193
+ return False
194
+
195
+ def clear_all_cache():
196
+ """Clear all cache data"""
197
+ try:
198
+ redis_client = get_redis_client()
199
+
200
+ if redis_client:
201
+ try:
202
+ redis_client.flushall()
203
+ except Exception as e:
204
+ logger.warning(f"Redis clear error: {e}")
205
+
206
+ # Clear memory cache
207
+ memory_cache.clear()
208
+ logger.info("All cache cleared")
209
+
210
+ except Exception as e:
211
+ logger.error(f"Error clearing cache: {e}")
212
+
213
+ def get_cache_stats() -> dict:
214
+ """Get cache statistics"""
215
+ try:
216
+ stats = {
217
+ "memory_cache_size": memory_cache.size(),
218
+ "redis_available": get_redis_client() is not None,
219
+ "timestamp": datetime.utcnow().isoformat()
220
+ }
221
+
222
+ redis_client = get_redis_client()
223
+ if redis_client:
224
+ try:
225
+ info = redis_client.info()
226
+ stats["redis_info"] = {
227
+ "used_memory": info.get("used_memory_human"),
228
+ "connected_clients": info.get("connected_clients"),
229
+ "total_commands_processed": info.get("total_commands_processed")
230
+ }
231
+ except Exception as e:
232
+ logger.warning(f"Could not get Redis stats: {e}")
233
+
234
+ return stats
235
+
236
+ except Exception as e:
237
+ logger.error(f"Error getting cache stats: {e}")
238
+ return {"error": str(e)}
239
+
240
+ # Decorator for caching function results
241
+ def cache_result(ttl: int = 300, key_prefix: str = ""):
242
+ """Decorator to cache function results"""
243
+ def decorator(func):
244
+ async def wrapper(*args, **kwargs):
245
+ # Generate cache key
246
+ cache_key = f"{key_prefix}:{func.__name__}:{str(args)}:{str(sorted(kwargs.items()))}"
247
+
248
+ # Try to get from cache
249
+ cached_result = await get_cached_data(cache_key)
250
+ if cached_result is not None:
251
+ return cached_result
252
+
253
+ # Execute function and cache result
254
+ result = await func(*args, **kwargs)
255
+ await set_cached_data(cache_key, result, ttl)
256
+
257
+ return result
258
+ return wrapper
259
+ return decorator
utils/rate_limiter.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Rate Limiter Utility
3
+ Simple rate limiting for API calls
4
+ """
5
+
6
+ import asyncio
7
+ import time
8
+ from typing import Dict
9
+ from collections import deque
10
+
11
+
12
+ class RateLimiter:
13
+ """Simple rate limiter using sliding window"""
14
+
15
+ def __init__(self, calls: int, period: int):
16
+ """
17
+ Args:
18
+ calls: Number of calls allowed
19
+ period: Time period in seconds
20
+ """
21
+ self.calls = calls
22
+ self.period = period
23
+ self.requests = deque()
24
+ self._lock = asyncio.Lock()
25
+
26
+ async def acquire(self):
27
+ """Acquire permission to make a request"""
28
+ async with self._lock:
29
+ now = time.time()
30
+
31
+ # Remove old requests outside the window
32
+ while self.requests and self.requests[0] <= now - self.period:
33
+ self.requests.popleft()
34
+
35
+ # Check if we can make a request
36
+ if len(self.requests) < self.calls:
37
+ self.requests.append(now)
38
+ return
39
+
40
+ # Need to wait
41
+ sleep_time = self.requests[0] + self.period - now
42
+ if sleep_time > 0:
43
+ await asyncio.sleep(sleep_time)
44
+ await self.acquire() # Recursively try again
45
+ else:
46
+ self.requests.append(now)