Leo Liu commited on
Commit
7c3c59f
·
verified ·
1 Parent(s): 9ede58e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -7
app.py CHANGED
@@ -14,21 +14,52 @@ def text2story(text):
14
  pipe = pipeline("text-generation",
15
  model="pranavpsv/genre-story-generator-v2",
16
  max_new_tokens=100,
17
- truncation=True)
18
- story_text = pipe(text,
19
- do_sample=True,
20
- temperature=0.9,
21
- top_k=50,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  num_return_sequences=1)[0]['generated_text']
23
  last_punctuation = max(story_text.rfind("."), story_text.rfind("!"), story_text.rfind("?"))
24
  if last_punctuation != -1:
25
  story_text = story_text[:last_punctuation+1]
26
  return story_text
27
 
 
 
 
 
 
 
 
 
28
  # text2audio
29
  def text2audio(story_text):
30
- pipe = pipeline("text-to-audio", model="Matthijs/mms-tts-eng")
31
- audio_data = pipe(story_text)
 
 
 
 
 
 
 
 
 
32
  return audio_data
33
 
34
 
 
14
  pipe = pipeline("text-generation",
15
  model="pranavpsv/genre-story-generator-v2",
16
  max_new_tokens=100,
17
+ truncation=True)
18
+ prompt_for_children = f"""Generate a fairy tale for children aged 3-10 based on: {text}.
19
+ Requirements:
20
+ 1. Simple and interesting plot
21
+ 2. Include magical elements or talking animals
22
+ 3. Educational message
23
+ 4. Use simple vocabulary (under 500 Lexile)
24
+ 5. Short sentences (max 15 words)
25
+ 6. Add onomatopoeia like 'Boom!', 'Whoosh!' etc.
26
+
27
+ Story: Once upon a time,"""
28
+
29
+ story_text = pipe(prompt_for_children,
30
+ text,
31
+ max_new_tokens=200, # 增加故事长度
32
+ temperature=0.7, # 降低随机性
33
+ top_p=0.95, # 使用核采样
34
+ repetition_penalty=1.2,
35
+ do_sample=True,
36
  num_return_sequences=1)[0]['generated_text']
37
  last_punctuation = max(story_text.rfind("."), story_text.rfind("!"), story_text.rfind("?"))
38
  if last_punctuation != -1:
39
  story_text = story_text[:last_punctuation+1]
40
  return story_text
41
 
42
+ # Post-processing optimization
43
+ story_text = story_text.replace(prompt, "").strip()
44
+ for stop_word in ["\nThe end", "The end", "End of story"]:
45
+ if story_text.endswith(stop_word):
46
+ story_text = story_text[:-len(stop_word)]
47
+ last_punct = max(story_text.rfind("."), story_text.rfind("!"), story_text.rfind("?"))
48
+ return story_text[:last_punct+1] if last_punct != -1 else story_text
49
+
50
  # text2audio
51
  def text2audio(story_text):
52
+ pipe = pipeline("text-to-audio", model="facebook/mms-tts-eng",
53
+ config = {"speaker": "en_female"})
54
+ audio_data = pipe(
55
+ story_text,
56
+ generate_kwargs={
57
+ "tempo": 1.1, # 稍快的语速(1.0为基准)
58
+ "pitch": 4, # 提高音调(0-10范围)
59
+ "energy": 1.2, # 增强表现力
60
+ "vocal_tract_length": 1.15 # 更明亮的音色
61
+ }
62
+ )
63
  return audio_data
64
 
65