import gradio as gr from PIL import Image import torch from transformers import ( BlipProcessor, BlipForConditionalGeneration, AutoTokenizer, AutoModelForCausalLM ) from gtts import gTTS import tempfile import os # Set device device = "cuda" if torch.cuda.is_available() else "cpu" # Load BLIP model for image captioning blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base") blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base").to(device) # Load Falcon model for story/poem generation gpt_tokenizer = AutoTokenizer.from_pretrained("tiiuae/falcon-rw-1b", trust_remote_code=True) gpt_model = AutoModelForCausalLM.from_pretrained("tiiuae/falcon-rw-1b", trust_remote_code=True).to(device) # Generate image caption def generate_caption(image): inputs = blip_processor(image, return_tensors="pt").to(device) out = blip_model.generate(**inputs) caption = blip_processor.decode(out[0], skip_special_tokens=True) return caption # Generate story or poem from caption, theme, characters def generate_text(caption, theme, characters, content_type): if content_type.lower() == "story": prompt = f"{caption}. This inspired a story about {theme.lower()}" if characters: prompt += f" involving {characters}" prompt += ". It begins like this:\n" else: prompt = f"{caption}. A poem themed around '{theme}'" if characters: prompt += f", mentioning {characters}" prompt += ":\n" input_ids = gpt_tokenizer.encode(prompt, return_tensors="pt").to(device) output_ids = gpt_model.generate( input_ids, max_length=300, do_sample=True, temperature=0.9, top_k=50, top_p=0.95, eos_token_id=gpt_tokenizer.eos_token_id, pad_token_id=gpt_tokenizer.pad_token_id or gpt_tokenizer.eos_token_id ) output = gpt_tokenizer.decode(output_ids[0], skip_special_tokens=True) return output[len(prompt):].strip() # Main function def generate_output(image, theme, characters, content_type): caption = generate_caption(image) generated_text = generate_text(caption, theme, characters, content_type) # Save text to .txt file txt_file = tempfile.NamedTemporaryFile(delete=False, suffix=".txt", mode="w", encoding="utf-8") txt_file.write(generated_text) txt_file.close() # Generate audio with gTTS (English only) audio_path = tempfile.NamedTemporaryFile(delete=False, suffix=".mp3").name try: tts = gTTS(text=generated_text, lang="en") tts.save(audio_path) except Exception as e: return f"Audio generation error: {str(e)}", txt_file.name, None return generated_text, txt_file.name, audio_path # Gradio UI with gr.Blocks(title="AI Story & Poem Generator") as demo: gr.Markdown("## 🎭 AI Story & Poem Generator") gr.Markdown("Upload an image, enter a theme and characters, and get a creative story or poem with audio!") with gr.Row(): image = gr.Image(type="pil", label="🖼️ Upload Image") with gr.Row(): theme = gr.Textbox(label="🎨 Theme (e.g., Adventure, Friendship, Dreams)") characters = gr.Textbox(label="🧑‍🤝‍🧑 Characters (Optional)") content_type = gr.Radio(["Poem", "Story"], label="📝 Choose Content Type") generate_btn = gr.Button("✨ Generate") output_text = gr.Textbox(label="📜 Generated Output", lines=10) txt_file = gr.File(label="📄 Download Text") audio_file = gr.Audio(label="🔊 Listen to Audio") generate_btn.click( fn=generate_output, inputs=[image, theme, characters, content_type], outputs=[output_text, txt_file, audio_file] ) if __name__ == "__main__": demo.launch()