Story_Poem / app.py
UmaGeeth's picture
Update app.py
4e40230 verified
import gradio as gr
from PIL import Image
import torch
from transformers import (
BlipProcessor, BlipForConditionalGeneration,
AutoTokenizer, AutoModelForCausalLM
)
from gtts import gTTS
import tempfile
import os
# Set device
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load BLIP model for image captioning
blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base").to(device)
# Load Falcon model for story/poem generation
gpt_tokenizer = AutoTokenizer.from_pretrained("tiiuae/falcon-rw-1b", trust_remote_code=True)
gpt_model = AutoModelForCausalLM.from_pretrained("tiiuae/falcon-rw-1b", trust_remote_code=True).to(device)
# Generate image caption
def generate_caption(image):
inputs = blip_processor(image, return_tensors="pt").to(device)
out = blip_model.generate(**inputs)
caption = blip_processor.decode(out[0], skip_special_tokens=True)
return caption
# Generate story or poem from caption, theme, characters
def generate_text(caption, theme, characters, content_type):
if content_type.lower() == "story":
prompt = f"{caption}. This inspired a story about {theme.lower()}"
if characters:
prompt += f" involving {characters}"
prompt += ". It begins like this:\n"
else:
prompt = f"{caption}. A poem themed around '{theme}'"
if characters:
prompt += f", mentioning {characters}"
prompt += ":\n"
input_ids = gpt_tokenizer.encode(prompt, return_tensors="pt").to(device)
output_ids = gpt_model.generate(
input_ids,
max_length=300,
do_sample=True,
temperature=0.9,
top_k=50,
top_p=0.95,
eos_token_id=gpt_tokenizer.eos_token_id,
pad_token_id=gpt_tokenizer.pad_token_id or gpt_tokenizer.eos_token_id
)
output = gpt_tokenizer.decode(output_ids[0], skip_special_tokens=True)
return output[len(prompt):].strip()
# Main function
def generate_output(image, theme, characters, content_type):
caption = generate_caption(image)
generated_text = generate_text(caption, theme, characters, content_type)
# Save text to .txt file
txt_file = tempfile.NamedTemporaryFile(delete=False, suffix=".txt", mode="w", encoding="utf-8")
txt_file.write(generated_text)
txt_file.close()
# Generate audio with gTTS (English only)
audio_path = tempfile.NamedTemporaryFile(delete=False, suffix=".mp3").name
try:
tts = gTTS(text=generated_text, lang="en")
tts.save(audio_path)
except Exception as e:
return f"Audio generation error: {str(e)}", txt_file.name, None
return generated_text, txt_file.name, audio_path
# Gradio UI
with gr.Blocks(title="AI Story & Poem Generator") as demo:
gr.Markdown("## 🎭 AI Story & Poem Generator")
gr.Markdown("Upload an image, enter a theme and characters, and get a creative story or poem with audio!")
with gr.Row():
image = gr.Image(type="pil", label="πŸ–ΌοΈ Upload Image")
with gr.Row():
theme = gr.Textbox(label="🎨 Theme (e.g., Adventure, Friendship, Dreams)")
characters = gr.Textbox(label="πŸ§‘β€πŸ€β€πŸ§‘ Characters (Optional)")
content_type = gr.Radio(["Poem", "Story"], label="πŸ“ Choose Content Type")
generate_btn = gr.Button("✨ Generate")
output_text = gr.Textbox(label="πŸ“œ Generated Output", lines=10)
txt_file = gr.File(label="πŸ“„ Download Text")
audio_file = gr.Audio(label="πŸ”Š Listen to Audio")
generate_btn.click(
fn=generate_output,
inputs=[image, theme, characters, content_type],
outputs=[output_text, txt_file, audio_file]
)
if __name__ == "__main__":
demo.launch()