File size: 26,215 Bytes
b1f38ad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8ced8b5
 
 
 
 
 
 
 
 
 
b1f38ad
 
 
 
 
8ced8b5
 
 
 
 
 
 
b1f38ad
 
 
 
 
29e6847
 
b1f38ad
 
 
 
 
 
 
 
 
29e6847
 
b1f38ad
 
 
 
 
 
eff53bc
 
 
 
 
 
 
 
 
b1f38ad
 
 
 
 
 
 
 
 
 
 
 
 
 
8ced8b5
b1f38ad
8ced8b5
 
 
b1f38ad
 
 
 
 
2b05b19
b1f38ad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8ced8b5
 
 
 
 
 
b1f38ad
 
 
 
 
 
 
8ced8b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b1f38ad
 
 
2b05b19
 
 
 
 
 
b1f38ad
2b05b19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b1f38ad
 
2b05b19
b1f38ad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2b05b19
ef67196
b1f38ad
 
 
 
 
 
 
 
 
 
 
 
 
 
ef67196
b1f38ad
 
 
 
 
 
 
 
 
ef67196
b1f38ad
 
 
 
 
 
 
 
 
 
ef67196
b1f38ad
 
 
 
 
a4a7fb9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
from fastapi import FastAPI, HTTPException, Depends, status, Header
from contextlib import asynccontextmanager
from typing import Optional, Annotated
from datetime import datetime, timedelta

# SQLModel & Database Imports
from sqlmodel import Session, select
from database import create_db_and_tables, engine
from models import User 

# Security Imports
from pydantic import BaseModel
from fastapi.middleware.cors import CORSMiddleware
from passlib.context import CryptContext
from jose import JWTError, jwt

# Import your strategy logic
from strategy import train_models_and_backtest

# Import model manager for HMM-SVR models
from model_manager import (
    load_all_models,
    train_and_save_model,
    load_model,
    is_model_trained,
    get_model_info,
    get_cached_models
)

# --- 1. LIFESPAN (Create Tables on Startup) ---
@asynccontextmanager
async def lifespan(app: FastAPI):
    create_db_and_tables()
    # Load all pre-trained HMM-SVR models from disk into memory
    print("\n🚀 Starting AlgoQuant API...")
    loaded_models = load_all_models()
    if loaded_models:
        print(f"✅ Loaded {len(loaded_models)} HMM-SVR models: {list(loaded_models.keys())}")
    else:
        print("ℹ️  No pre-trained models found. Train models using /api/models/train/{symbol}")
    yield

app = FastAPI(lifespan=lifespan)

# --- CONFIGURATION ---
import os
SECRET_KEY = os.getenv("SECRET_KEY", "algoquant_super_secret_key")
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 43200  # 30 days (30 * 24 * 60)
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")

# --- CORS ---
app.add_middleware(
    CORSMiddleware,
    allow_origins=[
        "http://localhost:3000",
        "http://127.0.0.1:3000",
        "https://algo-quant-pi.vercel.app"
    ],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# --- Health Check Endpoint ---
@app.get("/")
async def root():
    return {"status": "healthy", "message": "AlgoQuant API is running"}

@app.get("/health")
async def health_check():
    return {"status": "healthy", "timestamp": datetime.utcnow().isoformat()}

# --- Pydantic Models (For Request Body) ---
class UserCreate(BaseModel):
    email: str
    password: str
    name: Optional[str] = None

class UserLogin(BaseModel):
    email: str
    password: str

class BacktestRequest(BaseModel):
    ticker: str
    start_date: str
    end_date: str
    strategy: str = "hmm_svr"
    # Strategy-specific parameters
    short_window: int = 12
    long_window: int = 26
    n_states: int = 3

class Token(BaseModel):
    access_token: str
    token_type: str

class SimulatedTradingRequest(BaseModel):
    symbol: str
    trade_amount: float
    duration: int
    duration_unit: str = "minutes"  # "minutes" or "days"

# --- DATABASE DEPENDENCY ---
def get_session():
    with Session(engine) as session:
        yield session

# --- AUTH HELPERS ---
def verify_password(plain_password, hashed_password):
    return pwd_context.verify(plain_password, hashed_password)

def get_password_hash(password):
    return pwd_context.hash(password)

def create_access_token(data: dict, expires_delta: Optional[timedelta] = None):
    to_encode = data.copy()
    if expires_delta:
        expire = datetime.now() + expires_delta
    else:
        expire = datetime.now() + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
    to_encode.update({"exp": expire})
    return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)

