Spaces:
Sleeping
Sleeping
| import os | |
| import logging | |
| import io | |
| import asyncio | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| from google.auth.transport.requests import Request | |
| from google.oauth2.credentials import Credentials | |
| from googleapiclient.discovery import build | |
| from googleapiclient.http import MediaFileUpload, MediaIoBaseDownload | |
| from googleapiclient.errors import HttpError | |
| from core.database import DB_FILENAME | |
| logger = logging.getLogger(__name__) | |
| class DriveService: | |
| SCOPES = [ | |
| 'https://www.googleapis.com/auth/gmail.send', | |
| 'https://www.googleapis.com/auth/drive.file' | |
| ] | |
| FOLDER_NAME = "apigateway" | |
| # DB_FILENAME is now imported from core.database | |
| def __init__(self): | |
| self.creds = None | |
| self.service = None | |
| # Server-side credentials for Drive API | |
| self.client_id = os.getenv('SERVER_GOOGLE_CLIENT_ID') or os.getenv('GOOGLE_CLIENT_ID') | |
| self.client_secret = os.getenv('SERVER_GOOGLE_CLIENT_SECRET') or os.getenv('GOOGLE_CLIENT_SECRET') | |
| self.refresh_token = os.getenv('SERVER_GOOGLE_REFRESH_TOKEN') or os.getenv('GOOGLE_REFRESH_TOKEN') | |
| def authenticate(self): | |
| """Authenticate using the refresh token.""" | |
| if not all([self.client_id, self.client_secret, self.refresh_token]): | |
| logger.error("Missing Google API credentials for Drive Service") | |
| return False | |
| try: | |
| self.creds = Credentials( | |
| None, | |
| refresh_token=self.refresh_token, | |
| token_uri="https://oauth2.googleapis.com/token", | |
| client_id=self.client_id, | |
| client_secret=self.client_secret, | |
| scopes=self.SCOPES | |
| ) | |
| if self.creds and self.creds.expired and self.creds.refresh_token: | |
| self.creds.refresh(Request()) | |
| self.service = build('drive', 'v3', credentials=self.creds) | |
| return True | |
| except Exception as e: | |
| logger.error(f"Failed to authenticate with Drive API: {e}") | |
| return False | |
| def _find_folder(self): | |
| """Find the 'apigateway' folder.""" | |
| try: | |
| query = f"mimeType='application/vnd.google-apps.folder' and name='{self.FOLDER_NAME}' and trashed=false" | |
| results = self.service.files().list(q=query, spaces='drive', fields='files(id, name)').execute() | |
| files = results.get('files', []) | |
| if not files: | |
| return None | |
| return files[0]['id'] | |
| except HttpError as error: | |
| logger.error(f"An error occurred searching for folder: {error}") | |
| return None | |
| def _create_folder(self): | |
| """Create the 'apigateway' folder.""" | |
| try: | |
| file_metadata = { | |
| 'name': self.FOLDER_NAME, | |
| 'mimeType': 'application/vnd.google-apps.folder' | |
| } | |
| file = self.service.files().create(body=file_metadata, fields='id').execute() | |
| logger.info(f"Created folder with ID: {file.get('id')}") | |
| return file.get('id') | |
| except HttpError as error: | |
| logger.error(f"An error occurred creating folder: {error}") | |
| return None | |
| def _get_folder_id(self): | |
| """Get folder ID, creating it if it doesn't exist.""" | |
| folder_id = self._find_folder() | |
| if not folder_id: | |
| folder_id = self._create_folder() | |
| return folder_id | |
| def upload_db(self): | |
| """Upload the database file to Google Drive.""" | |
| if not self.service and not self.authenticate(): | |
| return False | |
| if not os.path.exists(DB_FILENAME): | |
| logger.warning(f"Database file {DB_FILENAME} not found for upload.") | |
| return False | |
| folder_id = self._get_folder_id() | |
| if not folder_id: | |
| return False | |
| try: | |
| # Check if file already exists in folder | |
| query = f"name='{DB_FILENAME}' and '{folder_id}' in parents and trashed=false" | |
| results = self.service.files().list(q=query, spaces='drive', fields='files(id)').execute() | |
| files = results.get('files', []) | |
| media = MediaFileUpload(DB_FILENAME, mimetype='application/x-sqlite3', resumable=True) | |
| if files: | |
| # Update existing file | |
| file_id = files[0]['id'] | |
| self.service.files().update(fileId=file_id, media_body=media).execute() | |
| logger.info(f"Updated database file {DB_FILENAME} in Drive (ID: {file_id})") | |
| else: | |
| # Create new file | |
| file_metadata = { | |
| 'name': DB_FILENAME, | |
| 'parents': [folder_id] | |
| } | |
| self.service.files().create(body=file_metadata, media_body=media, fields='id').execute() | |
| logger.info(f"Uploaded new database file {DB_FILENAME} to Drive") | |
| return True | |
| except HttpError as error: | |
| logger.error(f"An error occurred uploading DB: {error}") | |
| return False | |
| except Exception as e: | |
| logger.error(f"Unexpected error uploading DB: {e}") | |
| return False | |
| async def upload_db_async(self): | |
| """Async wrapper for upload_db to run in a separate thread.""" | |
| return await asyncio.to_thread(self.upload_db) | |
| def download_db(self): | |
| """Download the database file from Google Drive.""" | |
| if not self.service and not self.authenticate(): | |
| return False | |
| folder_id = self._find_folder() # Don't create if not found, just return | |
| if not folder_id: | |
| logger.info("No 'apigateway' folder found in Drive. Starting with fresh DB.") | |
| return False | |
| try: | |
| query = f"name='{DB_FILENAME}' and '{folder_id}' in parents and trashed=false" | |
| results = self.service.files().list(q=query, spaces='drive', fields='files(id)').execute() | |
| files = results.get('files', []) | |
| if not files: | |
| logger.info(f"No {DB_FILENAME} found in Drive folder. Starting with fresh DB.") | |
| return False | |
| file_id = files[0]['id'] | |
| request = self.service.files().get_media(fileId=file_id) | |
| fh = io.FileIO(DB_FILENAME, 'wb') | |
| downloader = MediaIoBaseDownload(fh, request) | |
| done = False | |
| while done is False: | |
| status, done = downloader.next_chunk() | |
| # logger.info(f"Download {int(status.progress() * 100)}%.") | |
| logger.info(f"Successfully downloaded {DB_FILENAME} from Drive.") | |
| return True | |
| except HttpError as error: | |
| logger.error(f"An error occurred downloading DB: {error}") | |
| return False | |
| except Exception as e: | |
| logger.error(f"Unexpected error downloading DB: {e}") | |
| return False | |