Update app.py
Browse files
app.py
CHANGED
|
@@ -26,7 +26,7 @@ def load_image_model():
|
|
| 26 |
logger.info("Stage 1 model loaded")
|
| 27 |
return processor, model
|
| 28 |
except Exception as e:
|
| 29 |
-
st.error("❌
|
| 30 |
raise
|
| 31 |
|
| 32 |
def stage1_generate_caption(uploaded_file):
|
|
@@ -34,7 +34,7 @@ def stage1_generate_caption(uploaded_file):
|
|
| 34 |
processor, model = load_image_model()
|
| 35 |
try:
|
| 36 |
img = Image.open(uploaded_file).convert("RGB")
|
| 37 |
-
img.thumbnail((512, 512))
|
| 38 |
inputs = processor(images=img, return_tensors="pt", padding=True)
|
| 39 |
outputs = model.generate(**inputs, max_length=30)
|
| 40 |
return processor.decode(outputs[0], skip_special_tokens=True)
|
|
@@ -47,14 +47,14 @@ def stage1_generate_caption(uploaded_file):
|
|
| 47 |
# ======================
|
| 48 |
@st.cache_resource
|
| 49 |
def load_story_model():
|
| 50 |
-
"""Load story generation model"""
|
| 51 |
try:
|
| 52 |
-
tokenizer = AutoTokenizer.from_pretrained("
|
| 53 |
-
model = AutoModelForCausalLM.from_pretrained("
|
| 54 |
logger.info("Stage 2 model loaded")
|
| 55 |
return tokenizer, model
|
| 56 |
except Exception as e:
|
| 57 |
-
st.error("❌
|
| 58 |
raise
|
| 59 |
|
| 60 |
def stage2_generate_story(keyword):
|
|
@@ -62,27 +62,28 @@ def stage2_generate_story(keyword):
|
|
| 62 |
tokenizer, model = load_story_model()
|
| 63 |
|
| 64 |
# Optimized prompt template
|
| 65 |
-
prompt = f"""
|
| 66 |
- Theme: {keyword}
|
| 67 |
- Characters: Animals
|
| 68 |
-
-
|
| 69 |
|
| 70 |
-
Story: Once upon a time, a little bear named Honey
|
| 71 |
|
| 72 |
try:
|
| 73 |
inputs = tokenizer(prompt, return_tensors="pt", max_length=100, truncation=True)
|
| 74 |
outputs = model.generate(
|
| 75 |
inputs.input_ids,
|
| 76 |
max_length=300,
|
| 77 |
-
temperature=0.
|
| 78 |
top_k=50,
|
| 79 |
-
repetition_penalty=1.2
|
|
|
|
| 80 |
)
|
| 81 |
full_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 82 |
return full_text.replace(prompt, "").strip()
|
| 83 |
except Exception as e:
|
| 84 |
st.error(f"Story generation failed: {str(e)}")
|
| 85 |
-
return "The animals had a wonderful
|
| 86 |
|
| 87 |
# ======================
|
| 88 |
# Stage 3: Text-to-Speech
|
|
|
|
| 26 |
logger.info("Stage 1 model loaded")
|
| 27 |
return processor, model
|
| 28 |
except Exception as e:
|
| 29 |
+
st.error("❌ Failed to load image model")
|
| 30 |
raise
|
| 31 |
|
| 32 |
def stage1_generate_caption(uploaded_file):
|
|
|
|
| 34 |
processor, model = load_image_model()
|
| 35 |
try:
|
| 36 |
img = Image.open(uploaded_file).convert("RGB")
|
| 37 |
+
img.thumbnail((512, 512)) # Optimize image size
|
| 38 |
inputs = processor(images=img, return_tensors="pt", padding=True)
|
| 39 |
outputs = model.generate(**inputs, max_length=30)
|
| 40 |
return processor.decode(outputs[0], skip_special_tokens=True)
|
|
|
|
| 47 |
# ======================
|
| 48 |
@st.cache_resource
|
| 49 |
def load_story_model():
|
| 50 |
+
"""Load reliable story generation model"""
|
| 51 |
try:
|
| 52 |
+
tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium")
|
| 53 |
+
model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-medium")
|
| 54 |
logger.info("Stage 2 model loaded")
|
| 55 |
return tokenizer, model
|
| 56 |
except Exception as e:
|
| 57 |
+
st.error("❌ Failed to load story model")
|
| 58 |
raise
|
| 59 |
|
| 60 |
def stage2_generate_story(keyword):
|
|
|
|
| 62 |
tokenizer, model = load_story_model()
|
| 63 |
|
| 64 |
# Optimized prompt template
|
| 65 |
+
prompt = f"""Write a children's story with:
|
| 66 |
- Theme: {keyword}
|
| 67 |
- Characters: Animals
|
| 68 |
+
- Length: 100 words
|
| 69 |
|
| 70 |
+
Story: Once upon a time, a little bear named Honey found"""
|
| 71 |
|
| 72 |
try:
|
| 73 |
inputs = tokenizer(prompt, return_tensors="pt", max_length=100, truncation=True)
|
| 74 |
outputs = model.generate(
|
| 75 |
inputs.input_ids,
|
| 76 |
max_length=300,
|
| 77 |
+
temperature=0.85,
|
| 78 |
top_k=50,
|
| 79 |
+
repetition_penalty=1.2,
|
| 80 |
+
pad_token_id=tokenizer.eos_token_id
|
| 81 |
)
|
| 82 |
full_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 83 |
return full_text.replace(prompt, "").strip()
|
| 84 |
except Exception as e:
|
| 85 |
st.error(f"Story generation failed: {str(e)}")
|
| 86 |
+
return "The animals had a wonderful day playing together!"
|
| 87 |
|
| 88 |
# ======================
|
| 89 |
# Stage 3: Text-to-Speech
|