# --- THE GUARD (Protect Routes) ---
async def get_current_user(authorization: str = Header(None)):
    if not authorization:
        raise HTTPException(status_code=401, detail="Missing Token")
    try:
        token = authorization.split(" ")[1]
        payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
        email: str = payload.get("sub")
        if email is None:
            raise HTTPException(status_code=401, detail="Invalid Token")
        return email
    except (JWTError, IndexError):
        raise HTTPException(status_code=401, detail="Could not validate credentials")

# --- ROUTES ---

@app.post("/api/signup", response_model=Token)
def signup(user_data: UserCreate, session: Session = Depends(get_session)):
    try:
        # 1. Check if user exists in DB
        statement = select(User).where(User.email == user_data.email)
        existing_user = session.exec(statement).first()
        
        if existing_user:
            raise HTTPException(status_code=400, detail="Email already registered")
        
        # 2. Hash Password & Create User Object
        hashed_pwd = get_password_hash(user_data.password)
        new_user = User(
            email=user_data.email, 
            name=user_data.name, 
            hashed_password=hashed_pwd
        )
        
        # 3. Save to DB
        session.add(new_user)
        session.commit()
        session.refresh(new_user)
        
        # 4. Auto-login (Return Token immediately)
        access_token = create_access_token(data={"sub": new_user.email})
        return {"access_token": access_token, "token_type": "bearer"}
    
    except HTTPException:
        raise
    except Exception as e:
        session.rollback()
        print(f"Signup error: {str(e)}")
        raise HTTPException(status_code=500, detail="Internal server error during signup")

@app.post("/api/login", response_model=Token)
def login(user_data: UserLogin, session: Session = Depends(get_session)):
    try:
        # 1. Validate input
        if not user_data.email or not user_data.password:
            raise HTTPException(
                status_code=status.HTTP_400_BAD_REQUEST,
                detail="Email and password are required"
            )
        
        # 2. Select User from DB
        statement = select(User).where(User.email == user_data.email)
        user = session.exec(statement).first()
        
        # 3. Verify
        if not user or not verify_password(user_data.password, user.hashed_password):
            raise HTTPException(
                status_code=status.HTTP_401_UNAUTHORIZED,
                detail="Incorrect email or password",
            )
        
        # 4. Issue Token (30 days expiration)
        access_token = create_access_token(data={"sub": user.email})
        return {"access_token": access_token, "token_type": "bearer"}
    
    except HTTPException:
        raise
    except Exception as e:
        print(f"Login error: {str(e)}")
        raise HTTPException(status_code=500, detail="Internal server error during login")

@app.post("/api/backtest")
def run_backtest(
    req: BacktestRequest, 
    current_user: str = Depends(get_current_user)
):
    print(f"User {current_user} is running {req.strategy} backtest...")
    
    result = train_models_and_backtest(
        req.ticker, req.start_date, req.end_date, 
        short_window=req.short_window,
        long_window=req.long_window,
        n_states=req.n_states
    )
    return result


@app.get("/api/backtest/strategies")
def get_backtest_strategies(current_user: str = Depends(get_current_user)):
    """Get available backtest strategies"""
    return {
        "strategies": []
    }


# --- MODEL MANAGEMENT ROUTES (HMM-SVR) ---

@app.post("/api/models/train/{symbol}")
def train_model(symbol: str, current_user: str = Depends(get_current_user)):
    """
    Train and save HMM-SVR model for a specific symbol.
    This trains on 4 years of historical data and saves the model to disk.
    The model will be automatically loaded on next startup.
    
    Example: POST /api/models/train/BTCUSDT
    """
    print(f"[API] User {current_user} requested model training for {symbol}")
    
    # Validate symbol format
    symbol = symbol.upper()
    if not symbol.endswith('USDT'):
        raise HTTPException(
            status_code=400, 
            detail="Symbol must end with USDT (e.g., BTCUSDT, ETHUSDT)"
        )
    
    result = train_and_save_model(symbol, n_states=3)
    
    if "error" in result:
        raise HTTPException(status_code=400, detail=result["error"])
    
    return {
        "success": True,
        "message": f"Model trained and saved for {symbol}",
        "details": result
    }


