JessicaChen1854 commited on
Commit
e11abed
·
verified ·
1 Parent(s): 5167f5d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +51 -55
app.py CHANGED
@@ -9,36 +9,36 @@ import json
9
 
10
  # 初始化模型 (全部使用原始pipeline)
11
  @st.cache_resource
12
- def load_pipelines():
13
  try:
14
  # 1. 剧本生成(T5模型)
15
  script_pipe = pipeline(
16
- "text2text-generation",
17
  model="mrm8488/t5-base-finetuned-common_gen",
18
- tokenizer="t5-base"
19
  )
20
 
21
- # 2. 分镜生成(BART模型)
22
- storyboard_pipe = pipeline(
23
- "text2text-generation", # 使用text2text而非summarization
24
- model="philschmid/bart-large-cnn-samsum"
25
- )
26
-
27
- # 3. 配乐生成(MusicGen)
28
- # music_pipe = pipeline(
29
- # "text-to-audio",
30
- # model="facebook/musicgen-small",
31
- # device_map="auto"
32
  # )
33
 
34
- # 4. 分镜图片生成(Stable Diffusion
35
- image_pipe = StableDiffusionPipeline.from_pretrained(
36
- "prompthero/openjourney-v4",
37
- safety_checker=None # 禁用安全检查以加速
38
- )
 
39
 
40
- return script_pipe, storyboard_pipe, image_pipe
41
- # return script_pipe, storyboard_pipe, music_pipe, image_pipe
 
 
 
 
 
 
42
 
43
  except Exception as e:
44
  st.error(f"模型加载失败: {str(e)}")
@@ -59,41 +59,37 @@ if user_input:
59
  # 1. 生成剧本
60
  with st.status("🖋️ 剧本生成中...", expanded=True) as status:
61
  try:
62
- # 生成提示词
63
- prompt = f"""生成电影剧本(主题:{user_input})
64
- ### 要求:
65
- - Markdown格式
66
- - 3个场景(INT/EXT)
67
- - 包含动作描述和对话
68
-
69
- 示例:
70
- ### 场景1:实验室(INT.夜)
71
- [警报灯闪烁]
72
- 博士(擦汗):实验体要失控了!"""
73
-
74
- # 生成并处理
75
- script = script_pipe(
76
- prompt,
77
- max_length=400, # 缩短生成长度
78
- num_beams=1, # 禁用束搜索
79
- do_sample=True, # 启用随机采样
80
- temperature=0.7 # 平衡创造性与稳定性
81
- )[0]["generated_text"]
82
-
83
- script = re.sub(r"场景\d+", r"### \g<0>", script) # 统一标题格式
84
-
85
- # 更新状态
86
- status.update(label="✅ 生成完成", state="complete")
87
- st.subheader("生成剧本")
88
- st.markdown(f"```markdown\n{script}\n```")
89
 
90
- except torch.cuda.OutOfMemoryError:
91
- status.update(label="❌ 显存不足", state="error")
92
- st.error("请尝试更简短的输入")
93
- except Exception as e:
94
- status.update(label="❌ 生成错误", state="error")
95
- st.error(str(e))
96
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
  # 2. 生成分镜
98
  # with st.status("🎥 正在转换分镜脚本...", expanded=True) as status:
99
  # try:
 
9
 
10
  # 初始化模型 (全部使用原始pipeline)
11
  @st.cache_resource
12
+ def load_model():
13
  try:
14
  # 1. 剧本生成(T5模型)
15
  script_pipe = pipeline(
16
+ "text2text-generation",
17
  model="mrm8488/t5-base-finetuned-common_gen",
18
+ device=0 if torch.cuda.is_available() else -1
19
  )
20
 
21
+ # # 2. 分镜生成(BART模型)
22
+ # storyboard_pipe = pipeline(
23
+ # "text2text-generation", # 使用text2text而非summarization
24
+ # model="philschmid/bart-large-cnn-samsum"
 
 
 
 
 
 
 
25
  # )
26
 
27
+ # # 3. 配乐生成(MusicGen
28
+ # # music_pipe = pipeline(
29
+ # # "text-to-audio",
30
+ # # model="facebook/musicgen-small",
31
+ # # device_map="auto"
32
+ # # )
33
 
34
+ # # 4. 分镜图片生成(Stable Diffusion)
35
+ # image_pipe = StableDiffusionPipeline.from_pretrained(
36
+ # "prompthero/openjourney-v4",
37
+ # safety_checker=None # 禁用安全检查以加速
38
+ # )
39
+
40
+ # return script_pipe, storyboard_pipe, image_pipe
41
+ # # return script_pipe, storyboard_pipe, music_pipe, image_pipe
42
 
43
  except Exception as e:
44
  st.error(f"模型加载失败: {str(e)}")
 
59
  # 1. 生成剧本
60
  with st.status("🖋️ 剧本生成中...", expanded=True) as status:
61
  try:
62
+ # 加载模型
63
+ pipe = load_model()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
 
65
+ # 生成内容
66
+ prompt = build_prompt(theme)
67
+ response = pipe(
68
+ prompt,
69
+ max_length=600,
70
+ temperature=0.6,
71
+ num_beams=2,
72
+ no_repeat_ngram_size=2
73
+ )[0]["generated_text"]
74
+
75
+ # 清洗与验证
76
+ cleaned = format_script(response)
77
+ if "###" not in cleaned:
78
+ raise ValueError("生成内容不符合格式要求")
79
+
80
+ # 显示结果
81
+ st.subheader("生成剧本")
82
+ st.markdown(f"```markdown\n{cleaned}\n```")
83
+ status.update(label="✅ 生成完成", state="complete")
84
+
85
+ except Exception as e:
86
+ status.update(label="❌ 生成失败", state="error")
87
+ st.error(f"错误详情:{str(e)}")
88
+ st.code(f"原始输出:{response[:200]}")
89
+ # 运行程序
90
+ if __name__ == "__main__":
91
+ main()
92
+
93
  # 2. 生成分镜
94
  # with st.status("🎥 正在转换分镜脚本...", expanded=True) as status:
95
  # try: