Repcak00's picture
fix: model name
fb6584a verified
import io
import zipfile
from datetime import datetime
import streamlit as st
from rembg import new_session
from logger import get_logger
from utils import load_and_process_image
logger = get_logger("streamlit_app")
logger.info("App started")
st.set_page_config(page_title="Usuwanie tła", layout="wide")
st.title("🎨 Usuwanie tła i zmiana koloru")
st.write("Prześlij obrazy, usuń tło i zastąp je wybranym kolorem!")
# Opcje
col_a, col_b, col_c = st.columns(3)
with col_a:
bg_color = st.color_picker("Wybierz kolor tła", "#FFFFFF")
with col_b:
model_choice = st.selectbox(
"Model AI:",
[
"birefnet-general",
"birefnet-general-lite",
"birefnet-dis", # Best for objects with holes
"isnet-general-use",
"u2net",
"birefnet-massive",
"bria-rmbg",
],
index=0,
)
with col_c:
transparency_mode = st.radio(
"Przezroczystość:",
["Pozostaw przezroczyste", "Wypełnij kolorem"],
index=1,
)
with st.sidebar:
st.subheader("Opcje znaku wodnego")
add_watermark = st.checkbox("Dodaj znak wodny", value=True)
watermark_file = st.file_uploader("Prześlij znak wodny (PNG)", type=["png"], key="wm")
wm_opacity = st.slider("Przezroczystość znaku wodnego", 0.0, 1.0, 0.3)
wm_scale = st.slider("Rozmiar znaku wodnego (część szerokości)", 0.1, 1.0, 0.33)
wm_position = st.selectbox(
"Pozycja",
["prawy-dolny róg", "lewy-dolny róg", "prawy-górny róg", "lewy-górny róg"],
)
# Konwersja HEX na RGB
hex_color = bg_color.lstrip("#")
rgb_color = tuple(int(hex_color[i : i + 2], 16) for i in (0, 2, 4))
# Wgrywanie wielu plików
uploaded_files = st.file_uploader(
"Wybierz obrazy",
type=["png", "jpg", "jpeg"],
accept_multiple_files=True,
)
if uploaded_files:
logger.info(f"Uploaded files: {len(uploaded_files)}")
st.write(f"📁 Przesłano {len(uploaded_files)} obraz(ów)")
# Przycisk przetwarzania
if st.button("🚀 Przetwórz wszystkie obrazy", type="primary"):
# Kontenery postępu
progress_bar = st.progress(0)
status_text = st.empty()
# Utworzenie sesji z wybranym modelem (wspólna dla wszystkich obrazów)
status_text.text(f"Ładowanie modelu: {model_choice}...")
session = new_session(model_choice)
logger.info(f"Model {model_choice} loaded successfully")
processed_images = []
# Przetwarzanie każdego obrazu
for idx, uploaded_file in enumerate(uploaded_files):
# Aktualizacja postępu
progress = (idx + 1) / len(uploaded_files)
progress_bar.progress(progress)
status_text.text(
f"Przetwarzanie {idx + 1}/{len(uploaded_files)}: {uploaded_file.name}"
)
input_img, processed_img = load_and_process_image(
uploaded_file,
session=session,
transparency_mode=transparency_mode,
rgb_color=rgb_color,
add_watermark=add_watermark,
watermark_file=watermark_file if watermark_file else None,
wm_opacity=wm_opacity,
wm_scale=wm_scale,
wm_position=wm_position,
)
# Zapis przetworzonego obrazu
processed_images.append(
{"name": uploaded_file.name, "original": input_img, "processed": processed_img}
)
status_text.text("✅ Wszystkie obrazy zostały przetworzone!")
# Wyświetlanie wyników w siatce
st.markdown("---")
st.subheader("Podgląd wyników")
# Pokazuj obrazy parami
for img_data in processed_images:
col1, col2 = st.columns(2)
with col1:
st.text(f"Oryginał: {img_data['name']}")
st.image(img_data["original"], width=700) # width="content")
with col2:
st.text(f"Po przetworzeniu: {img_data['name']}")
st.image(img_data["processed"], width=700)
st.markdown("---")
# Utworzenie pliku ZIP ze wszystkimi przetworzonymi obrazami
zip_buffer = io.BytesIO()
with zipfile.ZipFile(zip_buffer, "w", zipfile.ZIP_DEFLATED) as zip_file:
for img_data in processed_images:
img_buffer = io.BytesIO()
# Zapis jako PNG, aby zachować przezroczystość
img_data["processed"].save(img_buffer, format="PNG")
img_bytes = img_buffer.getvalue()
filename = f"przetworzony_{img_data['name'].rsplit('.', 1)[0]}.png"
zip_file.writestr(filename, img_bytes)
zip_buffer.seek(0)
# Przycisk pobierania ZIP
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
logger.info(
f"ZIP created with {len(processed_images)} images, filename=usuniete_tlo_{timestamp}.zip",
)
st.download_button(
label=f"📥 Pobierz wszystkie ({len(processed_images)} obrazów)",
data=zip_buffer.getvalue(),
file_name=f"usuniete_tlo_{timestamp}.zip",
mime="application/zip",
)
else:
st.info("👆 Prześlij jeden lub więcej obrazów, aby rozpocząć")
st.markdown("---")