Spaces:
Running
Running
| import gradio as gr | |
| from transformers import pipeline, AutoProcessor, AutoModelForCausalLM | |
| from diffusers import StableDiffusionPipeline, DiffusionPipeline | |
| import torch | |
| from PIL import Image | |
| import numpy as np | |
| import os | |
| import tempfile | |
| import moviepy.editor as mpe | |
| import nltk | |
| from pydub import AudioSegment | |
| import warnings | |
| import asyncio | |
| import edge_tts | |
| import random | |
| from datetime import datetime | |
| import pytz | |
| import re | |
| import json | |
| from gradio_client import Client | |
| warnings.filterwarnings("ignore", category=UserWarning) | |
| # Ensure NLTK data is downloaded | |
| nltk.download('punkt') | |
| # Initialize clients | |
| arxiv_client = None | |
| def init_arxiv_client(): | |
| global arxiv_client | |
| if arxiv_client is None: | |
| arxiv_client = Client("awacke1/Arxiv-Paper-Search-And-QA-RAG-Pattern") | |
| return arxiv_client | |
| # File I/O Functions | |
| def generate_filename(prompt, timestamp=None): | |
| """Generate a safe filename from prompt and timestamp""" | |
| if timestamp is None: | |
| timestamp = datetime.now(pytz.UTC).strftime("%Y%m%d_%H%M%S") | |
| # Clean the prompt to create a safe filename | |
| safe_prompt = re.sub(r'[^\w\s-]', '', prompt)[:50].strip() | |
| return f"story_{timestamp}_{safe_prompt}.txt" | |
| def save_story(story, prompt, filename=None): | |
| """Save story to file with metadata""" | |
| if filename is None: | |
| filename = generate_filename(prompt) | |
| try: | |
| with open(filename, 'w', encoding='utf-8') as f: | |
| metadata = { | |
| 'timestamp': datetime.now().isoformat(), | |
| 'prompt': prompt, | |
| 'type': 'story' | |
| } | |
| f.write(json.dumps(metadata) + '\n---\n' + story) | |
| return filename | |
| except Exception as e: | |
| print(f"Error saving story: {e}") | |
| return None | |
| def load_story(filename): | |
| """Load story and metadata from file""" | |
| try: | |
| with open(filename, 'r', encoding='utf-8') as f: | |
| content = f.read() | |
| parts = content.split('\n---\n') | |
| if len(parts) == 2: | |
| metadata = json.loads(parts[0]) | |
| story = parts[1] | |
| return metadata, story | |
| return None, content | |
| except Exception as e: | |
| print(f"Error loading story: {e}") | |
| return None, None | |
| # Story Generation Functions | |
| def generate_story(prompt, model_choice): | |
| """Generate story using specified model""" | |
| try: | |
| client = init_arxiv_client() | |
| if client is None: | |
| return "Error: Story generation service is not available." | |
| result = client.predict( | |
| prompt=prompt, | |
| llm_model_picked=model_choice, | |
| stream_outputs=True, | |
| api_name="/ask_llm" | |
| ) | |
| return result | |
| except Exception as e: | |
| return f"Error generating story: {str(e)}" | |
| async def generate_speech(text, voice="en-US-AriaNeural"): | |
| """Generate speech from text""" | |
| try: | |
| communicate = edge_tts.Communicate(text, voice) | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as tmp_file: | |
| tmp_path = tmp_file.name | |
| await communicate.save(tmp_path) | |
| return tmp_path | |
| except Exception as e: | |
| print(f"Error in text2speech: {str(e)}") | |
| return None | |
| def process_story_and_audio(prompt, model_choice): | |
| """Process story and generate audio""" | |
| try: | |
| # Generate story | |
| story = generate_story(prompt, model_choice) | |
| if isinstance(story, str) and story.startswith("Error"): | |
| return story, None, None | |
| # Save story | |
| filename = save_story(story, prompt) | |
| # Generate audio | |
| audio_path = asyncio.run(generate_speech(story)) | |
| return story, audio_path, filename | |
| except Exception as e: | |
| return f"Error: {str(e)}", None, None | |
| # Main App Code (your existing code remains here) | |
| # LLM Inference Class and other existing classes remain unchanged | |
| class LLMInferenceNode: | |
| # Your existing LLMInferenceNode implementation | |
| pass | |
| # Initialize models (your existing initialization code remains here) | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| torch_dtype = torch.float16 if device == "cuda" else torch.float32 | |
| # Story generator | |
| story_generator = pipeline( | |
| 'text-generation', | |
| model='gpt2-large', | |
| device=0 if device == 'cuda' else -1 | |
| ) | |
| # Stable Diffusion model | |
| sd_pipe = StableDiffusionPipeline.from_pretrained( | |
| "runwayml/stable-diffusion-v1-5", | |
| torch_dtype=torch_dtype | |
| ).to(device) | |
| # Create the enhanced Gradio interface | |
| with gr.Blocks() as demo: | |
| gr.Markdown("""# ๐จ AI Creative Suite | |
| Generate videos, stories, and more with AI! | |
| """) | |
| with gr.Tabs(): | |
| # Your existing video generation tab | |
| with gr.Tab("Video Generation"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| prompt_input = gr.Textbox(label="Enter a Prompt", lines=2) | |
| generate_button = gr.Button("Generate Video") | |
| with gr.Column(): | |
| video_output = gr.Video(label="Generated Video") | |
| generate_button.click(fn=process_pipeline, inputs=prompt_input, outputs=video_output) | |
| # New story generation tab | |
| with gr.Tab("Story Generation"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| story_prompt = gr.Textbox( | |
| label="Story Concept", | |
| placeholder="Enter your story idea...", | |
| lines=3 | |
| ) | |
| model_choice = gr.Dropdown( | |
| label="Model", | |
| choices=[ | |
| "mistralai/Mixtral-8x7B-Instruct-v0.1", | |
| "mistralai/Mistral-7B-Instruct-v0.2" | |
| ], | |
| value="mistralai/Mixtral-8x7B-Instruct-v0.1" | |
| ) | |
| generate_story_btn = gr.Button("Generate Story") | |
| with gr.Row(): | |
| story_output = gr.Textbox( | |
| label="Generated Story", | |
| lines=10, | |
| interactive=False | |
| ) | |
| with gr.Row(): | |
| audio_output = gr.Audio( | |
| label="Story Narration", | |
| type="filepath" | |
| ) | |
| filename_output = gr.Textbox( | |
| label="Saved Filename", | |
| interactive=False | |
| ) | |
| generate_story_btn.click( | |
| fn=process_story_and_audio, | |
| inputs=[story_prompt, model_choice], | |
| outputs=[story_output, audio_output, filename_output] | |
| ) | |
| # File management section | |
| with gr.Row(): | |
| file_list = gr.Dropdown( | |
| label="Saved Stories", | |
| choices=[f for f in os.listdir() if f.startswith("story_") and f.endswith(".txt")], | |
| interactive=True | |
| ) | |
| refresh_btn = gr.Button("๐ Refresh") | |
| def refresh_files(): | |
| return gr.Dropdown(choices=[f for f in os.listdir() if f.startswith("story_") and f.endswith(".txt")]) | |
| refresh_btn.click(fn=refresh_files, outputs=[file_list]) | |
| # Launch the app | |
| if __name__ == "__main__": | |
| demo.launch(debug=True) |