File size: 5,539 Bytes
9b51d59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
"""
Cleanup script for expired refresh tokens
Run this periodically (e.g., daily cron job) to remove expired tokens
"""

import asyncio
import sys
import os
from datetime import datetime

# Add parent directory to path
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))

from app.models.refresh_token_model import RefreshTokenModel
import logging

logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)

async def cleanup_expired_tokens():
    """Remove expired refresh tokens from database"""
    logger.info("Starting cleanup of expired refresh tokens...")
    
    try:
        deleted_count = await RefreshTokenModel.cleanup_expired_tokens()
        logger.info(f"✅ Successfully cleaned up {deleted_count} expired tokens")
        return deleted_count
        
    except Exception as e:
        logger.error(f"❌ Error during cleanup: {str(e)}", exc_info=True)
        raise

async def get_token_statistics():
    """Get statistics about refresh tokens"""
    from app.core.nosql_client import db
    
    try:
        collection = db["refresh_tokens"]
        
        # Total tokens
        total = await collection.count_documents({})
        
        # Active tokens (not revoked, not expired)
        active = await collection.count_documents({
            "revoked": False,
            "expires_at": {"$gt": datetime.utcnow()}
        })
        
        # Revoked tokens
        revoked = await collection.count_documents({"revoked": True})
        
        # Expired tokens
        expired = await collection.count_documents({
            "expires_at": {"$lt": datetime.utcnow()}
        })
        
        # Used tokens
        used = await collection.count_documents({"used": True})
        
        # Remember me tokens
        remember_me = await collection.count_documents({"remember_me": True})
        
        logger.info("\n📊 Token Statistics:")
        logger.info(f"  Total tokens: {total}")
        logger.info(f"  Active tokens: {active}")
        logger.info(f"  Revoked tokens: {revoked}")
        logger.info(f"  Expired tokens: {expired}")
        logger.info(f"  Used tokens: {used}")
        logger.info(f"  Remember me tokens: {remember_me}")
        
        return {
            "total": total,
            "active": active,
            "revoked": revoked,
            "expired": expired,
            "used": used,
            "remember_me": remember_me
        }
        
    except Exception as e:
        logger.error(f"❌ Error getting statistics: {str(e)}", exc_info=True)
        raise

async def check_suspicious_activity():
    """Check for suspicious token rotation patterns"""
    from app.core.cache_client import get_redis
    import json
    
    try:
        redis = await get_redis()
        
        # Get all token families
        cursor = 0
        suspicious_families = []
        
        while True:
            cursor, keys = await redis.scan(cursor, match="token_family:*", count=100)
            
            for key in keys:
                data = await redis.get(key)
                if data:
                    family_data = json.loads(data)
                    rotation_count = family_data.get("rotation_count", 0)
                    
                    # Flag families with excessive rotations (>100 in 30 days)
                    if rotation_count > 100:
                        suspicious_families.append({
                            "family_id": key.split(":")[-1],
                            "customer_id": family_data.get("customer_id"),
                            "rotation_count": rotation_count,
                            "created_at": family_data.get("created_at")
                        })
            
            if cursor == 0:
                break
        
        if suspicious_families:
            logger.warning(f"\n⚠️  Found {len(suspicious_families)} suspicious token families:")
            for family in suspicious_families:
                logger.warning(f"  - Family {family['family_id']}: {family['rotation_count']} rotations")
                logger.warning(f"    Customer: {family['customer_id']}")
        else:
            logger.info("\n✅ No suspicious token activity detected")
        
        return suspicious_families
        
    except Exception as e:
        logger.error(f"❌ Error checking suspicious activity: {str(e)}", exc_info=True)
        return []

async def main():
    """Main cleanup function"""
    logger.info("=" * 60)
    logger.info("Refresh Token Cleanup Script")
    logger.info(f"Run time: {datetime.utcnow().isoformat()}")
    logger.info("=" * 60)
    
    # Get statistics before cleanup
    logger.info("\n📈 Statistics before cleanup:")
    await get_token_statistics()
    
    # Perform cleanup
    logger.info("\n🧹 Performing cleanup...")
    deleted_count = await cleanup_expired_tokens()
    
    # Get statistics after cleanup
    logger.info("\n📈 Statistics after cleanup:")
    await get_token_statistics()
    
    # Check for suspicious activity
    logger.info("\n🔍 Checking for suspicious activity...")
    suspicious = await check_suspicious_activity()
    
    # Summary
    logger.info("\n" + "=" * 60)
    logger.info("Cleanup Summary:")
    logger.info(f"  Deleted tokens: {deleted_count}")
    logger.info(f"  Suspicious families: {len(suspicious)}")
    logger.info("=" * 60)

if __name__ == "__main__":
    asyncio.run(main())