Spaces:
Running
on
L4
Running
on
L4
| import time | |
| import json | |
| from pathlib import Path | |
| import torch.multiprocessing as mp | |
| from mm_story_agent.modality_agents.story_agent import QAOutlineStoryWriter | |
| from mm_story_agent.modality_agents.speech_agent import CosyVoiceAgent | |
| from mm_story_agent.modality_agents.sound_agent import AudioLDM2Agent | |
| from mm_story_agent.modality_agents.music_agent import MusicGenAgent | |
| from mm_story_agent.modality_agents.image_agent import StoryDiffusionAgent | |
| from mm_story_agent.video_compose_agent import VideoComposeAgent | |
| class MMStoryAgent: | |
| def __init__(self) -> None: | |
| self.modalities = ["image", "sound", "speech", "music"] | |
| self.modality_agent_class = { | |
| "image": StoryDiffusionAgent, | |
| "sound": AudioLDM2Agent, | |
| "speech": CosyVoiceAgent, | |
| "music": MusicGenAgent | |
| } | |
| self.modality_devices = { | |
| "image": "cuda:0", | |
| "sound": "cuda:1", | |
| "music": "cuda:2", | |
| "speech": "cuda:3" | |
| } | |
| self.agents = {} | |
| def call_modality_agent(self, agent, device, pages, save_path, return_dict): | |
| result = agent.call(pages, device, save_path) | |
| modality = result["modality"] | |
| return_dict[modality] = result | |
| def write_story(self, config): | |
| story_writer = QAOutlineStoryWriter(config["story_gen_config"]) | |
| pages = story_writer.call(config["story_setting"]) | |
| return pages | |
| def generate_speech(self, config, pages): | |
| story_dir = Path(config["story_dir"]) | |
| (story_dir / "speech").mkdir(exist_ok=True, parents=True) | |
| speech_agent = CosyVoiceAgent(config["speech_generation"]) | |
| speech_agent.call(pages, story_dir / "speech") | |
| def generate_sound(self, config, pages): | |
| story_dir = Path(config["story_dir"]) | |
| (story_dir / "sound").mkdir(exist_ok=True, parents=True) | |
| sound_agent = AudioLDM2Agent(config["sound_generation"]) | |
| sound_agent.call(pages, story_dir / "sound") | |
| def generate_music(self, config, pages): | |
| story_dir = Path(config["story_dir"]) | |
| (story_dir / "music").mkdir(exist_ok=True, parents=True) | |
| music_agent = MusicGenAgent(config["music_generation"]) | |
| music_agent.call(pages, story_dir / "music") | |
| def generate_image(self, config, pages): | |
| story_dir = Path(config["story_dir"]) | |
| (story_dir / "image").mkdir(exist_ok=True, parents=True) | |
| image_agent = StoryDiffusionAgent(config["image_generation"]) | |
| image_agent.call(pages, story_dir / "image") | |
| def generate_modality_assets(self, config, pages): | |
| script_data = {"pages": [{"story": page} for page in pages]} | |
| story_dir = Path(config["story_dir"]) | |
| for sub_dir in self.modalities: | |
| (story_dir / sub_dir).mkdir(exist_ok=True, parents=True) | |
| agents = {} | |
| for modality in self.modalities: | |
| agents[modality] = self.modality_agent_class[modality](config[modality + "_generation"]) | |
| processes = [] | |
| return_dict = mp.Manager().dict() | |
| for modality in self.modalities: | |
| p = mp.Process(target=self.call_modality_agent, args=(agents[modality], self.modality_devices[modality], pages, story_dir / modality, return_dict), daemon=False) | |
| processes.append(p) | |
| p.start() | |
| for p in processes: | |
| p.join() | |
| for modality, result in return_dict.items(): | |
| try: | |
| if result["modality"] == "image": | |
| images = result["generation_results"] | |
| for idx in range(len(pages)): | |
| script_data["pages"][idx]["image_prompt"] = result["prompts"][idx] | |
| elif result["modality"] == "sound": | |
| for idx in range(len(pages)): | |
| script_data["pages"][idx]["sound_prompt"] = result["prompts"][idx] | |
| elif result["modality"] == "music": | |
| script_data["music_prompt"] = result["prompt"] | |
| except Exception as e: | |
| print(f"Error occurred during generation: {e}") | |
| with open(story_dir / "script_data.json", "w") as writer: | |
| json.dump(script_data, writer, ensure_ascii=False, indent=4) | |
| return images | |
| def compose_storytelling_video(self, config, pages): | |
| video_compose_agent = VideoComposeAgent() | |
| video_compose_agent.call(pages, config) | |
| def call(self, config): | |
| pages = self.write_story(config) | |
| images = self.generate_modality_assets(config, pages) | |
| self.compose_storytelling_video(config, pages) | |