TLH01 commited on
Commit
23ad0fc
·
verified ·
1 Parent(s): c67a65e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -12
app.py CHANGED
@@ -26,7 +26,7 @@ def load_image_model():
26
  logger.info("Stage 1 model loaded")
27
  return processor, model
28
  except Exception as e:
29
- st.error("❌ Image model failed to load")
30
  raise
31
 
32
  def stage1_generate_caption(uploaded_file):
@@ -34,7 +34,7 @@ def stage1_generate_caption(uploaded_file):
34
  processor, model = load_image_model()
35
  try:
36
  img = Image.open(uploaded_file).convert("RGB")
37
- img.thumbnail((512, 512))
38
  inputs = processor(images=img, return_tensors="pt", padding=True)
39
  outputs = model.generate(**inputs, max_length=30)
40
  return processor.decode(outputs[0], skip_special_tokens=True)
@@ -47,14 +47,14 @@ def stage1_generate_caption(uploaded_file):
47
  # ======================
48
  @st.cache_resource
49
  def load_story_model():
50
- """Load story generation model"""
51
  try:
52
- tokenizer = AutoTokenizer.from_pretrained("pranavpsv/gpt-genre-story-generator")
53
- model = AutoModelForCausalLM.from_pretrained("pranavpsv/gpt-genre-story-generator", use_auth_token=True)
54
  logger.info("Stage 2 model loaded")
55
  return tokenizer, model
56
  except Exception as e:
57
- st.error("❌ Story model failed to load")
58
  raise
59
 
60
  def stage2_generate_story(keyword):
@@ -62,27 +62,28 @@ def stage2_generate_story(keyword):
62
  tokenizer, model = load_story_model()
63
 
64
  # Optimized prompt template
65
- prompt = f"""Generate a children's story with:
66
  - Theme: {keyword}
67
  - Characters: Animals
68
- - Word count: 100 words
69
 
70
- Story: Once upon a time, a little bear named Honey discovered"""
71
 
72
  try:
73
  inputs = tokenizer(prompt, return_tensors="pt", max_length=100, truncation=True)
74
  outputs = model.generate(
75
  inputs.input_ids,
76
  max_length=300,
77
- temperature=0.9,
78
  top_k=50,
79
- repetition_penalty=1.2
 
80
  )
81
  full_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
82
  return full_text.replace(prompt, "").strip()
83
  except Exception as e:
84
  st.error(f"Story generation failed: {str(e)}")
85
- return "The animals had a wonderful adventure!"
86
 
87
  # ======================
88
  # Stage 3: Text-to-Speech
 
26
  logger.info("Stage 1 model loaded")
27
  return processor, model
28
  except Exception as e:
29
+ st.error("❌ Failed to load image model")
30
  raise
31
 
32
  def stage1_generate_caption(uploaded_file):
 
34
  processor, model = load_image_model()
35
  try:
36
  img = Image.open(uploaded_file).convert("RGB")
37
+ img.thumbnail((512, 512)) # Optimize image size
38
  inputs = processor(images=img, return_tensors="pt", padding=True)
39
  outputs = model.generate(**inputs, max_length=30)
40
  return processor.decode(outputs[0], skip_special_tokens=True)
 
47
  # ======================
48
  @st.cache_resource
49
  def load_story_model():
50
+ """Load reliable story generation model"""
51
  try:
52
+ tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium")
53
+ model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-medium")
54
  logger.info("Stage 2 model loaded")
55
  return tokenizer, model
56
  except Exception as e:
57
+ st.error("❌ Failed to load story model")
58
  raise
59
 
60
  def stage2_generate_story(keyword):
 
62
  tokenizer, model = load_story_model()
63
 
64
  # Optimized prompt template
65
+ prompt = f"""Write a children's story with:
66
  - Theme: {keyword}
67
  - Characters: Animals
68
+ - Length: 100 words
69
 
70
+ Story: Once upon a time, a little bear named Honey found"""
71
 
72
  try:
73
  inputs = tokenizer(prompt, return_tensors="pt", max_length=100, truncation=True)
74
  outputs = model.generate(
75
  inputs.input_ids,
76
  max_length=300,
77
+ temperature=0.85,
78
  top_k=50,
79
+ repetition_penalty=1.2,
80
+ pad_token_id=tokenizer.eos_token_id
81
  )
82
  full_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
83
  return full_text.replace(prompt, "").strip()
84
  except Exception as e:
85
  st.error(f"Story generation failed: {str(e)}")
86
+ return "The animals had a wonderful day playing together!"
87
 
88
  # ======================
89
  # Stage 3: Text-to-Speech