Spaces:
Sleeping
Sleeping
Commit ·
e9d14bd
1
Parent(s): 6a279e1
working project
Browse files- .gitignore +3 -0
- README.md +0 -0
- app.py +147 -0
- requirements.txt +8 -0
- src/model.py +22 -0
- src/preprocess.py +26 -0
- src/processing.py +95 -0
- src/storage.py +58 -0
.gitignore
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
gcp_key.json
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.pyc
|
README.md
ADDED
|
File without changes
|
app.py
ADDED
|
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import streamlit as st
|
| 2 |
+
import os
|
| 3 |
+
from src.processing import train_mode_cloud, predict_health_cloud
|
| 4 |
+
|
| 5 |
+
#st.set_page_config(page_title="Piranaware Cloud", page_icon="☁️", layout="wide")
|
| 6 |
+
#st.markdown("""<style>.stApp {background-color: #F0F2F6;}</style>""", unsafe_allow_html=True)
|
| 7 |
+
# --- PIRANAWARE COASTAL THEME (CSS) ---
|
| 8 |
+
st.markdown("""
|
| 9 |
+
<style>
|
| 10 |
+
/* 1. Main Background - Pure black for maximum contrast */
|
| 11 |
+
.stApp {
|
| 12 |
+
background-color: #000000;
|
| 13 |
+
}
|
| 14 |
+
|
| 15 |
+
/* 2. Text Color Fix - High visibility yellow */
|
| 16 |
+
.stApp, .stMarkdown, p, label {
|
| 17 |
+
color: #FFD700 !important; /* Bright safety yellow */
|
| 18 |
+
}
|
| 19 |
+
|
| 20 |
+
/* 3. Headers - Strong yellow, slightly warmer */
|
| 21 |
+
h1, h2, h3, h4, h5, h6 {
|
| 22 |
+
color: #FFEB3B !important; /* Vivid header yellow */
|
| 23 |
+
font-family: 'Helvetica Neue', sans-serif;
|
| 24 |
+
font-weight: 700;
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
/* 4. Tab Styling */
|
| 28 |
+
button[data-baseweb="tab"] {
|
| 29 |
+
color: #BDB76B !important; /* Muted yellow for inactive */
|
| 30 |
+
font-weight: 600;
|
| 31 |
+
}
|
| 32 |
+
button[data-baseweb="tab"][aria-selected="true"] {
|
| 33 |
+
color: #FFD700 !important;
|
| 34 |
+
border-bottom: 4px solid #FFD700 !important;
|
| 35 |
+
background-color: #111111 !important;
|
| 36 |
+
}
|
| 37 |
+
|
| 38 |
+
/* 5. Buttons - Black & yellow safety style */
|
| 39 |
+
div.stButton > button {
|
| 40 |
+
background-color: #000000;
|
| 41 |
+
color: #FFD700;
|
| 42 |
+
border: 3px solid #FFD700;
|
| 43 |
+
border-radius: 8px;
|
| 44 |
+
font-weight: bold;
|
| 45 |
+
}
|
| 46 |
+
div.stButton > button:hover {
|
| 47 |
+
background-color: #FFD700;
|
| 48 |
+
color: #000000;
|
| 49 |
+
box-shadow: 0 4px 14px rgba(255, 215, 0, 0.6);
|
| 50 |
+
border: 3px solid #FFD700;
|
| 51 |
+
}
|
| 52 |
+
|
| 53 |
+
/* 6. Input Labels */
|
| 54 |
+
.stAudioInput label, .stFileUploader label, .stSelectbox label, .stTextInput label {
|
| 55 |
+
color: #FFD700 !important;
|
| 56 |
+
font-weight: 700;
|
| 57 |
+
}
|
| 58 |
+
|
| 59 |
+
/* 7. Results Box Styling */
|
| 60 |
+
.result-box-healthy {
|
| 61 |
+
background-color: #111111;
|
| 62 |
+
border: 2px solid #00FF9C;
|
| 63 |
+
border-left: 6px solid #00C781;
|
| 64 |
+
padding: 15px; border-radius: 5px;
|
| 65 |
+
color: #00FF9C;
|
| 66 |
+
}
|
| 67 |
+
.result-box-anomaly {
|
| 68 |
+
background-color: #111111;
|
| 69 |
+
border: 2px solid #FF5252;
|
| 70 |
+
border-left: 6px solid #D32F2F;
|
| 71 |
+
padding: 15px; border-radius: 5px;
|
| 72 |
+
color: #FF5252;
|
| 73 |
+
}
|
| 74 |
+
</style>
|
| 75 |
+
""", unsafe_allow_html=True)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
TEMP_AUDIO_DIR = "temp_audio_uploads"
|
| 79 |
+
os.makedirs(TEMP_AUDIO_DIR, exist_ok=True)
|
| 80 |
+
|
| 81 |
+
def save_audio(audio_value):
|
| 82 |
+
if audio_value is None: return None
|
| 83 |
+
audio_value.seek(0)
|
| 84 |
+
save_path = os.path.join(TEMP_AUDIO_DIR, "input.wav")
|
| 85 |
+
with open(save_path, "wb") as f:
|
| 86 |
+
f.write(audio_value.read())
|
| 87 |
+
return save_path
|
| 88 |
+
|
| 89 |
+
# --- LOGIN ---
|
| 90 |
+
with st.sidebar:
|
| 91 |
+
#st.image("https://img.icons8.com/color/96/speedboat.png", width=80)
|
| 92 |
+
st.title("User Login")
|
| 93 |
+
st.markdown("### Ensure to use your exact boat ID")
|
| 94 |
+
|
| 95 |
+
# BOAT ID INPUT
|
| 96 |
+
boat_id = st.text_input("Enter Boat ID", value="DEMO_BOAT_01").upper().replace(" ", "_")
|
| 97 |
+
st.caption("Training saved online on Google Cloud.")
|
| 98 |
+
st.divider()
|
| 99 |
+
st.info(f"Active Session:\n**{boat_id}**")
|
| 100 |
+
|
| 101 |
+
# --- MAIN APP ---
|
| 102 |
+
st.title("Piranaware Boat Engine AI")
|
| 103 |
+
|
| 104 |
+
tab_train, tab_test = st.tabs(["🛠️ Train Baseline", "🩺 Diagnostics"])
|
| 105 |
+
|
| 106 |
+
with tab_train:
|
| 107 |
+
st.info(f"Training models for: **{boat_id}**. Ensure engine is HEALTHY.")
|
| 108 |
+
c1, c2, c3 = st.columns(3)
|
| 109 |
+
|
| 110 |
+
for col, mode in [(c1, "idle"), (c2, "slow"), (c3, "fast")]:
|
| 111 |
+
with col:
|
| 112 |
+
st.markdown(f"### {mode.upper()}")
|
| 113 |
+
try: audio = st.audio_input(f"Rec {mode}", key=f"rec_{mode}")
|
| 114 |
+
except: audio = st.file_uploader(f"Up {mode}", type=['wav'], key=f"rec_{mode}")
|
| 115 |
+
|
| 116 |
+
if st.button(f"Train {mode.upper()}", key=f"btn_{mode}"):
|
| 117 |
+
if audio:
|
| 118 |
+
path = save_audio(audio)
|
| 119 |
+
with st.spinner("Training & Uploading to Cloud..."):
|
| 120 |
+
res = train_mode_cloud(path, mode, boat_id)
|
| 121 |
+
st.success(res)
|
| 122 |
+
else:
|
| 123 |
+
st.error("No Audio")
|
| 124 |
+
|
| 125 |
+
with tab_test:
|
| 126 |
+
st.divider()
|
| 127 |
+
st.markdown(f"### Diagnostics for: **{boat_id}**")
|
| 128 |
+
|
| 129 |
+
col_in, col_out = st.columns([1, 2])
|
| 130 |
+
with col_in:
|
| 131 |
+
mode = st.selectbox("Select Mode", ["idle", "slow", "fast"])
|
| 132 |
+
try: test_audio = st.audio_input("Record", key="test")
|
| 133 |
+
except: test_audio = st.file_uploader("Upload", type=['wav'], key="test")
|
| 134 |
+
btn = st.button("Run Diagnostics")
|
| 135 |
+
|
| 136 |
+
with col_out:
|
| 137 |
+
if btn and test_audio:
|
| 138 |
+
path = save_audio(test_audio)
|
| 139 |
+
with st.spinner("Downloading Model & Analyzing..."):
|
| 140 |
+
report = predict_health_cloud(path, mode, boat_id)
|
| 141 |
+
|
| 142 |
+
if "HEALTHY" in report:
|
| 143 |
+
st.success(report)
|
| 144 |
+
elif "ANOMALY" in report:
|
| 145 |
+
st.error(report)
|
| 146 |
+
else:
|
| 147 |
+
st.warning(report)
|
requirements.txt
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
tensorflow
|
| 2 |
+
librosa
|
| 3 |
+
numpy> 2.0
|
| 4 |
+
matplotlib
|
| 5 |
+
scikit-learn
|
| 6 |
+
streamlit
|
| 7 |
+
google-cloud-storage
|
| 8 |
+
huggingface-hub
|
src/model.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import tensorflow as tf
|
| 2 |
+
from tensorflow.keras import layers, models
|
| 3 |
+
|
| 4 |
+
def build_autoencoder(input_shape):
|
| 5 |
+
input_img = layers.Input(shape=input_shape)
|
| 6 |
+
|
| 7 |
+
# Encoder
|
| 8 |
+
x = layers.Conv2D(32, (3, 3), activation='relu', padding='same')(input_img)
|
| 9 |
+
x = layers.MaxPooling2D((2, 2), padding='same')(x)
|
| 10 |
+
x = layers.Conv2D(16, (3, 3), activation='relu', padding='same')(x)
|
| 11 |
+
encoded = layers.MaxPooling2D((2, 2), padding='same')(x)
|
| 12 |
+
|
| 13 |
+
# Decoder
|
| 14 |
+
x = layers.Conv2D(16, (3, 3), activation='relu', padding='same')(encoded)
|
| 15 |
+
x = layers.UpSampling2D((2, 2))(x)
|
| 16 |
+
x = layers.Conv2D(32, (3, 3), activation='relu', padding='same')(x)
|
| 17 |
+
x = layers.UpSampling2D((2, 2))(x)
|
| 18 |
+
decoded = layers.Conv2D(1, (3, 3), activation='sigmoid', padding='same')(x)
|
| 19 |
+
|
| 20 |
+
autoencoder = models.Model(input_img, decoded)
|
| 21 |
+
autoencoder.compile(optimizer='adam', loss='mse')
|
| 22 |
+
return autoencoder
|
src/preprocess.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import librosa
|
| 2 |
+
import numpy as np
|
| 3 |
+
|
| 4 |
+
SAMPLE_RATE = 22050
|
| 5 |
+
DURATION = 1.0
|
| 6 |
+
SAMPLES_PER_SLICE = int(SAMPLE_RATE * DURATION)
|
| 7 |
+
N_MELS = 128
|
| 8 |
+
|
| 9 |
+
def audio_to_spectrograms(file_path):
|
| 10 |
+
try:
|
| 11 |
+
y, sr = librosa.load(file_path, sr=SAMPLE_RATE)
|
| 12 |
+
num_slices = len(y) // SAMPLES_PER_SLICE
|
| 13 |
+
if num_slices < 1: return None
|
| 14 |
+
|
| 15 |
+
spectrograms = []
|
| 16 |
+
for i in range(num_slices):
|
| 17 |
+
y_slice = y[i*SAMPLES_PER_SLICE : (i+1)*SAMPLES_PER_SLICE]
|
| 18 |
+
spec = librosa.feature.melspectrogram(y=y_slice, sr=sr, n_mels=N_MELS)
|
| 19 |
+
log_spec = librosa.power_to_db(spec, ref=np.max)
|
| 20 |
+
norm_spec = np.clip((log_spec + 80) / 80, 0, 1)
|
| 21 |
+
spectrograms.append(norm_spec[..., np.newaxis])
|
| 22 |
+
|
| 23 |
+
return np.array(spectrograms)
|
| 24 |
+
except Exception as e:
|
| 25 |
+
print(f"Error: {e}")
|
| 26 |
+
return None
|
src/processing.py
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import os
|
| 3 |
+
import json
|
| 4 |
+
import tensorflow as tf
|
| 5 |
+
from src.preprocess import audio_to_spectrograms
|
| 6 |
+
from src.model import build_autoencoder
|
| 7 |
+
from src.storage import upload_file, download_file
|
| 8 |
+
|
| 9 |
+
TEMP_DIR = "temp_models"
|
| 10 |
+
if not os.path.exists(TEMP_DIR):
|
| 11 |
+
os.makedirs(TEMP_DIR)
|
| 12 |
+
|
| 13 |
+
def train_mode_cloud(audio_path, mode_name, boat_id):
|
| 14 |
+
# 1. Preprocess
|
| 15 |
+
X_train = audio_to_spectrograms(audio_path)
|
| 16 |
+
if X_train is None: return "❌ Audio too short (min 1 sec)."
|
| 17 |
+
|
| 18 |
+
# 2. Train
|
| 19 |
+
autoencoder = build_autoencoder(X_train.shape[1:])
|
| 20 |
+
autoencoder.fit(X_train, X_train, epochs=40, batch_size=4, verbose=0)
|
| 21 |
+
|
| 22 |
+
# 3. Calculate Threshold (THE FIX)
|
| 23 |
+
reconstructions = autoencoder.predict(X_train)
|
| 24 |
+
mse = np.mean(np.power(X_train - reconstructions, 2), axis=(1, 2, 3))
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
threshold = float(np.mean(mse) + (2 * np.std(mse)))
|
| 28 |
+
|
| 29 |
+
# 4. Save Locally
|
| 30 |
+
model_filename = f"{mode_name}_model.h5"
|
| 31 |
+
meta_filename = f"{mode_name}_meta.json"
|
| 32 |
+
local_model_path = os.path.join(TEMP_DIR, model_filename)
|
| 33 |
+
local_meta_path = os.path.join(TEMP_DIR, meta_filename)
|
| 34 |
+
|
| 35 |
+
autoencoder.save(local_model_path, save_format='h5', include_optimizer=False)
|
| 36 |
+
with open(local_meta_path, 'w') as f:
|
| 37 |
+
json.dump({"threshold": threshold}, f)
|
| 38 |
+
|
| 39 |
+
# 5. Upload
|
| 40 |
+
u1 = upload_file(local_model_path, boat_id, model_filename)
|
| 41 |
+
u2 = upload_file(local_meta_path, boat_id, meta_filename)
|
| 42 |
+
|
| 43 |
+
if u1 and u2:
|
| 44 |
+
return f"✅ Calibrated {mode_name.upper()} | Threshold: {threshold:.5f}"
|
| 45 |
+
else:
|
| 46 |
+
return "⚠️ Trained locally, but Cloud Upload Failed."
|
| 47 |
+
|
| 48 |
+
def predict_health_cloud(audio_path, mode_name, boat_id):
|
| 49 |
+
model_filename = f"{mode_name}_model.h5"
|
| 50 |
+
meta_filename = f"{mode_name}_meta.json"
|
| 51 |
+
local_model_path = os.path.join(TEMP_DIR, model_filename)
|
| 52 |
+
local_meta_path = os.path.join(TEMP_DIR, meta_filename)
|
| 53 |
+
|
| 54 |
+
# 1. Download
|
| 55 |
+
d1 = download_file(boat_id, model_filename, local_model_path)
|
| 56 |
+
d2 = download_file(boat_id, meta_filename, local_meta_path)
|
| 57 |
+
|
| 58 |
+
if not (d1 and d2):
|
| 59 |
+
return f"⚠️ No trained model found in cloud for Boat: {boat_id} (Mode: {mode_name})"
|
| 60 |
+
|
| 61 |
+
# 2. Load
|
| 62 |
+
with open(local_meta_path, 'r') as f:
|
| 63 |
+
threshold = json.load(f)["threshold"]
|
| 64 |
+
|
| 65 |
+
model = tf.keras.models.load_model(local_model_path, compile=False)
|
| 66 |
+
|
| 67 |
+
# 3. Predict
|
| 68 |
+
X_test = audio_to_spectrograms(audio_path)
|
| 69 |
+
if X_test is None: return "Error: Audio too short."
|
| 70 |
+
|
| 71 |
+
reconstructions = model.predict(X_test)
|
| 72 |
+
# Calculate error for each second of audio
|
| 73 |
+
mse = np.mean(np.power(X_test - reconstructions, 2), axis=(1, 2, 3))
|
| 74 |
+
|
| 75 |
+
# 4. Analysis
|
| 76 |
+
anomalies = np.sum(mse > threshold)
|
| 77 |
+
health_score = 100 * (1 - (anomalies / len(mse)))
|
| 78 |
+
|
| 79 |
+
# 5. Debug Data (Shows you WHY it decided what it decided)
|
| 80 |
+
avg_error = np.mean(mse)
|
| 81 |
+
max_error = np.max(mse)
|
| 82 |
+
|
| 83 |
+
status = "🟢 HEALTHY" if health_score > 85 else "🔴 ANOMALY DETECTED"
|
| 84 |
+
|
| 85 |
+
# Return detailed report
|
| 86 |
+
return f"""
|
| 87 |
+
STATUS: {status}
|
| 88 |
+
Confidence Score: {health_score:.1f}%
|
| 89 |
+
|
| 90 |
+
--- TECHNICAL TELEMETRY ---
|
| 91 |
+
Threshold Limit : {threshold:.5f}
|
| 92 |
+
Your Avg Error : {avg_error:.5f}
|
| 93 |
+
Your Max Error : {max_error:.5f}
|
| 94 |
+
Anomalous Secs : {anomalies} / {len(mse)}
|
| 95 |
+
"""
|
src/storage.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import streamlit as st
|
| 4 |
+
from google.cloud import storage
|
| 5 |
+
from google.oauth2 import service_account
|
| 6 |
+
|
| 7 |
+
# ✅ UPDATED WITH YOUR BUCKET NAME
|
| 8 |
+
BUCKET_NAME = "piranaware20251227841ph"
|
| 9 |
+
|
| 10 |
+
def get_storage_client():
|
| 11 |
+
"""
|
| 12 |
+
Authenticates with Google Cloud.
|
| 13 |
+
CHECKS LOCAL FILE FIRST (to prevent crashes in Codespaces),
|
| 14 |
+
then checks Secrets (for Hugging Face).
|
| 15 |
+
"""
|
| 16 |
+
# 1. Local Dev: Check for local file FIRST
|
| 17 |
+
if os.path.exists("gcp_key.json"):
|
| 18 |
+
return storage.Client.from_service_account_json("gcp_key.json")
|
| 19 |
+
|
| 20 |
+
# 2. Production: Check Streamlit Secrets
|
| 21 |
+
try:
|
| 22 |
+
if "gcp_service_account" in st.secrets:
|
| 23 |
+
creds_dict = dict(st.secrets["gcp_service_account"])
|
| 24 |
+
creds = service_account.Credentials.from_service_account_info(creds_dict)
|
| 25 |
+
return storage.Client(credentials=creds)
|
| 26 |
+
except Exception:
|
| 27 |
+
# If secrets don't exist, we just move on
|
| 28 |
+
pass
|
| 29 |
+
|
| 30 |
+
# 3. If neither works
|
| 31 |
+
st.error("⚠️ No Google Cloud credentials found. Cannot save models.")
|
| 32 |
+
return None
|
| 33 |
+
|
| 34 |
+
def upload_file(local_path, boat_id, filename):
|
| 35 |
+
client = get_storage_client()
|
| 36 |
+
if not client: return False
|
| 37 |
+
|
| 38 |
+
bucket = client.bucket(BUCKET_NAME)
|
| 39 |
+
# Creates folder structure: boat_id/filename
|
| 40 |
+
blob_name = f"{boat_id}/{filename}"
|
| 41 |
+
blob = bucket.blob(blob_name)
|
| 42 |
+
|
| 43 |
+
blob.upload_from_filename(local_path)
|
| 44 |
+
return True
|
| 45 |
+
|
| 46 |
+
def download_file(boat_id, filename, local_dest):
|
| 47 |
+
client = get_storage_client()
|
| 48 |
+
if not client: return False
|
| 49 |
+
|
| 50 |
+
bucket = client.bucket(BUCKET_NAME)
|
| 51 |
+
blob_name = f"{boat_id}/{filename}"
|
| 52 |
+
blob = bucket.blob(blob_name)
|
| 53 |
+
|
| 54 |
+
if not blob.exists():
|
| 55 |
+
return False
|
| 56 |
+
|
| 57 |
+
blob.download_to_filename(local_dest)
|
| 58 |
+
return True
|