@app.get("/api/models/status/{symbol}")
def get_model_status(symbol: str, current_user: str = Depends(get_current_user)):
    """
    Check if a model exists and get its metadata for a symbol.
    """
    symbol = symbol.upper()
    
    if not is_model_trained(symbol):
        return {
            "trained": False,
            "symbol": symbol,
            "message": f"No model found for {symbol}. Train it using POST /api/models/train/{symbol}"
        }
    
    info = get_model_info(symbol)
    return {
        "trained": True,
        "symbol": symbol,
        "info": info
    }


@app.get("/api/models")
def list_models(current_user: str = Depends(get_current_user)):
    """
    List all available trained models and their status.
    """
    cached = get_cached_models()
    
    # Also check for models on disk that aren't loaded yet
    import os
    from model_manager import MODEL_DIR
    
    disk_models = []
    if os.path.exists(MODEL_DIR):
        for filename in os.listdir(MODEL_DIR):
            if filename.endswith('_hmm_svr.pkl'):
                symbol = filename.replace('_hmm_svr.pkl', '').upper()
                disk_models.append(symbol)
    
    return {
        "loaded_models": cached,
        "available_on_disk": disk_models,
        "total_count": len(set(list(cached.keys()) + disk_models))
    }


@app.post("/api/models/reload")
def reload_models(current_user: str = Depends(get_current_user)):
    """
    Reload all models from disk into memory.
    Useful if models were trained externally or after a restart.
    """
    result = load_all_models()
    return {
        "success": True,
        "loaded_models": list(result.keys()),
        "count": sum(result.values())
    }


