import streamlit as st
import os
import io
import concurrent.futures
from PIL import Image
from google import genai
# 로직 모듈 임포트
import logic_image
import logic_seo
import logic_tts
from config_style import STYLE_DEFINITIONS, THUMBNAIL_STRATEGIES
# ==========================================
# 1. 페이지 설정 및 CSS 디자인
# ==========================================
st.set_page_config(
page_title="Nano Banana Studio Pro",
page_icon="🍌",
layout="wide",
initial_sidebar_state="expanded"
)
# 세션 초기화
if 'scene_results' not in st.session_state:
st.session_state['scene_results'] = []
if 'final_audio' not in st.session_state:
st.session_state['final_audio'] = None
if 'shared_script' not in st.session_state:
st.session_state['shared_script'] = ""
# [CSS] 다크모드 + 오렌지 포인트 디자인
st.markdown("""
""", unsafe_allow_html=True)
# ==========================================
# 2. 사이드바 (통합 설정)
# ==========================================
with st.sidebar:
st.markdown("## ⚙️ 기본 설정 (Basic)")
api_key = st.text_input("Google API Key", type="password", placeholder="API 키를 입력하세요...", label_visibility="collapsed")
st.divider()
st.markdown("### 🎨 Model Engine")
model_choice = st.radio("사용할 모델", ["⚡ Fast (Gemini 2.0)", "🚀 Pro (Imagen 3)"], label_visibility="collapsed")
if "Fast" in model_choice:
image_model_id = "gemini-2.0-flash"
text_model_id = "gemini-2.0-flash"
else:
# [복구] 사용자님이 원하시던 그 모델 ID!
image_model_id = "gemini-3-pro-image-preview"
text_model_id = "gemini-3-pro-preview"
st.divider()
st.markdown("### 📐 Canvas Ratio")
ar_radio = st.radio("비율", ["16:9", "9:16"], label_visibility="collapsed", horizontal=True)
aspect_ratio = ar_radio
st.divider()
st.markdown("### 🛠️ 고급 설정 (TTS & Split)")
st.caption("TTS Model ID")
tts_model_id = st.text_input("TTS ID", value="gemini-2.5-pro-preview-tts", label_visibility="collapsed")
st.markdown("
", unsafe_allow_html=True)
st.caption("Image Scene Split (이미지 생성용)")
duration_per_scene = st.slider("장면당 시간(초)", 3, 30, 5)
split_criteria = duration_per_scene * 8
st.info(f"💡 {duration_per_scene}초 (약 {split_criteria}자) 단위로 장면을 나눕니다.")
st.caption("※ TTS는 설정과 무관하게 500자 단위로 최적화됩니다.")
st.divider()
st.markdown("## 🎨 스타일 및 캐릭터 설정")
st.caption("영상 전체의 분위기와 캐릭터를 결정합니다.")
st.markdown("
", unsafe_allow_html=True)
st.markdown("#### 🖌️ 화풍(Style) 선택")
style_options = list(STYLE_DEFINITIONS.keys()) + ["직접 입력"]
selected_style = st.selectbox("스타일 선택", style_options, label_visibility="collapsed")
custom_style_input = ""
current_prompt = ""
if selected_style == "직접 입력":
custom_style_input = st.text_input("스타일 프롬프트", placeholder="예: 지브리 스타일, 수채화풍")
current_prompt = custom_style_input if custom_style_input else "(프롬프트를 입력해주세요)"
else:
style_value = STYLE_DEFINITIONS[selected_style]
if isinstance(style_value, dict):
current_prompt = style_value.get('prompt', str(style_value))
else:
current_prompt = str(style_value)
st.markdown("#### 📜 적용될 프롬프트 (미리보기)")
st.text_area("프롬프트 미리보기", value=current_prompt, height=150, disabled=True, label_visibility="collapsed")
st.divider()
st.markdown("#### 👤 캐릭터 참조 (Reference)")
uploaded_file = st.file_uploader("이미지 업로드", type=["png", "jpg", "jpeg"], label_visibility="collapsed")
reference_image = None
if uploaded_file:
reference_image = Image.open(uploaded_file)
st.image(reference_image, caption="참조 이미지", use_container_width=True)
# ==========================================
# 3. 메인 콘텐츠
# ==========================================
st.title("Nano Banana Studio Pro", anchor=False)
st.markdown("
AI Powered All-in-One Workspace
", unsafe_allow_html=True)
tab1, tab2, tab3 = st.tabs(["🎬 장면 생성 (Scenes)", "📈 기획/SEO (Analysis)", "🎙️ 성우/TTS (Voice)"])
# ----------------------------------------------------------------
# [TAB 1] 장면/이미지 생성 (메인 기능)
# ----------------------------------------------------------------
with tab1:
st.markdown("
", unsafe_allow_html=True)
st.markdown("#### ❶ 대본 입력 (Script)")
script_input = st.text_area(
"대본",
height=300,
placeholder="대본을 줄글로 붙여넣으세요. 사이드바에서 설정한 시간(초) 단위로 AI가 장면을 자동으로 나눕니다.",
label_visibility="collapsed",
key="scene_script_input"
)
if script_input:
st.session_state['shared_script'] = script_input
st.markdown("
", unsafe_allow_html=True)
# 생성 버튼
if st.button("🚀 장면 생성 시작 (Generate Scenes)", type="primary", use_container_width=True):
if not api_key: st.error("⚠️ 사이드바에 API Key를 입력해주세요.")
elif not script_input: st.warning("⚠️ 대본을 입력해주세요.")
else:
# [핵심] 이미지 생성은 슬라이더 값(split_criteria)을 사용
raw_scenes = logic_tts.split_text_smartly(script_input, limit=split_criteria)
scenes_text = raw_scenes[:10]
st.toast(f"📜 {len(scenes_text)}개 장면으로 분할하여 생성을 시작합니다.")
client = genai.Client(api_key=api_key)
progress_text = "AI 화가가 그림을 그리는 중입니다..."
my_bar = st.progress(0, text=progress_text)
temp_results = [None] * len(scenes_text)
# 안정성 위해 max_workers=2
with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:
future_to_idx = {
executor.submit(logic_image.process_scene_task, i, {'text': text}, selected_style, custom_style_input, client, text_model_id, image_model_id, aspect_ratio, reference_image): i
for i, text in enumerate(scenes_text)
}
completed = 0
for future in concurrent.futures.as_completed(future_to_idx):
idx, prompt, img = future.result()
temp_results[idx] = (prompt, img, scenes_text[idx])
completed += 1
my_bar.progress(completed / len(scenes_text), text=f"Scene {idx+1} 완료!")
st.session_state['scene_results'] = temp_results
my_bar.empty()
st.rerun()
# 결과 표시
if st.session_state['scene_results']:
st.divider()
st.subheader("🎬 Scene Results")
for i, item in enumerate(st.session_state['scene_results']):
if item is None: continue
p_text, img_data, s_text = item
with st.container():
c1, c2 = st.columns([1, 2])
with c1:
if img_data:
try:
image = Image.open(io.BytesIO(img_data))
st.image(image, use_container_width=True)
st.download_button(
label="⬇️ 다운로드",
data=img_data,
file_name=f"scene_{i+1}.png",
mime="image/png",
key=f"dl_btn_{i}",
use_container_width=True
)
except: st.error("이미지 오류")
else: st.warning("이미지 없음")
with c2:
st.markdown(f"**Scene {i+1}**")
st.info(s_text)
if st.button(f"🔄 재생성", key=f"regen_{i}"):
if not api_key: st.error("API Key 필요")
else:
client = genai.Client(api_key=api_key)
with st.spinner("다시 그리는 중..."):
_, new_p, new_img = logic_image.process_scene_task(i, {'text': s_text}, selected_style, custom_style_input, client, text_model_id, image_model_id, aspect_ratio, reference_image)
st.session_state['scene_results'][i] = (new_p, new_img, s_text)
st.rerun()
st.divider()
# ----------------------------------------------------------------
# [TAB 2] 유튜브 기획 (SEO)
# ----------------------------------------------------------------
with tab2:
st.markdown("
", unsafe_allow_html=True)
st.markdown("#### 📈 유튜브 SEO 분석기")
st.caption("대본을 기반으로 제목, 태그, 설명란을 자동으로 생성합니다.")
default_seo_script = st.session_state.get('shared_script', "")
seo_script = st.text_area("분석할 대본 입력", value=default_seo_script, height=150, placeholder="SEO 분석을 원하는 대본을 입력하세요.", key="seo_input")
if st.button("✨ SEO 기획서 생성하기", key="seo_btn", type="primary"):
if not api_key: st.error("API Key 필요")
elif not seo_script: st.warning("대본 필요")
else:
with st.spinner("Brain 모델이 분석 중입니다..."):
client = genai.Client(api_key=api_key)
seo_result = logic_seo.generate_seo_content(client, text_model_id, seo_script)
st.session_state["seo_result"] = seo_result
st.success("분석 완료!")
c_seo1, c_seo2 = st.columns(2)
with c_seo1:
st.markdown("##### 📌 추천 제목")
for t in seo_result['titles']: st.info(f"{t}")
with c_seo2:
st.markdown("##### 🏷️ 추천 태그")
st.code(", ".join(seo_result['tags']))
st.markdown("##### 📝 설명란 (Description)")
st.text_area("설명란 결과", seo_result['description'], height=200)
st.divider()
st.subheader("🖼️ 썸네일 생성")
seo_cached = st.session_state.get("seo_result")
if not seo_cached:
st.info("SEO 분석을 먼저 실행하면 썸네일 생성 버튼이 나타납니다.")
else:
strat_keys = list(THUMBNAIL_STRATEGIES.keys())
sel_strat = st.selectbox("썸네일 전략 선택", strat_keys, index=0)
# 썸네일 텍스트(메인 타이틀) 후보: SEO 추천 제목 1개를 기본값으로
default_thumb_text = seo_cached["titles"][0] if seo_cached.get("titles") else ""
thumb_text = st.text_input("썸네일 메인 문구(원하면 수정)", value=default_thumb_text)
if st.button("🧠 썸네일 프롬프트 생성", key="thumb_prompt_btn"):
if not api_key:
st.error("API Key 필요")
else:
client = genai.Client(api_key=api_key)
strategy_block = THUMBNAIL_STRATEGIES[sel_strat]
prompt = f"""
너는 유튜브 썸네일 기획자다.
아래 '대본'과 '전략'을 참고해서, 이미지 생성 모델에 넣을 '썸네일 프롬프트'를 만든다.
[대본]
{seo_script[:8000]}
[전략]
{strategy_block}
[추가 조건]
- 썸네일 이미지 안에 텍스트를 직접 그려 넣지 마라(글자 생성 금지).
- 대신 '텍스트를 넣을 자리'를 구도로 확보해라(상단/하단 여백, 안전영역).
- 국가/국기/대통령/청와대/국회의사당 등 특정 국가 상징이 자동으로 나오지 않게,
정치 상징물은 추상적 은유(실루엣, 조명, 군중, 무대)로 처리해라.
- 화면비는 {aspect_ratio}.
- 출력은 "프롬프트 텍스트 1개"만. JSON/마크다운 금지.
[사용자가 넣을 메인 문구(참고만)]
{thumb_text}
""".strip()
res = client.models.generate_content(model=text_model_id, contents=prompt)
thumb_prompt = (getattr(res, "text", "") or "").strip()
st.session_state["thumb_prompt"] = thumb_prompt
st.success("썸네일 프롬프트 생성 완료!")
st.text_area("생성된 썸네일 프롬프트", value=thumb_prompt, height=180)
# 프롬프트가 있을 때만 이미지 생성 버튼 노출
thumb_prompt_cached = st.session_state.get("thumb_prompt", "")
if thumb_prompt_cached:
if st.button("🎨 썸네일 이미지 1장 생성", key="thumb_img_btn", type="primary"):
if not api_key:
st.error("API Key 필요")
else:
client = genai.Client(api_key=api_key)
# 이미지 생성 (generate_images 우선, 없으면 generate_content fallback)
img_bytes = None
try:
if hasattr(client.models, "generate_images"):
img_res = client.models.generate_images(
model=image_model_id,
prompt=thumb_prompt_cached
)
# logic_image에 있는 안전 추출 함수가 없다면 간단히 parts에서 뽑기
try:
cand = img_res.candidates[0]
part = cand.content.parts[0]
img_bytes = part.inline_data.data if part.inline_data else None
except:
img_bytes = None
else:
img_res = client.models.generate_content(
model=image_model_id,
contents=thumb_prompt_cached
)
try:
cand = img_res.candidates[0]
part = cand.content.parts[0]
img_bytes = part.inline_data.data if part.inline_data else None
except:
img_bytes = None
except Exception as e:
st.error(f"이미지 생성 실패: {e}")
if img_bytes:
st.image(Image.open(io.BytesIO(img_bytes)), use_container_width=True)
st.download_button(
"⬇️ 썸네일 다운로드",
data=img_bytes,
file_name="thumbnail.png",
mime="image/png",
use_container_width=True
)
else:
st.error("이미지 바이트를 추출하지 못했습니다. (모델 응답 구조 확인 필요)")
# ----------------------------------------------------------------
# [TAB 3] AI 성우 (TTS)
# ----------------------------------------------------------------
with tab3:
st.markdown("
", unsafe_allow_html=True)
st.markdown("#### 🎙️ AI 성우 스튜디오")
st.caption("장면 생성 없이 오디오만 필요할 때 사용하세요.")
default_tts_script = st.session_state.get('shared_script', "")
tts_script = st.text_area("낭독할 대본 입력", value=default_tts_script, height=150, placeholder="읽고 싶은 텍스트를 입력하세요.", key="tts_input")
voice_map = {
"Charon (남성/다큐)": "Charon", "Puck (남성/쾌활)": "Puck",
"Kore (여성/차분)": "Kore", "Fenrir (남성/강함)": "Fenrir",
"Aoede (여성/높음)": "Aoede", "Orus (남성/중후)": "Orus"
}
c_voice1, c_voice2 = st.columns([2, 1])
with c_voice1:
v_sel = st.selectbox("성우 선택", list(voice_map.keys()), key="tts_voice_sel")
voice_opt = voice_map[v_sel]
with c_voice2:
if st.button("▶ 미리듣기", key="tts_preview"):
if not api_key: st.error("API Key 필요")
else:
client = genai.Client(api_key=api_key)
prev_audio = logic_tts.generate_speech_chunk(client, tts_model_id, "안녕하세요, 제 목소리입니다.", voice_opt)
if isinstance(prev_audio, bytes):
if not prev_audio.startswith(b'RIFF') and hasattr(logic_tts, 'raw_pcm_to_wav'):
prev_audio = logic_tts.raw_pcm_to_wav(prev_audio)
st.audio(prev_audio, format="audio/wav", autoplay=True)
if st.button("🎙️ 전체 녹음 및 병합 시작", key="tts_full_btn", type="primary"):
if not api_key: st.error("API Key 필요")
elif not tts_script: st.warning("대본 필요")
else:
client = genai.Client(api_key=api_key)
# [핵심] TTS는 슬라이더 값(split_criteria)을 무시하고, 무조건 500자 단위로 고정
chunks = logic_tts.split_text_smartly(tts_script, limit=500)
audio_res = [None] * len(chunks)
with st.status("녹음 진행 중...", expanded=True):
with concurrent.futures.ThreadPoolExecutor() as executor:
f_map = {executor.submit(logic_tts.process_tts_task, i, c, client, tts_model_id, voice_opt): i for i, c in enumerate(chunks)}
for f in concurrent.futures.as_completed(f_map):
idx, dat = f.result()
if isinstance(dat, bytes):
audio_res[idx] = dat
st.write(f"✅ Part {idx+1} 완료")
final_wav = logic_tts.merge_wav_bytes(audio_res)
if final_wav:
st.success("오디오 생성 완료!")
st.audio(final_wav, format="audio/wav")
st.download_button("다운로드 (WAV)", final_wav, "full_audio.wav", "audio/wav")