garyuzair commited on
Commit
e927273
·
verified ·
1 Parent(s): 01a7460

Update src/app_hf_space_optimized.py

Browse files
Files changed (1) hide show
  1. src/app_hf_space_optimized.py +8 -6
src/app_hf_space_optimized.py CHANGED
@@ -35,11 +35,12 @@ def clear_torch():
35
  # --- Step 1: Generate JSON Story ---
36
  def generate_story(prompt: str, num_scenes: int):
37
  st.info("🧠 Generating story...")
38
- tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL_ID)
39
  model = AutoModelForCausalLM.from_pretrained(
40
  LLM_MODEL_ID,
41
  torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
42
- device_map="auto"
 
43
  )
44
 
45
  sys_prompt = (
@@ -69,7 +70,8 @@ def generate_images(scenes):
69
  st.info("🎨 Generating images...")
70
  pipe = StableDiffusionPipeline.from_pretrained(
71
  IMG_MODEL_ID,
72
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
 
73
  )
74
  pipe.to("cuda" if torch.cuda.is_available() else "cpu")
75
  images = []
@@ -83,9 +85,9 @@ def generate_images(scenes):
83
  # --- Step 3: Generate TTS ---
84
  def generate_audios(scenes):
85
  st.info("🔊 Generating audio...")
86
- tts = ParlerTTSForConditionalGeneration.from_pretrained(TTS_MODEL_ID, device_map="auto")
87
- tokenizer = AutoTokenizer.from_pretrained(TTS_MODEL_ID)
88
- desc_tokenizer = AutoTokenizer.from_pretrained(tts.config.text_encoder._name_or_path)
89
 
90
  audio_paths = []
91
  for i, scene in enumerate(scenes):
 
35
  # --- Step 1: Generate JSON Story ---
36
  def generate_story(prompt: str, num_scenes: int):
37
  st.info("🧠 Generating story...")
38
+ tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL_ID, cache_dir=CACHE_DIR)
39
  model = AutoModelForCausalLM.from_pretrained(
40
  LLM_MODEL_ID,
41
  torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
42
+ device_map="auto",
43
+ cache_dir=CACHE_DIR
44
  )
45
 
46
  sys_prompt = (
 
70
  st.info("🎨 Generating images...")
71
  pipe = StableDiffusionPipeline.from_pretrained(
72
  IMG_MODEL_ID,
73
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
74
+ cache_dir=CACHE_DIR
75
  )
76
  pipe.to("cuda" if torch.cuda.is_available() else "cpu")
77
  images = []
 
85
  # --- Step 3: Generate TTS ---
86
  def generate_audios(scenes):
87
  st.info("🔊 Generating audio...")
88
+ tts = ParlerTTSForConditionalGeneration.from_pretrained(TTS_MODEL_ID, device_map="auto", cache_dir=CACHE_DIR)
89
+ tokenizer = AutoTokenizer.from_pretrained(TTS_MODEL_ID, cache_dir=CACHE_DIR)
90
+ desc_tokenizer = AutoTokenizer.from_pretrained(tts.config.text_encoder._name_or_path, cache_dir=CACHE_DIR)
91
 
92
  audio_paths = []
93
  for i, scene in enumerate(scenes):