@app.get("/api/models/signal/{symbol}")
def get_instant_signal(symbol: str, current_user: str = Depends(get_current_user)):
    """
    Get instant trading signal for a symbol using trained HMM-SVR model.
    Auto-trains model if it doesn't exist.
    Returns current regime, recommended position size, and trading signal.
    
    Example: GET /api/models/signal/BTCUSDT
    """
    from model_manager import is_model_trained, load_model, calculate_signal_and_position, train_and_save_model
    import yfinance as yf
    from datetime import datetime, timedelta
    import pandas as pd
    
    symbol = symbol.upper()
    base_symbol = symbol.replace('USDT', '')
    yahoo_symbol = f"{base_symbol}-USD"  # Convert to Yahoo Finance format
    
    # Check if model exists, train if not (same as bot auto-training)
    if not is_model_trained(base_symbol) and not is_model_trained(symbol):
        print(f"[SignalAPI] No model found for {base_symbol}, training now...")
        
        try:
            # Train model with both Yahoo symbol and Binance symbol for fallback
            # Save model with base symbol name (BNB) not Yahoo format (BNB-USD)
            train_result = train_and_save_model(
                symbol=yahoo_symbol, 
                n_states=3, 
                binance_symbol=symbol,
                save_as=base_symbol
            )
            
            if train_result and 'error' not in train_result:
                print(f"[SignalAPI] ✅ Model trained for {base_symbol} with {train_result.get('train_days', 0)} days")
            else:
                return {
                    "success": False,
                    "error": f"Failed to train model: {train_result.get('error', 'Unknown error')}",
                    "action_required": "Insufficient data to train model"
                }
        except Exception as e:
            return {
                "success": False,
                "error": f"Model training failed: {str(e)}"
            }
    
    # Fetch recent price data (450 days for proper feature calculation)
    try:
        end_date = datetime.now()
        start_date = end_date - timedelta(days=450)
        
        df = yf.download(yahoo_symbol, start=start_date, end=end_date, progress=False, auto_adjust=True)
        
        if df.empty:
            return {
                "success": False,
                "error": f"Could not fetch price data for {yahoo_symbol}"
            }
        
        # Handle MultiIndex columns
        if isinstance(df.columns, pd.MultiIndex):
            if 'Close' in df.columns.get_level_values(0):
                df.columns = df.columns.get_level_values(0)
            else:
                df.columns = df.columns.get_level_values(1)
        
        # Get signal from model (use base_symbol for model lookup, yahoo_symbol for data)
        result = calculate_signal_and_position(
            symbol=base_symbol,
            recent_data=df,
            short_window=12,
            long_window=26
        )
        
        if result is None or 'error' in result:
            return {
                "success": False,
                "error": result.get('error', 'Unknown error') if result else "Failed to calculate signal"
            }
        
        # Determine human-readable signal
        ema_signal = result.get('ema_signal', 0)
        target_position = result.get('target_position', 0)
        position_multiplier = result.get('position_size_multiplier', 1.0)
        regime = result.get('regime', 1)
        regime_label = result.get('regime_label', 'Normal')
        
        # Generate action recommendation (5-level system: 0x, 0.5x, 1x, 2x, 3x)
        if target_position == 0:
            if regime_label == 'Crash':
                action = "STAY OUT"
                action_color = "red"
                action_description = "🚨 Crash Protocol: Safety override activated"
            else:
                action = "WAIT"
                action_color = "yellow"
                action_description = "Bearish trend - waiting for reversal"
        elif target_position == 3:
            action = "STRONG BUY (3x)"
            action_color = "green"
            action_description = "🚀 Max Leverage: Safe regime + very low risk!"
        elif target_position == 2:
            action = "BUY (2x)"
            action_color = "cyan"
            action_description = "📈 Medium Leverage: Favorable conditions"
        elif target_position == 0.5:
            action = "CAUTIOUS BUY (0.5x)"
            action_color = "orange"
            action_description = "⚠️ Defensive: High risk detected"
        else:
            action = "BUY (1x)"
            action_color = "blue"
            action_description = "✅ Standard bullish position"
        
        return {
            "success": True,
            "symbol": symbol,
            "current_price": result.get('close_price', 0),
            "signal": {
                "action": action,
                "action_color": action_color,
                "action_description": action_description,
                "ema_trend": "Bullish" if ema_signal == 1 else "Bearish",
                "position_multiplier": position_multiplier,
                "target_position": target_position,
                "signal_stability": result.get('signal_stability', 0.5),  # NEW
                "ema_gap_percent": result.get('ema_gap_percent', 0)  # NEW: Trend strength
            },
            "regime": {
                "state": regime,
                "label": regime_label,
                "description": "Low volatility" if regime == 0 else ("High volatility - danger" if regime_label == 'Crash' else "Normal volatility")
            },
            "risk": {
                "ratio": result.get('risk_ratio', 1.0),
                "level": "Low" if result.get('risk_ratio', 1.0) < 0.5 else ("High" if result.get('risk_ratio', 1.0) > 1.5 else "Moderate"),
                "predicted_volatility": result.get('predicted_vol', 0)
            },
            "technicals": {
                "ema_short": result.get('ema_short', 0),
                "ema_long": result.get('ema_long', 0)
            },
            "reasoning": result.get('reasoning', ''),
            "timestamp": datetime.now().isoformat()
        }
        
    except Exception as e:
        return {
            "success": False,
            "error": f"Error calculating signal: {str(e)}"
        }


# --- SIMULATED TRADING ROUTES ---

@app.get("/api/simulated/trades")
def get_simulated_trades(
    limit: int = 50,
    current_user: str = Depends(get_current_user)
):
    """Get recent simulated trades for the current user"""
    from simulated_endpoints import get_simulated_trades_endpoint
    return get_simulated_trades_endpoint(limit, current_user)


@app.get("/api/simulated/sessions")
def get_simulated_sessions(current_user: str = Depends(get_current_user)):
    """Get all simulated trading sessions for the current user"""
    from simulated_endpoints import get_simulated_sessions_endpoint
    return get_simulated_sessions_endpoint(current_user)


@app.get("/api/simulated/portfolio")
def get_simulated_portfolio(current_user: str = Depends(get_current_user)):
    """Get the internal simulated portfolio (database-driven wallet)"""
    from simulated_exchange import get_portfolio_summary
    from database import initialize_portfolio_if_empty
    
    # Initialize portfolio with 10k USDT if this is a new user
    initialize_portfolio_if_empty(user_email=current_user)
    
    portfolio = get_portfolio_summary(user_email=current_user)
    return portfolio


