import os.path from typing import List import firebase_admin from firebase_admin import credentials from firebase_admin import firestore from google.cloud.firestore_v1 import CollectionReference from scripts.mo.data.storage import Storage, map_dict_to_record, map_record_to_dict from scripts.mo.environment import env from scripts.mo.models import Record FIREBASE_APP_NAME = "sd-model-organizer-app" def _filter_download(record: Record, show_downloaded, show_not_downloaded): is_downloaded = bool(record.location) and os.path.exists(record.location) return (show_downloaded and is_downloaded) or (show_not_downloaded and not is_downloaded) class FirebaseStorage(Storage): def __init__(self): if not firebase_admin._apps: cred = credentials.Certificate(os.path.join(env.script_dir, "service-account-file.json")) self.app = firebase_admin.initialize_app(cred, name=FIREBASE_APP_NAME) else: self.app = firebase_admin.get_app(name=FIREBASE_APP_NAME) self.firestore_client = firestore.client(app=self.app) def _records(self) -> CollectionReference: return self.firestore_client.collection('records') def get_all_records(self) -> List: record_refs = self._records().stream() records = [] for ref in record_refs: records.append(map_dict_to_record(ref.id, ref.to_dict())) return records def query_records(self, name_query=None, groups=None, model_types=None, show_downloaded=None, show_not_downloaded=None) -> List: query_ref = self._records() if model_types is not None and model_types: query_ref = query_ref.where('model_type', 'in', model_types) records = [] for ref in query_ref.stream(): records.append(map_dict_to_record(ref.id, ref.to_dict())) if name_query is not None and name_query: records = [record for record in records if name_query.lower() in record.name.lower()] if groups is not None and len(groups) > 0: records = [item for item in records if all(val in item.groups for val in groups)] records = list(filter(lambda r: _filter_download(r, show_downloaded, show_not_downloaded), records)) return records def get_record_by_id(self, _id) -> Record: doc = self._records().document(_id).get() return map_dict_to_record(doc.id, doc.to_dict()) def add_record(self, record: Record): self._records().add(map_record_to_dict(record)) def update_record(self, record: Record): ref = self._records().document(record.id_) ref.update(map_record_to_dict(record)) def remove_record(self, _id): self._records().document(_id).delete() def get_available_groups(self) -> List: records = self.get_all_records() groups = [] for record in records: if len(record.groups) > 0: groups.extend(record.groups) return list(set(groups)) def get_records_by_group(self, group: str) -> List: col_ref = self._records() query_ref = col_ref.where('group', 'array_contains', f'%{group}%') records = [] for ref in query_ref.stream(): records.append(map_dict_to_record(ref.id, ref.to_dict())) return records def get_all_records_locations(self) -> List: records = self.get_all_records() locations = [] for record in records: if record.location: locations.append(record.location) return list(set(locations))