Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, HTTPException, Request, Depends, Cookie, Header | |
| from fastapi.staticfiles import StaticFiles | |
| from fastapi.responses import HTMLResponse, FileResponse, RedirectResponse, JSONResponse | |
| from pydantic import BaseModel | |
| import os | |
| import json | |
| import firebase_admin | |
| from firebase_admin import credentials, db | |
| import base64 | |
| import time | |
| import hashlib | |
| import secrets | |
| import jwt | |
| from datetime import datetime, timedelta | |
| from dotenv import load_dotenv | |
| from typing import List, Dict, Any, Optional | |
| # Load environment variables | |
| load_dotenv() | |
| app = FastAPI(title="CatGPT Model Manager", description="Modern web interface for managing Discord bot AI models") | |
| # Authentication Configuration | |
| ADMIN_KEY = os.getenv('ADMIN_KEY') | |
| if not ADMIN_KEY: | |
| print("WARNING: ADMIN_KEY not set. Using default key for development.") | |
| ADMIN_KEY = "dev-admin-key-please-change" | |
| # Bot Management Configuration (for auto-configuration) | |
| BOT_API_URL = os.getenv('BOT_API_URL') | |
| DISCORD_ADMIN_KEY = os.getenv('DISCORD_ADMIN_KEY') | |
| # JWT Configuration | |
| JWT_SECRET = os.getenv('JWT_SECRET', ADMIN_KEY + '-jwt-secret') | |
| JWT_ALGORITHM = "HS256" | |
| JWT_EXPIRATION_HOURS = 24 * 7 # 7 days | |
| print(f"π ADMIN_KEY configured: {'Yes' if ADMIN_KEY else 'No'}") | |
| print(f"π ADMIN_KEY length: {len(ADMIN_KEY) if ADMIN_KEY else 0} characters") | |
| print(f"π ADMIN_KEY starts with: {ADMIN_KEY[:3]}... (hidden for security)") | |
| print(f"π JWT authentication enabled with {JWT_EXPIRATION_HOURS}h expiration") | |
| # Pydantic models for authentication | |
| class LoginRequest(BaseModel): | |
| admin_key: str | |
| def create_jwt_token() -> str: | |
| """Generate a JWT token""" | |
| payload = { | |
| "exp": datetime.utcnow() + timedelta(hours=JWT_EXPIRATION_HOURS), | |
| "iat": datetime.utcnow(), | |
| "admin": True | |
| } | |
| return jwt.encode(payload, JWT_SECRET, algorithm=JWT_ALGORITHM) | |
| def verify_jwt_token(token: str) -> bool: | |
| """Verify if JWT token is valid""" | |
| try: | |
| payload = jwt.decode(token, JWT_SECRET, algorithms=[JWT_ALGORITHM]) | |
| return payload.get("admin", False) | |
| except jwt.ExpiredSignatureError: | |
| return False | |
| except jwt.InvalidTokenError: | |
| return False | |
| def get_token_from_request(authorization: Optional[str] = Header(None), session_token: Optional[str] = Cookie(None)) -> Optional[str]: | |
| """Extract token from Authorization header or cookie""" | |
| # Try Authorization header first (Bearer token) | |
| if authorization and authorization.startswith("Bearer "): | |
| return authorization.split(" ")[1] | |
| # Fallback to cookie | |
| return session_token | |
| def verify_session(token: Optional[str] = Depends(get_token_from_request)) -> bool: | |
| """Verify if session token is valid""" | |
| if not token: | |
| return False | |
| return verify_jwt_token(token) | |
| def require_auth(authenticated: bool = Depends(verify_session)): | |
| """Dependency to require authentication""" | |
| if not authenticated: | |
| raise HTTPException(status_code=401, detail="Authentication required") | |
| return True | |
| class FirebaseManager: | |
| def __init__(self): | |
| self.app = None | |
| self.db_ref = None | |
| self.initialize() | |
| def initialize(self): | |
| """Initialize Firebase Admin SDK""" | |
| try: | |
| database_url = os.getenv('FIREBASE_URL', 'https://genaibot-28c30-default-rtdb.asia-southeast1.firebasedatabase.app/') | |
| service_account_b64 = os.getenv('FIREBASE_SERVICE_ACCOUNT_B64') | |
| if not service_account_b64: | |
| print("Warning: FIREBASE_SERVICE_ACCOUNT_B64 not set, using mock data") | |
| return False | |
| # Decode base64 service account | |
| service_account_json = base64.b64decode(service_account_b64).decode('utf-8') | |
| service_account_dict = json.loads(service_account_json) | |
| # Initialize Firebase Admin | |
| if not firebase_admin._apps: | |
| cred = credentials.Certificate(service_account_dict) | |
| self.app = firebase_admin.initialize_app(cred, { | |
| 'databaseURL': database_url | |
| }) | |
| else: | |
| self.app = firebase_admin.get_app() | |
| self.db_ref = db.reference('catgptbot') | |
| print("Firebase initialized successfully") | |
| # Test logging system | |
| test_log = { | |
| "timestamp": "2024-01-01T00:00:00.000Z", | |
| "action": "system_start", | |
| "category": "system", | |
| "displayName": "Firebase System", | |
| "urn": "system://firebase/initialized" | |
| } | |
| success, message = self.add_log_entry(test_log) | |
| print(f"π§ͺ Test log entry: {success} - {message}") | |
| return True | |
| except Exception as e: | |
| print(f"Firebase initialization error: {e}") | |
| # Use mock data for development | |
| self.use_mock_data() | |
| return False | |
| def use_mock_data(self): | |
| """Use mock data when Firebase is not available""" | |
| print("π Using mock data for development") | |
| # Initialize mock logs with test data | |
| self.mock_logs = { | |
| "1640995200000": { | |
| "timestamp": "2024-01-01T12:00:00.000Z", | |
| "action": "added", | |
| "category": "pony", | |
| "displayName": "Test Model", | |
| "urn": "urn:air:sdxl:checkpoint:civitai:123456@789012" | |
| } | |
| } | |
| print(f"π Initialized mock logs with {len(self.mock_logs)} test entries") | |
| self.mock_data = { | |
| "pony": { | |
| "models": { | |
| "1640995200000": { | |
| "displayName": "Pony Diffusion V6", | |
| "urn": "@cf/bytedance/stable-diffusion-xl-lightning", | |
| "category": "pony", | |
| "isActive": True, | |
| "tags": ["anime", "realistic"], | |
| "metadata": {"nsfw": False} | |
| }, | |
| "1640995300000": { | |
| "displayName": "Pony XL Enhanced", | |
| "urn": "@cf/stabilityai/stable-diffusion-xl-base-1.0", | |
| "category": "pony", | |
| "isActive": True, | |
| "tags": ["anime", "pony"], | |
| "metadata": {"nsfw": True} | |
| } | |
| } | |
| }, | |
| "illustrious": { | |
| "models": { | |
| "1640995400000": { | |
| "displayName": "Illustrious XL", | |
| "urn": "@cf/black-forest-labs/flux-1-schnell", | |
| "category": "illustrious", | |
| "isActive": False, | |
| "tags": ["illustration", "art"], | |
| "metadata": {"nsfw": True} | |
| } | |
| } | |
| }, | |
| "sdxl": { | |
| "models": { | |
| "1640995500000": { | |
| "displayName": "SDXL Lightning", | |
| "urn": "@cf/bytedance/stable-diffusion-xl-lightning", | |
| "category": "sdxl", | |
| "isActive": True, | |
| "tags": ["realistic", "fast"], | |
| "metadata": {"nsfw": False} | |
| } | |
| } | |
| } | |
| } | |
| print("Using mock data for development") | |
| def get_models_by_category(self, category: str) -> Dict[str, Any]: | |
| """Get all models for a specific category""" | |
| try: | |
| if self.db_ref: | |
| models_ref = self.db_ref.child(f'models/categories/{category}/models') | |
| return models_ref.get() or {} | |
| else: | |
| # Return mock data | |
| return self.mock_data.get(category, {}).get("models", {}) | |
| except Exception as e: | |
| print(f"Error getting models for {category}: {e}") | |
| return {} | |
| def get_all_models(self) -> Dict[str, Any]: | |
| """Get all models from all categories""" | |
| try: | |
| if self.db_ref: | |
| models_ref = self.db_ref.child('models/categories') | |
| return models_ref.get() or {} | |
| else: | |
| return self.mock_data | |
| except Exception as e: | |
| print(f"Error getting all models: {e}") | |
| return {} | |
| def add_model(self, category: str, model_data: Dict[str, Any]) -> tuple[bool, str]: | |
| """Add a new model""" | |
| try: | |
| # Check active model limit | |
| active_count = self.count_active_models(category) | |
| if active_count >= 25: | |
| return False, f"Cannot add model. {category.upper()} already has 25 models (limit reached)." | |
| model_id = str(int(time.time() * 1000)) | |
| if self.db_ref: | |
| model_path = f'models/categories/{category}/models/{model_id}' | |
| self.db_ref.child(model_path).set(model_data) | |
| else: | |
| # Add to mock data | |
| if category not in self.mock_data: | |
| self.mock_data[category] = {"models": {}} | |
| self.mock_data[category]["models"][model_id] = model_data | |
| return True, f"Model added successfully with ID: {model_id}" | |
| except Exception as e: | |
| return False, f"Error adding model: {str(e)}" | |
| def update_model(self, category: str, model_id: str, field: str, value: Any) -> tuple[bool, str]: | |
| """Update a specific field of a model""" | |
| try: | |
| if field == 'isActive' and value: | |
| active_count = self.count_active_models(category) | |
| if active_count >= 25: | |
| return False, f"Cannot activate model. {category.upper()} already has 25 active models." | |
| if self.db_ref: | |
| model_path = f'models/categories/{category}/models/{model_id}' | |
| if field == 'tags': | |
| tag_list = [tag.strip() for tag in value.split(',') if tag.strip()] | |
| self.db_ref.child(f'{model_path}/tags').set(tag_list) | |
| elif field == 'nsfw': | |
| self.db_ref.child(f'{model_path}/metadata/nsfw').set(value) | |
| else: | |
| self.db_ref.child(f'{model_path}/{field}').set(value) | |
| else: | |
| # Update mock data | |
| if category in self.mock_data and model_id in self.mock_data[category]["models"]: | |
| if field == 'tags': | |
| self.mock_data[category]["models"][model_id][field] = [tag.strip() for tag in value.split(',') if tag.strip()] | |
| elif field == 'nsfw': | |
| self.mock_data[category]["models"][model_id]["metadata"]["nsfw"] = value | |
| else: | |
| self.mock_data[category]["models"][model_id][field] = value | |
| return True, "Model updated successfully" | |
| except Exception as e: | |
| return False, f"Error updating model: {str(e)}" | |
| def delete_model(self, category: str, model_id: str) -> tuple[bool, str]: | |
| """Delete a model""" | |
| try: | |
| if self.db_ref: | |
| model_path = f'models/categories/{category}/models/{model_id}' | |
| self.db_ref.child(model_path).delete() | |
| else: | |
| # Delete from mock data | |
| if category in self.mock_data and model_id in self.mock_data[category]["models"]: | |
| del self.mock_data[category]["models"][model_id] | |
| return True, "Model deleted successfully" | |
| except Exception as e: | |
| return False, f"Error deleting model: {str(e)}" | |
| def count_active_models(self, category: str) -> int: | |
| """Count active models in a category""" | |
| try: | |
| models = self.get_models_by_category(category) | |
| return sum(1 for model in models.values() if model.get('isActive', True)) | |
| except Exception as e: | |
| print(f"Error counting models for {category}: {e}") | |
| return 0 | |
| def add_log_entry(self, log_entry: dict) -> tuple[bool, str]: | |
| """Add a log entry""" | |
| try: | |
| log_id = str(int(time.time() * 1000)) # millisecond precision | |
| if self.db_ref: | |
| print(f"π₯ Adding log to Firebase: {log_entry}") | |
| self.db_ref.child(f'logs/{log_id}').set(log_entry) | |
| print(f"π₯ Successfully added log {log_id} to Firebase") | |
| else: | |
| print(f"π Adding log to mock data: {log_entry}") | |
| # Add to mock data logs | |
| if not hasattr(self, 'mock_logs'): | |
| self.mock_logs = {} | |
| print("π Initialized mock_logs dictionary") | |
| self.mock_logs[log_id] = log_entry | |
| print(f"π Successfully added log {log_id} to mock data. Total logs: {len(self.mock_logs)}") | |
| return True, "Log entry added successfully" | |
| except Exception as e: | |
| print(f"π₯ Error in add_log_entry: {str(e)}") | |
| return False, f"Error adding log entry: {str(e)}" | |
| def get_logs(self) -> dict: | |
| """Get all log entries""" | |
| try: | |
| if self.db_ref: | |
| logs_ref = self.db_ref.child('logs') | |
| logs_data = logs_ref.get() or {} | |
| print(f"Retrieved {len(logs_data)} log entries from Firebase") | |
| return logs_data | |
| else: | |
| # Return mock logs | |
| mock_logs = getattr(self, 'mock_logs', {}) | |
| print(f"Retrieved {len(mock_logs)} log entries from mock data") | |
| return mock_logs | |
| except Exception as e: | |
| print(f"Error getting logs: {str(e)}") | |
| return {} | |
| # Initialize Firebase | |
| firebase = FirebaseManager() | |
| # Pydantic models | |
| class ModelCreate(BaseModel): | |
| displayName: str | |
| urn: str | |
| tags: str = "" | |
| isActive: bool = True | |
| nsfw: bool = False | |
| class ModelUpdate(BaseModel): | |
| field: str | |
| value: Any | |
| class LogEntry(BaseModel): | |
| timestamp: str | |
| action: str | |
| category: str | |
| displayName: str | |
| urn: str | |
| # Authentication Routes | |
| async def login(request: LoginRequest, req: Request): | |
| """Authenticate with admin key""" | |
| print(f"π Login attempt with key length: {len(request.admin_key)}") | |
| print(f"π Expected key length: {len(ADMIN_KEY)}") | |
| print(f"π Keys match: {request.admin_key == ADMIN_KEY}") | |
| print(f"π Request from: {req.client.host if req.client else 'unknown'}") | |
| if request.admin_key == ADMIN_KEY: | |
| jwt_token = create_jwt_token() | |
| print(f"β Login successful, JWT token created") | |
| # Detect environment for optimal cookie settings | |
| host = str(req.headers.get("host", "")) | |
| is_hf = "hf.space" in host | |
| is_https = req.headers.get("x-forwarded-proto") == "https" or req.url.scheme == "https" or is_hf | |
| print(f"π Host: {host}") | |
| print(f"π Is HuggingFace: {is_hf}") | |
| print(f"π Is HTTPS: {is_https}") | |
| response = { | |
| "success": True, | |
| "message": "Authenticated successfully", | |
| "token": jwt_token, # Also return token in response for localStorage fallback | |
| "expires_in": JWT_EXPIRATION_HOURS * 3600 # seconds | |
| } | |
| response_obj = JSONResponse(response) | |
| # Set cookie with mobile/HF-friendly settings | |
| cookie_settings = { | |
| "key": "session_token", | |
| "value": jwt_token, | |
| "max_age": JWT_EXPIRATION_HOURS * 3600, | |
| "httponly": False, # Allow JS access for localStorage fallback | |
| "secure": is_https, # Secure if HTTPS | |
| "samesite": "lax" if not is_hf else "none" # None for cross-origin on HF | |
| } | |
| # For HuggingFace Spaces, use more permissive settings | |
| if is_hf: | |
| cookie_settings["secure"] = True # Always secure on HF | |
| cookie_settings["samesite"] = "none" # Allow cross-origin | |
| print(f"πͺ Cookie settings: {cookie_settings}") | |
| response_obj.set_cookie(**cookie_settings) | |
| return response_obj | |
| else: | |
| print(f"β Login failed - invalid key") | |
| raise HTTPException(status_code=401, detail="Invalid admin key") | |
| async def logout(token: Optional[str] = Depends(get_token_from_request)): | |
| """Logout and invalidate session""" | |
| # JWT tokens are stateless, so we just clear the cookie | |
| # In a more complex setup, you could maintain a blacklist of tokens | |
| response = {"success": True, "message": "Logged out successfully"} | |
| response_obj = JSONResponse(response) | |
| # Clear cookie with same settings as login | |
| response_obj.set_cookie( | |
| key="session_token", | |
| value="", | |
| max_age=0, | |
| httponly=False, | |
| secure=True, | |
| samesite="none" | |
| ) | |
| return response_obj | |
| async def auth_status(authenticated: bool = Depends(verify_session), token: Optional[str] = Depends(get_token_from_request)): | |
| """Check if user is authenticated""" | |
| print(f"π Auth status check - Authenticated: {authenticated}, Token present: {bool(token)}") | |
| if token: | |
| print(f"π Token first 10 chars: {token[:10]}...") | |
| return {"authenticated": authenticated} | |
| async def debug_info(req: Request): | |
| """Debug endpoint to check configuration""" | |
| return { | |
| "admin_key_set": bool(ADMIN_KEY), | |
| "admin_key_length": len(ADMIN_KEY) if ADMIN_KEY else 0, | |
| "jwt_enabled": True, | |
| "host": req.headers.get("host"), | |
| "user_agent": req.headers.get("user-agent"), | |
| "is_hf": "hf.space" in str(req.headers.get("host", "")), | |
| "environment": "production" if "hf.space" in str(req.headers.get("host", "")) else "development" | |
| } | |
| async def get_bot_config(_: bool = Depends(require_auth)): | |
| """Get bot management configuration from Firebase""" | |
| try: | |
| # Try to get from Firebase first | |
| if firebase.db_ref: | |
| config_ref = firebase.db_ref.child('system/bot_config') | |
| config_data = config_ref.get() or {} | |
| return { | |
| "bot_api_url": config_data.get('bot_api_url'), | |
| "discord_admin_key_set": bool(config_data.get('discord_admin_key')), | |
| "auto_configured": bool(config_data.get('bot_api_url') and config_data.get('discord_admin_key')), | |
| "stored_in_firebase": True | |
| } | |
| else: | |
| # Fallback to environment variables | |
| return { | |
| "bot_api_url": BOT_API_URL, | |
| "discord_admin_key_set": bool(DISCORD_ADMIN_KEY), | |
| "auto_configured": bool(BOT_API_URL and DISCORD_ADMIN_KEY), | |
| "stored_in_firebase": False | |
| } | |
| except Exception as e: | |
| # Fallback to environment variables on error | |
| return { | |
| "bot_api_url": BOT_API_URL, | |
| "discord_admin_key_set": bool(DISCORD_ADMIN_KEY), | |
| "auto_configured": bool(BOT_API_URL and DISCORD_ADMIN_KEY), | |
| "stored_in_firebase": False, | |
| "error": str(e) | |
| } | |
| # Pydantic model for bot config | |
| class BotConfigUpdate(BaseModel): | |
| bot_api_url: str | |
| discord_admin_key: str | |
| async def save_bot_config(config: BotConfigUpdate, _: bool = Depends(require_auth)): | |
| """Save bot management configuration to Firebase""" | |
| try: | |
| if not firebase.db_ref: | |
| raise HTTPException(status_code=500, detail="Firebase not available") | |
| # Validate URL | |
| from urllib.parse import urlparse | |
| parsed_url = urlparse(config.bot_api_url) | |
| if not parsed_url.scheme or not parsed_url.netloc: | |
| raise HTTPException(status_code=400, detail="Invalid bot API URL") | |
| # Validate admin key | |
| if len(config.discord_admin_key.strip()) < 8: | |
| raise HTTPException(status_code=400, detail="Discord admin key must be at least 8 characters") | |
| # Save to Firebase | |
| config_data = { | |
| 'bot_api_url': config.bot_api_url.strip(), | |
| 'discord_admin_key': config.discord_admin_key.strip(), | |
| 'updated_at': datetime.now().isoformat(), | |
| 'updated_by': 'admin' # Could be enhanced to track actual user | |
| } | |
| config_ref = firebase.db_ref.child('system/bot_config') | |
| config_ref.set(config_data) | |
| # Log the configuration change | |
| log_entry = { | |
| 'timestamp': datetime.now().isoformat(), | |
| 'action': 'bot_config_updated', | |
| 'category': 'system', | |
| 'displayName': 'Bot Configuration', | |
| 'urn': f'system://bot_config' | |
| } | |
| firebase.add_log_entry(log_entry) | |
| return {"message": "Bot configuration saved successfully"} | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| print(f"Error saving bot config: {e}") | |
| raise HTTPException(status_code=500, detail=f"Error saving configuration: {str(e)}") | |
| # Bot Management Proxy Endpoints | |
| async def proxy_reload_cache(_: bool = Depends(require_auth)): | |
| """Proxy reload cache request to Discord bot using Firebase-stored credentials""" | |
| return await proxy_bot_request("reload-cache", "POST") | |
| async def proxy_restart_bot(_: bool = Depends(require_auth)): | |
| """Proxy restart bot request to Discord bot using Firebase-stored credentials""" | |
| return await proxy_bot_request("restart-bot", "POST") | |
| async def proxy_bot_status(_: bool = Depends(require_auth)): | |
| """Proxy status request to Discord bot using Firebase-stored credentials""" | |
| return await proxy_bot_request("status", "GET") | |
| async def proxy_bot_request(endpoint: str, method: str): | |
| """Helper function to proxy requests to Discord bot""" | |
| import aiohttp | |
| try: | |
| # Get bot configuration from Firebase | |
| if not firebase.db_ref: | |
| raise HTTPException(status_code=500, detail="Firebase not available") | |
| config_ref = firebase.db_ref.child('system/bot_config') | |
| config_data = config_ref.get() or {} | |
| bot_api_url = config_data.get('bot_api_url') | |
| discord_admin_key = config_data.get('discord_admin_key') | |
| if not bot_api_url or not discord_admin_key: | |
| raise HTTPException(status_code=400, detail="Bot configuration not found in Firebase") | |
| # Make request to Discord bot | |
| url = f"{bot_api_url}/api/{endpoint}?ngrok-skip-browser-warning=1" | |
| headers = { | |
| 'Authorization': f'Bearer {discord_admin_key}', | |
| 'Content-Type': 'application/json' | |
| } | |
| async with aiohttp.ClientSession() as session: | |
| if method == "POST": | |
| async with session.post(url, headers=headers) as response: | |
| if response.status == 200: | |
| result = await response.json() | |
| return result | |
| else: | |
| error_text = await response.text() | |
| raise HTTPException(status_code=response.status, detail=error_text) | |
| else: # GET | |
| async with session.get(url, headers=headers) as response: | |
| if response.status == 200: | |
| result = await response.json() | |
| return result | |
| else: | |
| error_text = await response.text() | |
| raise HTTPException(status_code=response.status, detail=error_text) | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| print(f"Error proxying bot request to {endpoint}: {e}") | |
| raise HTTPException(status_code=500, detail=f"Error communicating with Discord bot: {str(e)}") | |
| # Main Routes | |
| async def read_root(authenticated: bool = Depends(verify_session)): | |
| """Serve the main HTML page or login page""" | |
| # Always serve index.html - let the JavaScript handle authentication | |
| # This prevents the server-side redirect loop | |
| return FileResponse('index.html') | |
| async def login_page(): | |
| """Serve the login page""" | |
| return FileResponse('login.html') | |
| async def get_all_models(): | |
| """Get all models from all categories""" | |
| models = firebase.get_all_models() | |
| return {"models": models} | |
| async def get_models_by_category(category: str): | |
| """Get models for a specific category""" | |
| if category not in ["pony", "illustrious", "sdxl"]: | |
| raise HTTPException(status_code=400, detail="Invalid category") | |
| models = firebase.get_models_by_category(category) | |
| active_count = firebase.count_active_models(category) | |
| return { | |
| "category": category, | |
| "models": models, | |
| "activeCount": active_count, | |
| "maxModels": 25 | |
| } | |
| async def add_model(category: str, model: ModelCreate, _: bool = Depends(require_auth)): | |
| """Add a new model to a category""" | |
| if category not in ["pony", "illustrious", "sdxl"]: | |
| raise HTTPException(status_code=400, detail="Invalid category") | |
| model_data = { | |
| 'displayName': model.displayName, | |
| 'urn': model.urn, | |
| 'category': category, | |
| 'isActive': model.isActive, | |
| 'tags': [tag.strip() for tag in model.tags.split(',') if tag.strip()], | |
| 'metadata': {'nsfw': model.nsfw} | |
| } | |
| success, message = firebase.add_model(category, model_data) | |
| if not success: | |
| raise HTTPException(status_code=400, detail=message) | |
| return {"message": message} | |
| async def update_model(category: str, model_id: str, update: ModelUpdate, _: bool = Depends(require_auth)): | |
| """Update a specific field of a model""" | |
| if category not in ["pony", "illustrious", "sdxl"]: | |
| raise HTTPException(status_code=400, detail="Invalid category") | |
| success, message = firebase.update_model(category, model_id, update.field, update.value) | |
| if not success: | |
| raise HTTPException(status_code=400, detail=message) | |
| return {"message": message} | |
| async def delete_model(category: str, model_id: str, _: bool = Depends(require_auth)): | |
| """Delete a model""" | |
| if category not in ["pony", "illustrious", "sdxl"]: | |
| raise HTTPException(status_code=400, detail="Invalid category") | |
| success, message = firebase.delete_model(category, model_id) | |
| if not success: | |
| raise HTTPException(status_code=400, detail=message) | |
| return {"message": message} | |
| async def add_log_entry(log_entry: LogEntry, _: bool = Depends(require_auth)): | |
| """Add a log entry""" | |
| print(f"π Received log entry: {log_entry.action} - {log_entry.category} - {log_entry.displayName}") | |
| log_data = { | |
| 'timestamp': log_entry.timestamp, | |
| 'action': log_entry.action, | |
| 'category': log_entry.category, | |
| 'displayName': log_entry.displayName, | |
| 'urn': log_entry.urn | |
| } | |
| success, message = firebase.add_log_entry(log_data) | |
| if not success: | |
| print(f"β Failed to add log entry: {message}") | |
| raise HTTPException(status_code=400, detail=message) | |
| print(f"β Successfully added log entry") | |
| return {"message": message} | |
| async def get_logs(_: bool = Depends(require_auth)): | |
| """Get all log entries""" | |
| print("π Getting logs...") | |
| try: | |
| logs_data = firebase.get_logs() | |
| # Convert dict to list for easier frontend handling | |
| logs_list = [] | |
| for log_id, log_entry in logs_data.items(): | |
| logs_list.append(log_entry) | |
| print(f"π Returning {len(logs_list)} log entries") | |
| return {"logs": logs_list} | |
| except Exception as e: | |
| print(f"β Error getting logs: {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"Error retrieving logs: {str(e)}") | |
| async def test_log(): | |
| """Test endpoint to verify logging system works (no auth required for debugging)""" | |
| print("π§ͺ Testing logging system...") | |
| test_log = { | |
| "timestamp": "2024-01-01T12:00:00.000Z", | |
| "action": "test", | |
| "category": "debug", | |
| "displayName": "Test Log Entry", | |
| "urn": "test://debug/endpoint" | |
| } | |
| success, message = firebase.add_log_entry(test_log) | |
| if not success: | |
| print(f"β Test log failed: {message}") | |
| return {"success": False, "message": message} | |
| print(f"β Test log succeeded: {message}") | |
| # Also try to retrieve logs to verify read access | |
| logs_data = firebase.get_logs() | |
| logs_count = len(logs_data) | |
| return { | |
| "success": True, | |
| "message": f"Test log added successfully. Total logs: {logs_count}", | |
| "logs_count": logs_count | |
| } | |
| async def test_get_logs(): | |
| """Test endpoint to get logs without auth (for debugging)""" | |
| print("π§ͺ Testing log retrieval...") | |
| try: | |
| logs_data = firebase.get_logs() | |
| logs_count = len(logs_data) | |
| return { | |
| "success": True, | |
| "logs_count": logs_count, | |
| "logs": list(logs_data.values())[:5] # Return first 5 logs for testing | |
| } | |
| except Exception as e: | |
| print(f"β Test log retrieval failed: {str(e)}") | |
| return {"success": False, "error": str(e)} | |
| async def get_logs_debug(): | |
| """Debug version of logs endpoint without auth requirement""" | |
| print("π Debug: Getting logs without auth...") | |
| try: | |
| logs_data = firebase.get_logs() | |
| # Convert dict to list for easier frontend handling | |
| logs_list = [] | |
| for log_id, log_entry in logs_data.items(): | |
| logs_list.append(log_entry) | |
| print(f"π Debug: Returning {len(logs_list)} log entries") | |
| return {"logs": logs_list} | |
| except Exception as e: | |
| print(f"π Debug: Error getting logs: {str(e)}") | |
| return {"error": str(e), "logs": []} | |
| # ======================================== | |
| # BOT MANAGEMENT ENDPOINTS | |
| # ======================================== | |
| class BotRestartRequest(BaseModel): | |
| reason: str = "Manual restart from web admin" | |
| async def restart_bot(request: BotRestartRequest, authenticated: bool = Depends(require_auth)): | |
| """Restart the Discord bot""" | |
| try: | |
| import subprocess | |
| import sys | |
| import os | |
| # Log the restart request | |
| activity_log = { | |
| "timestamp": datetime.now().isoformat(), | |
| "action": "bot_restart", | |
| "details": { | |
| "reason": request.reason, | |
| "requested_by": "web_admin" | |
| } | |
| } | |
| try: | |
| firebase.log_activity(activity_log) | |
| except: | |
| pass # Log to Firebase if possible, but don't fail the restart | |
| # For now, just return a success message | |
| # In a production environment, you might want to implement actual bot restart logic | |
| # This could involve process management, systemd service restart, etc. | |
| return { | |
| "success": True, | |
| "message": "Bot restart requested successfully", | |
| "note": "β οΈ Bot restart functionality needs to be implemented based on your deployment method. " | |
| "Consider using process managers like PM2, systemd, or container orchestration." | |
| } | |
| except Exception as e: | |
| return {"success": False, "error": str(e)} | |
| async def get_bot_status(authenticated: bool = Depends(require_auth)): | |
| """Get Discord bot status""" | |
| try: | |
| # This is a placeholder - in a real implementation, you'd check if the bot process is running | |
| # You could use process monitoring, health check endpoints, or Discord API status | |
| status_info = { | |
| "status": "unknown", | |
| "message": "Bot status monitoring not implemented yet", | |
| "last_seen": None, | |
| "uptime": None, | |
| "note": "Implement bot status monitoring by adding health check endpoints to your bot" | |
| } | |
| return status_info | |
| except Exception as e: | |
| return {"success": False, "error": str(e)} | |
| async def reload_bot_cache(authenticated: bool = Depends(require_auth)): | |
| """Trigger bot cache reload""" | |
| try: | |
| # Log the cache reload request | |
| activity_log = { | |
| "timestamp": datetime.now().isoformat(), | |
| "action": "cache_reload", | |
| "details": { | |
| "requested_by": "web_admin", | |
| "target": "discord_bot" | |
| } | |
| } | |
| try: | |
| firebase.log_activity(activity_log) | |
| except: | |
| pass | |
| # Note: This would typically send a signal to the bot or call its API | |
| # For now, we'll just return a success message with instructions | |
| return { | |
| "success": True, | |
| "message": "Cache reload requested", | |
| "instructions": [ | |
| "1. Go to your Discord server where the bot is active", | |
| "2. Run the slash command: /reload_cache", | |
| "3. The bot will refresh its model cache from Firebase", | |
| "4. Note: New models won't appear in slash commands until bot restart" | |
| ] | |
| } | |
| except Exception as e: | |
| return {"success": False, "error": str(e)} | |
| # Serve static files | |
| app.mount("/static", StaticFiles(directory="static"), name="static") | |
| # Add CORS middleware for development | |
| from fastapi.middleware.cors import CORSMiddleware | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| if __name__ == "__main__": | |
| import uvicorn | |
| port = int(os.getenv('PORT', 7860)) | |
| print(f"Starting CatGPT Model Manager on http://0.0.0.0:{port}") | |
| uvicorn.run(app, host="0.0.0.0", port=port) |