@app.post("/api/simulated/start")
def start_simulated_session(req: SimulatedTradingRequest, current_user: str = Depends(get_current_user)):
    """Start HMM-SVR trading bot session"""
    from simulated_trading import start_simulated_trading
    from database import initialize_portfolio_if_empty
    
    # Initialize portfolio with 10k USDT if this is a new user
    initialize_portfolio_if_empty(user_email=current_user)
    
    duration_minutes = req.duration
    if req.duration_unit == "days":
        duration_minutes = req.duration * 24 * 60
    
    result = start_simulated_trading(
        user_email=current_user,
        symbol=req.symbol,
        trade_amount=req.trade_amount,
        duration_minutes=duration_minutes
    )
    
    if "error" in result:
        raise HTTPException(status_code=400, detail=result["error"])
    return result


@app.post("/api/simulated/stop/{session_id}")
def stop_simulated_session(session_id: str, current_user: str = Depends(get_current_user)):
    """Stop trading bot session"""
    from simulated_trading import stop_simulated_trading
    
    result = stop_simulated_trading(session_id)
    if "error" in result:
        raise HTTPException(status_code=404, detail=result["error"])
    return result


@app.get("/api/simulated/session/{session_id}")
def get_simulated_session(session_id: str, current_user: str = Depends(get_current_user)):
    """Get bot session status"""
    from simulated_trading import get_simulated_session_status
    
    status = get_simulated_session_status(session_id)
    if "error" in status:
        raise HTTPException(status_code=404, detail=status["error"])
    return status


# --- MANUAL TRADING ROUTES (Market Page) ---

class ManualBuyRequest(BaseModel):
    symbol: str  # e.g., 'BTC', 'ETH'
    usdt_amount: float  # Amount in USDT to spend

class ManualSellRequest(BaseModel):
    symbol: str  # e.g., 'BTC', 'ETH'
    quantity: float  # Amount of asset to sell

class ManualSellPercentRequest(BaseModel):
    symbol: str  # e.g., 'BTC', 'ETH'
    percentage: float  # Percentage of holdings to sell (0-100)


@app.post("/api/market/buy")
def manual_buy(req: ManualBuyRequest, current_user: str = Depends(get_current_user)):
    """
    Execute a manual buy order from the Market page.
    This is independent from automated trading bot strategies.
    Updates portfolio and creates trade log entry.
    """
    from manual_trading import execute_manual_buy
    from database import initialize_portfolio_if_empty
    
    # Ensure user has portfolio initialized
    initialize_portfolio_if_empty(user_email=current_user)
    
    # Validate input
    if req.usdt_amount <= 0:
        raise HTTPException(status_code=400, detail="Amount must be positive")
    
    if req.usdt_amount < 1:
        raise HTTPException(status_code=400, detail="Minimum buy amount is 1 USDT")
    
    success, trade_info, error = execute_manual_buy(
        symbol=req.symbol,
        usdt_amount=req.usdt_amount,
        user_email=current_user
    )
    
    if not success:
        raise HTTPException(status_code=400, detail=error)
    
    return {
        "success": True,
        "message": f"Successfully bought {trade_info['quantity']:.8f} {req.symbol}",
        "trade": trade_info
    }


@app.post("/api/market/sell")
def manual_sell(req: ManualSellRequest, current_user: str = Depends(get_current_user)):
    """
    Execute a manual sell order from the Market page.
    This is independent from automated trading bot strategies.
    Updates portfolio and creates trade log entry.
    """
    from manual_trading import execute_manual_sell
    from database import initialize_portfolio_if_empty
    
    # Ensure user has portfolio initialized
    initialize_portfolio_if_empty(user_email=current_user)
    
    # Validate input
    if req.quantity <= 0:
        raise HTTPException(status_code=400, detail="Quantity must be positive")
    
    success, trade_info, error = execute_manual_sell(
        symbol=req.symbol,
        quantity=req.quantity,
        user_email=current_user
    )
    
    if not success:
        raise HTTPException(status_code=400, detail=error)
    
    return {
        "success": True,
        "message": f"Successfully sold {trade_info['quantity']:.8f} {req.symbol}",
        "trade": trade_info
    }


