POV / src /app_hf_space_optimized.py
garyuzair's picture
Update src/app_hf_space_optimized.py
da66c01 verified
import os
import gc
import torch
import streamlit as st
import tempfile
import json
import subprocess
from huggingface_hub import hf_hub_download
import shutil
from datetime import datetime
from io import BytesIO
from transformers import AutoTokenizer, AutoModelForCausalLM
from parler_tts import ParlerTTSForConditionalGeneration
from diffusers import StableDiffusionPipeline
from PIL import Image
import soundfile as sf
# --- Config ---
st.set_page_config(layout="wide", page_title="⚑ POV Generator Pro")
LLM_MODEL_ID = "openai-community/gpt2-medium" # Slightly larger GPT-2 model
IMG_MODEL_ID = "CompVis/stable-diffusion-v1-4"
TTS_MODEL_ID = "parler-tts/parler-tts-mini-v1.1" # Make sure this matches your desired ParlerTTS model version
# Using Streamlit's native caching for Hugging Face Hub downloads if possible,
# otherwise, this explicit cache dir is fine.
# For HF Spaces, /tmp is ephemeral but fine for a session.
CACHE_DIR = os.path.join(tempfile.gettempdir(), "hf_cache_pov_generator")
os.makedirs(CACHE_DIR, exist_ok=True)
os.environ['HUGGINGFACE_HUB_CACHE'] = CACHE_DIR
os.environ['HF_HOME'] = CACHE_DIR # Also sets the general Hugging Face home
os.environ['TRANSFORMERS_CACHE'] = CACHE_DIR
os.environ['DIFFUSERS_CACHE'] = CACHE_DIR
# --- Session State Initialization ---
if 'run_id' not in st.session_state:
st.session_state.run_id = datetime.now().strftime("%Y%m%d_%H%M%S")
if 'story_data' not in st.session_state:
st.session_state.story_data = None
if 'pil_images' not in st.session_state:
st.session_state.pil_images = None
if 'image_paths_for_video' not in st.session_state:
st.session_state.image_paths_for_video = None
if 'audio_paths' not in st.session_state:
st.session_state.audio_paths = None
if 'video_path' not in st.session_state:
st.session_state.video_path = None
if 'temp_base_dir' not in st.session_state:
st.session_state.temp_base_dir = None
# --- Utility ---
def get_session_temp_dir():
if st.session_state.temp_base_dir and os.path.exists(st.session_state.temp_base_dir):
return st.session_state.temp_base_dir
# Define a base directory for all temporary files for this session run
# This helps in cleaning up everything related to one generation run
base_dir = os.path.join(tempfile.gettempdir(), f"pov_generator_run_{st.session_state.run_id}")
os.makedirs(base_dir, exist_ok=True)
st.session_state.temp_base_dir = base_dir
return base_dir
def cleanup_temp_files(specific_dir=None):
"""Cleans up temporary files."""
path_to_clean = specific_dir or st.session_state.get("temp_base_dir")
if path_to_clean and os.path.exists(path_to_clean):
try:
shutil.rmtree(path_to_clean)
if specific_dir is None: # Only reset if cleaning the main session temp dir
st.session_state.temp_base_dir = None
print(f"Cleaned up temp directory: {path_to_clean}")
except Exception as e:
print(f"Error cleaning up temp directory {path_to_clean}: {e}")
# Clean up individual files if they were stored outside temp_base_dir (legacy or direct)
# For this improved version, all temp files should be within temp_base_dir
def clear_torch_cache():
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
# --- Model Loading (Cached) ---
@st.cache_resource
def load_llm_model_and_tokenizer(model_id):
tokenizer = AutoTokenizer.from_pretrained(model_id, cache_dir=CACHE_DIR)
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
device_map="auto",
cache_dir=CACHE_DIR
)
if tokenizer.pad_token_id is None: # GPT-2 might not have a pad token by default
tokenizer.pad_token = tokenizer.eos_token
model.config.pad_token_id = model.config.eos_token_id
return model, tokenizer
@st.cache_resource
def load_sd_pipeline(model_id):
pipe = StableDiffusionPipeline.from_pretrained(
model_id,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
cache_dir=CACHE_DIR
)
if torch.cuda.is_available():
pipe = pipe.to("cuda")
return pipe
@st.cache_resource
def load_tts_model_and_tokenizers(model_id):
tts_model = ParlerTTSForConditionalGeneration.from_pretrained(
model_id,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
device_map="auto",
cache_dir=CACHE_DIR
)
prompt_tokenizer = AutoTokenizer.from_pretrained(model_id, cache_dir=CACHE_DIR)
# Ensure text_encoder config attribute is correctly accessed
desc_tokenizer_path = tts_model.config.text_encoder.name_or_path if hasattr(tts_model.config.text_encoder, 'name_or_path') else tts_model.config.text_encoder._name_or_path
desc_tokenizer = AutoTokenizer.from_pretrained(desc_tokenizer_path, cache_dir=CACHE_DIR)
return tts_model, prompt_tokenizer, desc_tokenizer
# --- Step 1: Generate JSON Story ---
def generate_story(prompt: str, num_scenes: int):
model, tokenizer = load_llm_model_and_tokenizer(LLM_MODEL_ID)
# Refined prompt for better scene separation and count
story_prompt = (
f"Generate a compelling short POV story based on the following prompt: '{prompt}'. "
f"The story should consist of exactly {num_scenes} distinct scenes. "
f"Clearly separate each scene with the delimiter '###'. "
f"Do not include any introductory or concluding text outside of the scenes and their separators. "
f"Each scene should be a paragraph of 2-4 sentences."
)
input_ids = tokenizer.encode(story_prompt, return_tensors="pt").to(model.device)
# Calculate max_new_tokens, ensuring it doesn't exceed model capacity
# Model's max context length (e.g., 1024 for GPT-2, 2048 for GPT-2-medium/large)
# model.config.n_ctx might not always be present or accurate for all models, using common values.
# For gpt2-medium, n_positions is 1024.
max_model_tokens = getattr(model.config, 'n_positions', 1024)
max_possible_new_tokens = max_model_tokens - input_ids.shape[1] - 20 # Safety buffer
desired_tokens_per_scene = 75 # Avg tokens per scene
desired_total_tokens = num_scenes * desired_tokens_per_scene
# Cap generated tokens to prevent overly long outputs and stay within model limits
max_new_tokens_val = min(desired_total_tokens, 700, max_possible_new_tokens)
if max_new_tokens_val <= 0:
st.error("Prompt is too long, or an issue with calculating max tokens. Not enough space for generating new tokens.")
return None
output = model.generate(
input_ids,
max_new_tokens=max_new_tokens_val,
do_sample=True,
temperature=0.7,
top_k=50,
pad_token_id=tokenizer.eos_token_id
)
full_result = tokenizer.decode(output[0], skip_special_tokens=True)
# Remove the input prompt from the beginning of the result
if full_result.startswith(story_prompt):
generated_text = full_result[len(story_prompt):].strip()
else:
# Fallback: sometimes the model doesn't perfectly echo the input.
# Try to find common start of generation if input is complex or long.
# For now, assume it generates after the prompt or just the story.
# A simple heuristic is to take the part after the last occurrence of a keyword from the prompt.
# This is fragile; good prompt engineering is key.
# For now, let's assume it doesn't include the prompt in the output or the above split works.
# Or, that the '###' split will handle it.
generated_text = full_result # If unsure, process the whole output.
scenes_raw = generated_text.split("###")
processed_scenes = []
for s in scenes_raw:
s_clean = s.strip()
if s_clean: # Skip empty scenes
processed_scenes.append(s_clean)
final_scenes = processed_scenes
# If more scenes than requested, take the first N. If fewer, use what's available.
if len(final_scenes) > num_scenes:
final_scenes = final_scenes[:num_scenes]
st.warning(f"LLM generated more scenes than requested. Using the first {num_scenes}.")
elif len(final_scenes) < num_scenes:
st.warning(f"LLM generated {len(final_scenes)} scenes, but {num_scenes} were requested. Using available scenes.")
if not final_scenes:
st.error("Failed to parse scenes from LLM output. The output was: " + generated_text)
return None
clear_torch_cache()
return {"title": prompt[:60].capitalize(), "scenes": final_scenes}
# --- Step 2: Generate Images ---
def generate_images_for_scenes(scenes):
pipe = load_sd_pipeline(IMG_MODEL_ID)
pil_images = []
# Create a directory for storing frame images for the video
frames_dir = os.path.join(get_session_temp_dir(), "frames_for_video")
os.makedirs(frames_dir, exist_ok=True)
image_paths_for_video = []
cols = st.columns(3) # Adjust number of columns as preferred
col_idx = 0
for i, scene_text in enumerate(scenes):
with st.spinner(f"Generating image for scene {i+1}..."):
try:
# Add a style modifier for better visual appeal, can be user-configurable
styled_prompt = f"{scene_text}, cinematic lighting, detailed, high quality"
image = pipe(styled_prompt, num_inference_steps=30).images[0] # Reduced steps for speed
pil_images.append(image)
# Save image for video creation
img_path = os.path.join(frames_dir, f"frame_{i:03d}.png")
image.save(img_path)
image_paths_for_video.append(img_path)
with cols[col_idx % len(cols)]:
st.image(image, caption=f"Scene {i+1}: {scene_text[:100]}...")
# Download button for individual image
img_byte_arr = BytesIO()
image.save(img_byte_arr, format='PNG')
st.download_button(
label=f"Download Scene {i+1} Image",
data=img_byte_arr.getvalue(),
file_name=f"scene_{i+1}_image.png",
mime="image/png",
key=f"download_img_{i}"
)
col_idx += 1
except Exception as e:
st.error(f"Error generating image for scene {i+1}: {e}")
pil_images.append(None) # Placeholder for failed image
image_paths_for_video.append(None) # Placeholder
clear_torch_cache()
return pil_images, image_paths_for_video
# --- Step 3: Generate TTS ---
def generate_audios_for_scenes(scenes):
tts_model, prompt_tokenizer, desc_tokenizer = load_tts_model_and_tokenizers(TTS_MODEL_ID)
audio_dir = os.path.join(get_session_temp_dir(), "audio_files")
os.makedirs(audio_dir, exist_ok=True)
audio_paths = []
cols = st.columns(3) # Adjust number of columns
col_idx = 0
# User-configurable description, or keep it fixed
tts_description = "A neutral and clear narrator voice."
for i, scene_text in enumerate(scenes):
with st.spinner(f"Generating audio for scene {i+1}..."):
try:
desc_ids = desc_tokenizer(tts_description, return_tensors="pt").input_ids.to(tts_model.device)
prompt_ids = prompt_tokenizer(scene_text, return_tensors="pt").input_ids.to(tts_model.device)
# Generate audio
# For parler-tts, generation_kwargs might be useful, e.g., temperature for description
# generation_output = tts_model.generate(input_ids=desc_ids, prompt_input_ids=prompt_ids, temperature=0.7) # Example
generation_output = tts_model.generate(input_ids=desc_ids, prompt_input_ids=prompt_ids)
audio_waveform = generation_output.to(torch.float32).cpu().numpy()
file_path = os.path.join(audio_dir, f"audio_scene_{i+1}.wav")
sf.write(file_path, audio_waveform, tts_model.config.sampling_rate) # Use model's sampling rate
audio_paths.append(file_path)
with cols[col_idx % len(cols)]:
st.markdown(f"**Audio for Scene {i+1}**")
st.audio(file_path)
with open(file_path, "rb") as f_audio:
st.download_button(
label=f"Download Scene {i+1} Audio",
data=f_audio.read(), # Read bytes for download
file_name=f"scene_{i+1}_audio.wav",
mime="audio/wav",
key=f"download_audio_{i}"
)
col_idx += 1
except Exception as e:
st.error(f"Error generating audio for scene {i+1}: {e}")
audio_paths.append(None) # Placeholder
clear_torch_cache()
return audio_paths
# --- Step 4: Create Video ---
def create_video_from_scenes(image_file_paths, audio_file_paths, output_filename="final_pov_video.mp4"):
if not image_file_paths or not audio_file_paths or len(image_file_paths) != len(audio_file_paths):
st.error("Mismatch in number of images and audio files, or missing assets. Cannot create video.")
return None
# Ensure ffmpeg is installed and accessible
try:
subprocess.run(["ffmpeg", "-version"], capture_output=True, check=True)
except (subprocess.CalledProcessError, FileNotFoundError):
st.error("FFMPEG is not installed or not found in PATH. Video creation is not possible.")
st.markdown("Please install FFMPEG: `sudo apt update && sudo apt install ffmpeg` (Linux) or `brew install ffmpeg` (macOS).")
return None
temp_clips_dir = os.path.join(get_session_temp_dir(), "temp_video_clips")
os.makedirs(temp_clips_dir, exist_ok=True)
video_clips_paths = []
valid_scene_count = 0
for i, (img_path, audio_path) in enumerate(zip(image_file_paths, audio_file_paths)):
if img_path is None or audio_path is None:
st.warning(f"Skipping scene {i+1} in video due to missing image or audio.")
continue
try:
audio_info = sf.info(audio_path)
audio_duration = audio_info.duration
if audio_duration <= 0.1: # Minimum practical duration
st.warning(f"Audio for scene {i+1} is too short ({audio_duration:.2f}s). Using a minimum duration of 1s.")
audio_duration = 1.0 # Enforce a minimum duration
clip_path = os.path.join(temp_clips_dir, f"clip_{i:03d}.mp4")
# Create individual clip: loop image, add audio, set duration to audio length
command = [
"ffmpeg", "-y",
"-loop", "1", "-i", img_path, # Loop the image
"-i", audio_path, # Input audio
"-c:v", "libx264", "-preset", "medium", "-tune", "stillimage",
"-c:a", "aac", "-b:a", "192k",
"-pix_fmt", "yuv420p",
"-t", str(audio_duration), # Duration of this clip
"-shortest", # End when shortest input (audio) ends
clip_path
]
process = subprocess.run(command, capture_output=True, text=True)
if process.returncode != 0:
st.error(f"FFMPEG error creating clip for scene {i+1}:\n{process.stderr}")
continue # Skip this clip
video_clips_paths.append(clip_path)
valid_scene_count += 1
except Exception as e:
st.error(f"Error processing scene {i+1} for video: {e}")
continue
if not video_clips_paths or valid_scene_count == 0:
st.error("No valid video clips were generated. Cannot create final video.")
cleanup_temp_files(temp_clips_dir) # Clean up partial clips
return None
# Create a file list for ffmpeg concat
concat_list_file = os.path.join(temp_clips_dir, "concat_list.txt")
with open(concat_list_file, "w") as f:
for clip_p in video_clips_paths:
# Paths in concat file need to be relative or absolute, ensure correct format for ffmpeg
# Using absolute paths is safer here if concat_list.txt is in a different dir than clips.
# Since they are in the same dir, relative is fine.
f.write(f"file '{os.path.basename(clip_p)}'\n")
final_video_path = os.path.join(get_session_temp_dir(), output_filename)
concat_command = [
"ffmpeg", "-y",
"-f", "concat", "-safe", "0", "-i", concat_list_file,
"-c", "copy", # Re-mux, don't re-encode if codecs are compatible
final_video_path
]
process = subprocess.run(concat_command, capture_output=True, text=True, cwd=temp_clips_dir) # Run from clips dir
if process.returncode != 0:
st.error(f"FFMPEG error concatenating video clips:\n{process.stderr}")
cleanup_temp_files(temp_clips_dir) # Clean up partial clips
return None
st.success("Video created successfully!")
# cleanup_temp_files(temp_clips_dir) # Optionally clean up intermediate clips after final video is made
# Better to clean up everything at session end or via button.
return final_video_path
# --- Main App UI ---
st.title("⚑ POV Story Generator Pro")
st.markdown("Create engaging POV stories with AI-generated text, images, audio, and video.")
st.markdown("---")
# Sidebar for inputs
with st.sidebar:
st.header("πŸ“ Story Configuration")
prompt = st.text_area(
"Enter your POV story prompt:",
st.session_state.get("user_prompt", "POV: You are a detective solving a mystery in a futuristic city."),
height=100,
key="user_prompt_input"
)
num_scenes = st.slider("Number of Scenes:", min_value=2, max_value=10, value=st.session_state.get("num_scenes_val", 3), key="num_scenes_slider")
st.markdown("---")
if st.button("πŸš€ Generate Full Story & Assets", type="primary", use_container_width=True):
# Reset states for a new generation run
st.session_state.run_id = datetime.now().strftime("%Y%m%d_%H%M%S") # New unique ID for this run
cleanup_temp_files() # Clean up any previous run's temp files
st.session_state.story_data = None
st.session_state.pil_images = None
st.session_state.image_paths_for_video = None
st.session_state.audio_paths = None
st.session_state.video_path = None
st.session_state.user_prompt = prompt # Save current input values
st.session_state.num_scenes_val = num_scenes
# Trigger generation flags (optional, direct execution is fine too)
st.session_state.generate_all = True
st.markdown("---")
st.header("πŸ› οΈ Utilities")
if st.button("🧹 Clear Cache & Temp Files & Restart", use_container_width=True):
# Clear model caches
st.cache_resource.clear()
# Clear session state related to generated artifacts
keys_to_clear = ['story_data', 'pil_images', 'image_paths_for_video',
'audio_paths', 'video_path', 'temp_base_dir', 'generate_all']
for key in keys_to_clear:
if key in st.session_state:
del st.session_state[key]
cleanup_temp_files() # Ensure physical temp files are deleted
st.session_state.run_id = datetime.now().strftime("%Y%m%d_%H%M%S") # New ID after clear
st.success("Caches and temporary files cleared. App will restart.")
st.rerun()
# Main content area
if st.session_state.get("generate_all"):
# --- 1. Generate Story ---
with st.status("🧠 Generating story...", expanded=True) as status_story:
try:
st.session_state.story_data = generate_story(st.session_state.user_prompt, st.session_state.num_scenes_val)
if st.session_state.story_data:
status_story.update(label="Story generated successfully!", state="complete")
else:
status_story.update(label="Story generation failed.", state="error")
st.session_state.generate_all = False # Stop further processing
except Exception as e:
st.error(f"An unexpected error occurred during story generation: {e}")
status_story.update(label="Story generation error.", state="error")
st.session_state.generate_all = False
# --- Display Story ---
if st.session_state.story_data:
st.subheader(f"🎬 Story: {st.session_state.story_data['title']}")
for i, scene_text in enumerate(st.session_state.story_data['scenes']):
st.markdown(f"**Scene {i+1}:** {scene_text}")
story_json = json.dumps(st.session_state.story_data, indent=2)
st.download_button(
label="Download Story (JSON)",
data=story_json,
file_name=f"{st.session_state.story_data['title'].replace(' ', '_').lower()}_story.json",
mime="application/json"
)
st.markdown("---")
# --- 2. Generate Images (if story succeeded) ---
if st.session_state.get("generate_all") and st.session_state.story_data:
with st.status("🎨 Generating images for scenes...", expanded=True) as status_images:
try:
st.session_state.pil_images, st.session_state.image_paths_for_video = generate_images_for_scenes(st.session_state.story_data['scenes'])
if all(img is not None for img in st.session_state.pil_images): # Basic check
status_images.update(label="Images generated successfully!", state="complete")
elif any(img is not None for img in st.session_state.pil_images):
status_images.update(label="Some images generated. Check for errors.", state="warning")
else:
status_images.update(label="Image generation failed for all scenes.", state="error")
st.session_state.generate_all = False # Stop further processing
except Exception as e:
st.error(f"An unexpected error occurred during image generation: {e}")
status_images.update(label="Image generation error.", state="error")
st.session_state.generate_all = False
st.markdown("---")
# --- 3. Generate Audio (if images succeeded or partially) ---
if st.session_state.get("generate_all") and st.session_state.story_data and st.session_state.pil_images:
with st.status("πŸ”Š Generating audio for scenes...", expanded=True) as status_audio:
try:
st.session_state.audio_paths = generate_audios_for_scenes(st.session_state.story_data['scenes'])
if all(p is not None for p in st.session_state.audio_paths): # Basic check
status_audio.update(label="Audio generated successfully!", state="complete")
elif any(p is not None for p in st.session_state.audio_paths):
status_audio.update(label="Some audio files generated. Check for errors.", state="warning")
else:
status_audio.update(label="Audio generation failed for all scenes.", state="error")
st.session_state.generate_all = False # Stop further processing
except Exception as e:
st.error(f"An unexpected error occurred during audio generation: {e}")
status_audio.update(label="Audio generation error.", state="error")
st.session_state.generate_all = False
st.markdown("---")
# --- 4. Create Video (if audio succeeded or partially) ---
if st.session_state.get("generate_all") and st.session_state.image_paths_for_video and st.session_state.audio_paths:
# Ensure there's at least one valid pair of image and audio
valid_assets = sum(1 for img, aud in zip(st.session_state.image_paths_for_video, st.session_state.audio_paths) if img and aud)
if valid_assets > 0:
with st.status("πŸ“Ή Creating final video...", expanded=True) as status_video:
try:
st.session_state.video_path = create_video_from_scenes(
st.session_state.image_paths_for_video,
st.session_state.audio_paths
)
if st.session_state.video_path:
status_video.update(label="Video created successfully!", state="complete")
else:
status_video.update(label="Video creation failed.", state="error")
except Exception as e:
st.error(f"An unexpected error occurred during video creation: {e}")
status_video.update(label="Video creation error.", state="error")
if st.session_state.video_path:
st.subheader("🎞️ Final Video Presentation")
st.video(st.session_state.video_path)
with open(st.session_state.video_path, "rb") as f_video:
st.download_button(
label="Download Final Video",
data=f_video.read(),
file_name=os.path.basename(st.session_state.video_path),
mime="video/mp4"
)
st.markdown("---")
else:
st.warning("Not enough valid image/audio pairs to create a video.")
# Reset generation trigger
if "generate_all" in st.session_state: # Check if key exists before deleting
del st.session_state.generate_all
elif not st.session_state.get("user_prompt"): # Show initial message if no prompt yet
st.info("Configure your story in the sidebar and click 'Generate Full Story & Assets' to begin!")
# --- Final Cleanup Instruction (Optional: can be tied to session end if platform supports) ---
# For Streamlit, manual cleanup via button or at start of new run is common.
# The `cleanup_temp_files()` is called at the start of a new generation.