Spaces:
Sleeping
Sleeping
| import io | |
| import zipfile | |
| from datetime import datetime | |
| import streamlit as st | |
| from rembg import new_session | |
| from logger import get_logger | |
| from utils import load_and_process_image | |
| logger = get_logger("streamlit_app") | |
| logger.info("App started") | |
| st.set_page_config(page_title="Usuwanie tła", layout="wide") | |
| st.title("🎨 Usuwanie tła i zmiana koloru") | |
| st.write("Prześlij obrazy, usuń tło i zastąp je wybranym kolorem!") | |
| # Opcje | |
| col_a, col_b, col_c = st.columns(3) | |
| with col_a: | |
| bg_color = st.color_picker("Wybierz kolor tła", "#FFFFFF") | |
| with col_b: | |
| model_choice = st.selectbox( | |
| "Model AI:", | |
| [ | |
| "birefnet-general", | |
| "birefnet-general-lite", | |
| "birefnet-dis", # Best for objects with holes | |
| "isnet-general-use", | |
| "u2net", | |
| "birefnet-massive", | |
| "bria-rmbg", | |
| ], | |
| index=0, | |
| ) | |
| with col_c: | |
| transparency_mode = st.radio( | |
| "Przezroczystość:", | |
| ["Pozostaw przezroczyste", "Wypełnij kolorem"], | |
| index=1, | |
| ) | |
| with st.sidebar: | |
| st.subheader("Opcje znaku wodnego") | |
| add_watermark = st.checkbox("Dodaj znak wodny", value=True) | |
| watermark_file = st.file_uploader("Prześlij znak wodny (PNG)", type=["png"], key="wm") | |
| wm_opacity = st.slider("Przezroczystość znaku wodnego", 0.0, 1.0, 0.3) | |
| wm_scale = st.slider("Rozmiar znaku wodnego (część szerokości)", 0.1, 1.0, 0.33) | |
| wm_position = st.selectbox( | |
| "Pozycja", | |
| ["prawy-dolny róg", "lewy-dolny róg", "prawy-górny róg", "lewy-górny róg"], | |
| ) | |
| # Konwersja HEX na RGB | |
| hex_color = bg_color.lstrip("#") | |
| rgb_color = tuple(int(hex_color[i : i + 2], 16) for i in (0, 2, 4)) | |
| # Wgrywanie wielu plików | |
| uploaded_files = st.file_uploader( | |
| "Wybierz obrazy", | |
| type=["png", "jpg", "jpeg"], | |
| accept_multiple_files=True, | |
| ) | |
| if uploaded_files: | |
| logger.info(f"Uploaded files: {len(uploaded_files)}") | |
| st.write(f"📁 Przesłano {len(uploaded_files)} obraz(ów)") | |
| # Przycisk przetwarzania | |
| if st.button("🚀 Przetwórz wszystkie obrazy", type="primary"): | |
| # Kontenery postępu | |
| progress_bar = st.progress(0) | |
| status_text = st.empty() | |
| # Utworzenie sesji z wybranym modelem (wspólna dla wszystkich obrazów) | |
| status_text.text(f"Ładowanie modelu: {model_choice}...") | |
| session = new_session(model_choice) | |
| logger.info(f"Model {model_choice} loaded successfully") | |
| processed_images = [] | |
| # Przetwarzanie każdego obrazu | |
| for idx, uploaded_file in enumerate(uploaded_files): | |
| # Aktualizacja postępu | |
| progress = (idx + 1) / len(uploaded_files) | |
| progress_bar.progress(progress) | |
| status_text.text( | |
| f"Przetwarzanie {idx + 1}/{len(uploaded_files)}: {uploaded_file.name}" | |
| ) | |
| input_img, processed_img = load_and_process_image( | |
| uploaded_file, | |
| session=session, | |
| transparency_mode=transparency_mode, | |
| rgb_color=rgb_color, | |
| add_watermark=add_watermark, | |
| watermark_file=watermark_file if watermark_file else None, | |
| wm_opacity=wm_opacity, | |
| wm_scale=wm_scale, | |
| wm_position=wm_position, | |
| ) | |
| # Zapis przetworzonego obrazu | |
| processed_images.append( | |
| {"name": uploaded_file.name, "original": input_img, "processed": processed_img} | |
| ) | |
| status_text.text("✅ Wszystkie obrazy zostały przetworzone!") | |
| # Wyświetlanie wyników w siatce | |
| st.markdown("---") | |
| st.subheader("Podgląd wyników") | |
| # Pokazuj obrazy parami | |
| for img_data in processed_images: | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| st.text(f"Oryginał: {img_data['name']}") | |
| st.image(img_data["original"], width=700) # width="content") | |
| with col2: | |
| st.text(f"Po przetworzeniu: {img_data['name']}") | |
| st.image(img_data["processed"], width=700) | |
| st.markdown("---") | |
| # Utworzenie pliku ZIP ze wszystkimi przetworzonymi obrazami | |
| zip_buffer = io.BytesIO() | |
| with zipfile.ZipFile(zip_buffer, "w", zipfile.ZIP_DEFLATED) as zip_file: | |
| for img_data in processed_images: | |
| img_buffer = io.BytesIO() | |
| # Zapis jako PNG, aby zachować przezroczystość | |
| img_data["processed"].save(img_buffer, format="PNG") | |
| img_bytes = img_buffer.getvalue() | |
| filename = f"przetworzony_{img_data['name'].rsplit('.', 1)[0]}.png" | |
| zip_file.writestr(filename, img_bytes) | |
| zip_buffer.seek(0) | |
| # Przycisk pobierania ZIP | |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| logger.info( | |
| f"ZIP created with {len(processed_images)} images, filename=usuniete_tlo_{timestamp}.zip", | |
| ) | |
| st.download_button( | |
| label=f"📥 Pobierz wszystkie ({len(processed_images)} obrazów)", | |
| data=zip_buffer.getvalue(), | |
| file_name=f"usuniete_tlo_{timestamp}.zip", | |
| mime="application/zip", | |
| ) | |
| else: | |
| st.info("👆 Prześlij jeden lub więcej obrazów, aby rozpocząć") | |
| st.markdown("---") | |