@app.post("/api/market/sell-percent")
def manual_sell_percent(req: ManualSellPercentRequest, current_user: str = Depends(get_current_user)):
    """
    Sell a percentage of holdings for a specific asset.
    Useful for quick "Sell 25%", "Sell 50%", "Sell All" actions.
    """
    from manual_trading import execute_manual_sell, get_user_balance
    from database import initialize_portfolio_if_empty
    
    # Ensure user has portfolio initialized
    initialize_portfolio_if_empty(user_email=current_user)
    
    # Validate percentage
    if req.percentage <= 0 or req.percentage > 100:
        raise HTTPException(status_code=400, detail="Percentage must be between 0 and 100")
    
    # Get current balance
    balance = get_user_balance(req.symbol.upper(), current_user)
    if balance <= 0:
        raise HTTPException(status_code=400, detail=f"No {req.symbol} holdings to sell")
    
    # Calculate quantity to sell
    quantity_to_sell = balance * (req.percentage / 100)
    
    success, trade_info, error = execute_manual_sell(
        symbol=req.symbol,
        quantity=quantity_to_sell,
        user_email=current_user
    )
    
    if not success:
        raise HTTPException(status_code=400, detail=error)
    
    return {
        "success": True,
        "message": f"Successfully sold {req.percentage}% ({trade_info['quantity']:.8f}) {req.symbol}",
        "trade": trade_info
    }


@app.get("/api/market/trades")
def get_manual_trades(limit: int = 50, current_user: str = Depends(get_current_user)):
    """Get manual trade history for the current user"""
    from manual_trading import get_manual_trade_history
    
    trades = get_manual_trade_history(current_user, limit)
    return {"trades": trades}


@app.get("/api/market/prices")
def get_market_prices(current_user: str = Depends(get_current_user)):
    """
    Get current prices for all supported assets.
    Useful for initial page load before WebSocket connects.
    """
    from manual_trading import get_prices_for_assets
    
    prices = get_prices_for_assets()
    return {"prices": prices}


@app.get("/api/market/assets")
def get_supported_assets(current_user: str = Depends(get_current_user)):
    """Get list of supported assets for manual trading"""
    from manual_trading import SUPPORTED_ASSETS
    
    assets = [
        {"symbol": "BTC", "name": "Bitcoin", "logo": "₿", "color": "#F7931A"},
        {"symbol": "ETH", "name": "Ethereum", "logo": "Ξ", "color": "#627EEA"},
        {"symbol": "SOL", "name": "Solana", "logo": "◎", "color": "#14F195"},
        {"symbol": "LINK", "name": "Chainlink", "logo": "⬡", "color": "#2A5ADA"},
        {"symbol": "DOGE", "name": "Dogecoin", "logo": "Ð", "color": "#C2A633"},
        {"symbol": "BNB", "name": "BNB", "logo": "⬡", "color": "#F3BA2F"},
    ]
    
    return {"assets": [a for a in assets if a["symbol"] in SUPPORTED_ASSETS]}


@app.get("/api/market/cost-basis/{symbol}")
def get_cost_basis(symbol: str, current_user: str = Depends(get_current_user)):
    """
    Get the average cost basis and investment info for a specific asset.
    Used to show estimated PnL before selling.
    """
    from manual_trading import get_asset_cost_basis, get_current_price_from_binance, TRADING_FEE
    
    cost_info = get_asset_cost_basis(symbol.upper(), current_user)
    
    # Get current price to calculate unrealized PnL
    current_price = get_current_price_from_binance(symbol.upper(), "USDT")
    
    if current_price and cost_info['balance'] > 0:
        current_value = current_price * cost_info['balance']
        fee_estimate = current_value * TRADING_FEE
        net_value = current_value - fee_estimate
        unrealized_pnl = net_value - cost_info['total_invested']
        unrealized_pnl_percent = ((net_value / cost_info['total_invested']) - 1) * 100 if cost_info['total_invested'] > 0 else 0.0
    else:
        current_value = 0.0
        unrealized_pnl = 0.0
        unrealized_pnl_percent = 0.0
    
    return {
        "symbol": symbol.upper(),
        "balance": cost_info['balance'],
        "avg_cost_basis": cost_info['avg_cost_basis'],
        "total_invested": cost_info['total_invested'],
        "current_price": current_price,
        "current_value": current_value,
        "unrealized_pnl": unrealized_pnl,
        "unrealized_pnl_percent": unrealized_pnl_percent
    }