Spaces:
Runtime error
Runtime error
| import spaces | |
| import gradio as gr | |
| import os | |
| import asyncio | |
| import torch | |
| import io | |
| import json | |
| import re | |
| import httpx | |
| import tempfile | |
| import wave | |
| import base64 | |
| import numpy as np | |
| import soundfile as sf | |
| import subprocess | |
| import shutil | |
| import requests | |
| import logging | |
| import random | |
| from datetime import datetime, timedelta | |
| from typing import List, Tuple, Dict, Optional | |
| from pathlib import Path | |
| from threading import Thread | |
| from dotenv import load_dotenv | |
| # PDF processing imports | |
| from langchain_community.document_loaders import PyPDFLoader | |
| # OpenAI imports | |
| from openai import OpenAI | |
| # Transformers imports (for legacy local mode) | |
| from transformers import ( | |
| AutoModelForCausalLM, | |
| AutoTokenizer, | |
| TextIteratorStreamer, | |
| BitsAndBytesConfig, | |
| ) | |
| # Llama CPP imports (for new local mode) | |
| try: | |
| from llama_cpp import Llama | |
| from llama_cpp_agent import LlamaCppAgent, MessagesFormatterType | |
| from llama_cpp_agent.providers import LlamaCppPythonProvider | |
| from llama_cpp_agent.chat_history import BasicChatHistory | |
| from llama_cpp_agent.chat_history.messages import Roles | |
| from huggingface_hub import hf_hub_download | |
| LLAMA_CPP_AVAILABLE = True | |
| except ImportError: | |
| LLAMA_CPP_AVAILABLE = False | |
| # Chatterbox TTS imports | |
| try: | |
| from chatterbox.src.chatterbox.tts import ChatterboxTTS | |
| CHATTERBOX_AVAILABLE = True | |
| print("โ Chatterbox TTS imported successfully from chatterbox.src.chatterbox.tts") | |
| except ImportError: | |
| try: | |
| from chatterbox.tts import ChatterboxTTS | |
| CHATTERBOX_AVAILABLE = True | |
| print("โ Chatterbox TTS imported successfully from chatterbox.tts") | |
| except ImportError: | |
| try: | |
| # ๋ค๋ฅธ ๊ฐ๋ฅํ ๊ฒฝ๋ก ์๋ | |
| import sys | |
| sys.path.append('/usr/local/lib/python3.10/site-packages') | |
| from chatterbox import ChatterboxTTS | |
| CHATTERBOX_AVAILABLE = True | |
| print("โ Chatterbox TTS imported successfully from chatterbox") | |
| except ImportError: | |
| CHATTERBOX_AVAILABLE = False | |
| print("โ Chatterbox TTS not available - falling back to text-only mode") | |
| # Import config and prompts | |
| from config_prompts import ( | |
| ConversationConfig, | |
| PromptBuilder, | |
| DefaultConversations, | |
| ) | |
| load_dotenv() | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"๐ Running on device: {DEVICE}") | |
| # Brave Search API ์ค์ | |
| BRAVE_KEY = os.getenv("BSEARCH_API") | |
| BRAVE_ENDPOINT = "https://api.search.brave.com/res/v1/web/search" | |
| def set_seed(seed: int): | |
| """Sets the random seed for reproducibility across torch, numpy, and random.""" | |
| torch.manual_seed(seed) | |
| if DEVICE == "cuda": | |
| torch.cuda.manual_seed(seed) | |
| torch.cuda.manual_seed_all(seed) | |
| random.seed(seed) | |
| np.random.seed(seed) | |
| def split_text_into_chunks(text: str, max_chars: int = 250) -> list[str]: | |
| """ | |
| ํ ์คํธ๋ฅผ ๋ฌธ์ฅ ๋จ์๋ก ๋๋๋, ๊ฐ ์ฒญํฌ๊ฐ max_chars๋ฅผ ๋์ง ์๋๋ก ํฉ๋๋ค. | |
| """ | |
| # ๋ฌธ์ฅ ๋จ์๋ก ๋ถ๋ฆฌ (๊ธฐ๋ณธ์ ์ธ ๋ฌธ์ฅ ๋ถ๋ฆฌ) | |
| sentences = re.split(r'(?<=[.!?])\s+', text.strip()) | |
| chunks = [] | |
| current_chunk = "" | |
| for sentence in sentences: | |
| # ํ์ฌ ์ฒญํฌ์ ๋ฌธ์ฅ์ ์ถ๊ฐํด๋ max_chars๋ฅผ ๋์ง ์์ผ๋ฉด ์ถ๊ฐ | |
| if len(current_chunk) + len(sentence) + 1 <= max_chars: | |
| if current_chunk: | |
| current_chunk += " " + sentence | |
| else: | |
| current_chunk = sentence | |
| else: | |
| # ํ์ฌ ์ฒญํฌ๋ฅผ ์ ์ฅํ๊ณ ์ ์ฒญํฌ ์์ | |
| if current_chunk: | |
| chunks.append(current_chunk) | |
| # ๋ฌธ์ฅ ์์ฒด๊ฐ max_chars๋ณด๋ค ๊ธด ๊ฒฝ์ฐ ๊ฐ์ ๋ก ๋ถํ | |
| if len(sentence) > max_chars: | |
| words = sentence.split() | |
| temp_chunk = "" | |
| for word in words: | |
| if len(temp_chunk) + len(word) + 1 <= max_chars: | |
| if temp_chunk: | |
| temp_chunk += " " + word | |
| else: | |
| temp_chunk = word | |
| else: | |
| if temp_chunk: | |
| chunks.append(temp_chunk) | |
| temp_chunk = word | |
| if temp_chunk: | |
| current_chunk = temp_chunk | |
| else: | |
| current_chunk = sentence | |
| # ๋ง์ง๋ง ์ฒญํฌ ์ถ๊ฐ | |
| if current_chunk: | |
| chunks.append(current_chunk) | |
| return chunks | |
| def brave_search(query: str, count: int = 8, freshness_days: int | None = None): | |
| """Brave Search API๋ฅผ ์ฌ์ฉํ์ฌ ์ต์ ์ ๋ณด ๊ฒ์""" | |
| if not BRAVE_KEY: | |
| return [] | |
| params = {"q": query, "count": str(count)} | |
| if freshness_days: | |
| dt_from = (datetime.utcnow() - timedelta(days=freshness_days)).strftime("%Y-%m-%d") | |
| params["freshness"] = dt_from | |
| try: | |
| r = requests.get( | |
| BRAVE_ENDPOINT, | |
| headers={"Accept": "application/json", "X-Subscription-Token": BRAVE_KEY}, | |
| params=params, | |
| timeout=15 | |
| ) | |
| raw = r.json().get("web", {}).get("results") or [] | |
| return [{ | |
| "title": r.get("title", ""), | |
| "url": r.get("url", r.get("link", "")), | |
| "snippet": r.get("description", r.get("text", "")), | |
| "host": re.sub(r"https?://(www\.)?", "", r.get("url", "")).split("/")[0] | |
| } for r in raw[:count]] | |
| except Exception as e: | |
| logging.error(f"Brave search error: {e}") | |
| return [] | |
| def format_search_results(query: str, for_keyword: bool = False) -> str: | |
| """๊ฒ์ ๊ฒฐ๊ณผ๋ฅผ ํฌ๋งทํ ํ์ฌ ๋ฐํ""" | |
| # ํค์๋ ๊ฒ์์ ๊ฒฝ์ฐ ๋ ๋ง์ ๊ฒฐ๊ณผ ์ฌ์ฉ | |
| count = 5 if for_keyword else 3 | |
| rows = brave_search(query, count, freshness_days=7 if not for_keyword else None) | |
| if not rows: | |
| return "" | |
| results = [] | |
| # ํค์๋ ๊ฒ์์ ๊ฒฝ์ฐ ๋ ์์ธํ ์ ๋ณด ํฌํจ | |
| max_results = 4 if for_keyword else 2 | |
| for r in rows[:max_results]: | |
| if for_keyword: | |
| # ํค์๋ ๊ฒ์์ ๋ ๊ธด ์ค๋ํซ ์ฌ์ฉ | |
| snippet = r['snippet'][:200] + "..." if len(r['snippet']) > 200 else r['snippet'] | |
| results.append(f"**{r['title']}**\n{snippet}\nSource: {r['host']}") | |
| else: | |
| # ์ผ๋ฐ ๊ฒ์์ ์งง์ ์ค๋ํซ | |
| snippet = r['snippet'][:100] + "..." if len(r['snippet']) > 100 else r['snippet'] | |
| results.append(f"- {r['title']}: {snippet}") | |
| return "\n\n".join(results) + "\n" | |
| def extract_keywords_for_search(text: str, language: str = "English") -> List[str]: | |
| """ํ ์คํธ์์ ๊ฒ์ํ ํค์๋ ์ถ์ถ""" | |
| # ํ ์คํธ ์๋ถ๋ถ๋ง ์ฌ์ฉ (๋๋ฌด ๋ง์ ํ ์คํธ ์ฒ๋ฆฌ ๋ฐฉ์ง) | |
| text_sample = text[:500] | |
| # ์์ด๋ ๋๋ฌธ์๋ก ์์ํ๋ ๋จ์ด ์ค ๊ฐ์ฅ ๊ธด ๊ฒ 1๊ฐ | |
| words = text_sample.split() | |
| keywords = [word.strip('.,!?;:') for word in words | |
| if len(word) > 4 and word[0].isupper()] | |
| if keywords: | |
| return [max(keywords, key=len)] # ๊ฐ์ฅ ๊ธด ๋จ์ด 1๊ฐ | |
| return [] | |
| def search_and_compile_content(keyword: str, language: str = "English") -> str: | |
| """ํค์๋๋ก ๊ฒ์ํ์ฌ ์ถฉ๋ถํ ์ฝํ ์ธ ์ปดํ์ผ""" | |
| if not BRAVE_KEY: | |
| # API ์์ ๋๋ ๊ธฐ๋ณธ ์ฝํ ์ธ ์์ฑ | |
| return f""" | |
| Comprehensive information about '{keyword}': | |
| {keyword} is a significant topic in modern society. | |
| This subject impacts our lives in various ways and has been | |
| gaining increasing attention recently. | |
| Key aspects: | |
| 1. Technological advancement and innovation | |
| 2. Social impact and changes | |
| 3. Future prospects and possibilities | |
| 4. Practical applications | |
| 5. Global trends and developments | |
| Experts predict that {keyword} will become even more important, | |
| and it's crucial to develop a deep understanding of this topic. | |
| """ | |
| # ์์ด ๊ฒ์ ์ฟผ๋ฆฌ | |
| queries = [ | |
| f"{keyword} latest news 2024", | |
| f"{keyword} explained comprehensive", | |
| f"{keyword} trends forecast", | |
| f"{keyword} advantages disadvantages", | |
| f"{keyword} how to use", | |
| f"{keyword} expert opinions" | |
| ] | |
| all_content = [] | |
| total_content_length = 0 | |
| for query in queries: | |
| results = brave_search(query, count=5) # ๋ ๋ง์ ๊ฒฐ๊ณผ ๊ฐ์ ธ์ค๊ธฐ | |
| for r in results[:3]: # ๊ฐ ์ฟผ๋ฆฌ๋น ์์ 3๊ฐ | |
| content = f"**{r['title']}**\n{r['snippet']}\nSource: {r['host']}\n" | |
| all_content.append(content) | |
| total_content_length += len(r['snippet']) | |
| # ์ฝํ ์ธ ๊ฐ ๋ถ์กฑํ๋ฉด ์ถ๊ฐ ์์ฑ | |
| if total_content_length < 1000: # ์ต์ 1000์ ํ๋ณด | |
| additional_content = f""" | |
| Additional insights: | |
| Recent developments in {keyword} show rapid advancement in this field. | |
| Many experts are actively researching this topic, and its practical | |
| applications continue to expand. | |
| Key points to note: | |
| - Accelerating technological innovation | |
| - Improving user experience | |
| - Enhanced accessibility | |
| - Increased cost efficiency | |
| - Growing global market | |
| These factors are making the future of {keyword} increasingly promising. | |
| """ | |
| all_content.append(additional_content) | |
| # ์ปดํ์ผ๋ ์ฝํ ์ธ ๋ฐํ | |
| compiled = "\n\n".join(all_content) | |
| # ํค์๋ ๊ธฐ๋ฐ ์๊ฐ | |
| intro = f"### Comprehensive information and latest trends about '{keyword}':\n\n" | |
| return intro + compiled | |
| class UnifiedAudioConverter: | |
| def __init__(self, config: ConversationConfig): | |
| self.config = config | |
| self.llm_client = None | |
| self.legacy_local_model = None | |
| self.legacy_tokenizer = None | |
| # ์๋ก์ด ๋ก์ปฌ LLM ๊ด๋ จ | |
| self.local_llm = None | |
| self.local_llm_model = None | |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # ํ๋กฌํํธ ๋น๋ ์ถ๊ฐ | |
| self.prompt_builder = PromptBuilder() | |
| def initialize_api_mode(self, api_key: str): | |
| """Initialize API mode with Together API""" | |
| self.llm_client = OpenAI(api_key=api_key, base_url="https://api.together.xyz/v1") | |
| def initialize_local_mode(self): | |
| """Initialize new local mode with Llama CPP""" | |
| if not LLAMA_CPP_AVAILABLE: | |
| raise RuntimeError("Llama CPP dependencies not available. Please install llama-cpp-python and llama-cpp-agent.") | |
| if self.local_llm is None or self.local_llm_model != self.config.local_model_name: | |
| try: | |
| # ๋ชจ๋ธ ๋ค์ด๋ก๋ | |
| model_path = hf_hub_download( | |
| repo_id=self.config.local_model_repo, | |
| filename=self.config.local_model_name, | |
| local_dir="./models" | |
| ) | |
| model_path_local = os.path.join("./models", self.config.local_model_name) | |
| if not os.path.exists(model_path_local): | |
| raise RuntimeError(f"Model file not found at {model_path_local}") | |
| # Llama ๋ชจ๋ธ ์ด๊ธฐํ | |
| self.local_llm = Llama( | |
| model_path=model_path_local, | |
| flash_attn=True, | |
| n_gpu_layers=81 if torch.cuda.is_available() else 0, | |
| n_batch=1024, | |
| n_ctx=16384, | |
| ) | |
| self.local_llm_model = self.config.local_model_name | |
| print(f"Local LLM initialized: {model_path_local}") | |
| except Exception as e: | |
| print(f"Failed to initialize local LLM: {e}") | |
| raise RuntimeError(f"Failed to initialize local LLM: {e}") | |
| def initialize_legacy_local_mode(self): | |
| """Initialize legacy local mode with Hugging Face model (fallback)""" | |
| if self.legacy_local_model is None: | |
| quantization_config = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_compute_dtype=torch.float16 | |
| ) | |
| self.legacy_local_model = AutoModelForCausalLM.from_pretrained( | |
| self.config.legacy_local_model_name, | |
| quantization_config=quantization_config | |
| ) | |
| self.legacy_tokenizer = AutoTokenizer.from_pretrained( | |
| self.config.legacy_local_model_name, | |
| revision='8ab73a6800796d84448bc936db9bac5ad9f984ae' | |
| ) | |
| def fetch_text(self, url: str) -> str: | |
| """Fetch text content from URL""" | |
| if not url: | |
| raise ValueError("URL cannot be empty") | |
| if not url.startswith("http://") and not url.startswith("https://"): | |
| raise ValueError("URL must start with 'http://' or 'https://'") | |
| full_url = f"{self.config.prefix_url}{url}" | |
| try: | |
| response = httpx.get(full_url, timeout=60.0) | |
| response.raise_for_status() | |
| return response.text | |
| except httpx.HTTPError as e: | |
| raise RuntimeError(f"Failed to fetch URL: {e}") | |
| def extract_text_from_pdf(self, pdf_file) -> str: | |
| """Extract text content from PDF file""" | |
| try: | |
| # Gradio returns file path, not file object | |
| if isinstance(pdf_file, str): | |
| pdf_path = pdf_file | |
| else: | |
| # If it's a file object (shouldn't happen with Gradio) | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") as tmp_file: | |
| tmp_file.write(pdf_file.read()) | |
| pdf_path = tmp_file.name | |
| # PDF ๋ก๋ ๋ฐ ํ ์คํธ ์ถ์ถ | |
| loader = PyPDFLoader(pdf_path) | |
| pages = loader.load() | |
| # ๋ชจ๋ ํ์ด์ง์ ํ ์คํธ๋ฅผ ๊ฒฐํฉ | |
| text = "\n".join([page.page_content for page in pages]) | |
| # ์์ ํ์ผ์ธ ๊ฒฝ์ฐ ์ญ์ | |
| if not isinstance(pdf_file, str) and os.path.exists(pdf_path): | |
| os.unlink(pdf_path) | |
| return text | |
| except Exception as e: | |
| raise RuntimeError(f"Failed to extract text from PDF: {e}") | |
| def _get_messages_formatter_type(self, model_name): | |
| """Get appropriate message formatter for the model""" | |
| if "Mistral" in model_name or "BitSix" in model_name: | |
| return MessagesFormatterType.CHATML | |
| else: | |
| return MessagesFormatterType.LLAMA_3 | |
| def extract_conversation_local(self, text: str, language: str = "English", progress=None) -> Dict: | |
| """Extract conversation using new local LLM with enhanced professional style""" | |
| try: | |
| # ๊ฒ์ ์ปจํ ์คํธ ์์ฑ | |
| search_context = "" | |
| if BRAVE_KEY and not text.startswith("Keyword-based content:"): | |
| try: | |
| keywords = extract_keywords_for_search(text, language) | |
| if keywords: | |
| search_query = f"{keywords[0]} latest news" | |
| search_context = format_search_results(search_query) | |
| print(f"Search context added for: {search_query}") | |
| except Exception as e: | |
| print(f"Search failed, continuing without context: {e}") | |
| # ๋จผ์ ์๋ก์ด ๋ก์ปฌ LLM ์๋ | |
| self.initialize_local_mode() | |
| chat_template = self._get_messages_formatter_type(self.config.local_model_name) | |
| provider = LlamaCppPythonProvider(self.local_llm) | |
| # ์์ด ์ ์ฉ ์์คํ ๋ฉ์์ง | |
| system_message = ( | |
| f"You are a professional podcast scriptwriter creating high-quality, " | |
| f"insightful discussions in English. Create exactly 12 conversation exchanges " | |
| f"with professional expertise. All dialogue must be in English. " | |
| f"Respond only in JSON format." | |
| ) | |
| agent = LlamaCppAgent( | |
| provider, | |
| system_prompt=system_message, | |
| predefined_messages_formatter_type=chat_template, | |
| debug_output=False | |
| ) | |
| settings = provider.get_provider_default_settings() | |
| settings.temperature = 0.75 | |
| settings.top_k = 40 | |
| settings.top_p = 0.95 | |
| settings.max_tokens = self.config.max_tokens | |
| settings.repeat_penalty = 1.1 | |
| settings.stream = False | |
| messages = BasicChatHistory() | |
| prompt = self.prompt_builder.build_prompt(text, language, search_context) | |
| response = agent.get_chat_response( | |
| prompt, | |
| llm_sampling_settings=settings, | |
| chat_history=messages, | |
| returns_streaming_generator=False, | |
| print_output=False | |
| ) | |
| # JSON ํ์ฑ | |
| pattern = r"\{(?:[^{}]|(?:\{[^{}]*\}))*\}" | |
| json_match = re.search(pattern, response) | |
| if json_match: | |
| conversation_data = json.loads(json_match.group()) | |
| return conversation_data | |
| else: | |
| raise ValueError("No valid JSON found in local LLM response") | |
| except Exception as e: | |
| print(f"Local LLM failed: {e}, falling back to legacy local method") | |
| return self.extract_conversation_legacy_local(text, language, progress, search_context) | |
| def extract_conversation_legacy_local(self, text: str, language: str = "English", progress=None, search_context: str = "") -> Dict: | |
| """Extract conversation using legacy local model""" | |
| try: | |
| self.initialize_legacy_local_mode() | |
| # ์์ด ์ ์ฉ ๋ฉ์์ง | |
| messages = self.prompt_builder.build_messages_for_local(text, language, search_context) | |
| terminators = [ | |
| self.legacy_tokenizer.eos_token_id, | |
| self.legacy_tokenizer.convert_tokens_to_ids("<|eot_id|>") | |
| ] | |
| chat_messages = self.legacy_tokenizer.apply_chat_template( | |
| messages, tokenize=False, add_generation_prompt=True | |
| ) | |
| model_inputs = self.legacy_tokenizer([chat_messages], return_tensors="pt").to(self.device) | |
| streamer = TextIteratorStreamer( | |
| self.legacy_tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True | |
| ) | |
| generate_kwargs = dict( | |
| model_inputs, | |
| streamer=streamer, | |
| max_new_tokens=self.config.max_new_tokens, | |
| do_sample=True, | |
| temperature=0.75, | |
| eos_token_id=terminators, | |
| ) | |
| t = Thread(target=self.legacy_local_model.generate, kwargs=generate_kwargs) | |
| t.start() | |
| partial_text = "" | |
| for new_text in streamer: | |
| partial_text += new_text | |
| pattern = r"\{(?:[^{}]|(?:\{[^{}]*\}))*\}" | |
| json_match = re.search(pattern, partial_text) | |
| if json_match: | |
| return json.loads(json_match.group()) | |
| else: | |
| raise ValueError("No valid JSON found in legacy local response") | |
| except Exception as e: | |
| print(f"Legacy local model also failed: {e}") | |
| return DefaultConversations.get_conversation("English") | |
| def extract_conversation_api(self, text: str, language: str = "English") -> Dict: | |
| """Extract conversation using API""" | |
| if not self.llm_client: | |
| raise RuntimeError("API mode not initialized") | |
| try: | |
| # ๊ฒ์ ์ปจํ ์คํธ ์์ฑ | |
| search_context = "" | |
| if BRAVE_KEY and not text.startswith("Keyword-based content:"): | |
| try: | |
| keywords = extract_keywords_for_search(text, language) | |
| if keywords: | |
| search_query = f"{keywords[0]} latest news" | |
| search_context = format_search_results(search_query) | |
| print(f"Search context added for: {search_query}") | |
| except Exception as e: | |
| print(f"Search failed, continuing without context: {e}") | |
| # ๋ฉ์์ง ๋น๋ | |
| messages = self.prompt_builder.build_messages_for_local(text, language, search_context) | |
| chat_completion = self.llm_client.chat.completions.create( | |
| messages=messages, | |
| model=self.config.api_model_name, | |
| temperature=0.75, | |
| ) | |
| pattern = r"\{(?:[^{}]|(?:\{[^{}]*\}))*\}" | |
| json_match = re.search(pattern, chat_completion.choices[0].message.content) | |
| if not json_match: | |
| raise ValueError("No valid JSON found in response") | |
| return json.loads(json_match.group()) | |
| except Exception as e: | |
| raise RuntimeError(f"Failed to extract conversation: {e}") | |
| def parse_conversation_text(self, conversation_text: str) -> Dict: | |
| """Parse conversation text back to JSON format""" | |
| lines = conversation_text.strip().split('\n') | |
| conversation_data = {"conversation": []} | |
| for line in lines: | |
| if ':' in line: | |
| speaker, text = line.split(':', 1) | |
| conversation_data["conversation"].append({ | |
| "speaker": speaker.strip(), | |
| "text": text.strip() | |
| }) | |
| return conversation_data | |
| def generate_tts_audio_gpu( | |
| self, | |
| conversation_json: Dict, | |
| audio_prompt_path_input: str, | |
| exaggeration_input: float = 0.5, | |
| temperature_input: float = 0.8, | |
| seed_num_input: int = 0, | |
| cfgw_input: float = 0.5, | |
| chunk_size_input: int = 250 | |
| ) -> tuple[int, np.ndarray]: | |
| """ | |
| Chatterbox TTS๋ฅผ ์ฌ์ฉํ์ฌ ๋ํ๋ฅผ ์์ฑ์ผ๋ก ๋ณํ | |
| """ | |
| if not CHATTERBOX_AVAILABLE: | |
| raise RuntimeError("Chatterbox TTS not available. Please install chatterbox package.") | |
| try: | |
| # GPU ํจ์ ๋ด์์ ๋ชจ๋ธ ๋ก๋ | |
| model = ChatterboxTTS.from_pretrained(DEVICE) | |
| print(f"โ Chatterbox TTS model loaded on {DEVICE}") | |
| except Exception as e: | |
| raise RuntimeError(f"Failed to load Chatterbox TTS model: {e}") | |
| if seed_num_input != 0: | |
| set_seed(int(seed_num_input)) | |
| audio_segments = [] | |
| for i, turn in enumerate(conversation_json["conversation"]): | |
| text = turn["text"] | |
| if not text.strip(): | |
| continue | |
| print(f"๐๏ธ ์์ฑ ์ค: Speaker {i+1} - '{text[:50]}...'") | |
| try: | |
| # ํ ์คํธ๊ฐ ์งง์ผ๋ฉด ๋จ์ผ ์์ฑ | |
| if len(text) <= 300: | |
| wav = model.generate( | |
| text, | |
| audio_prompt_path=audio_prompt_path_input, | |
| exaggeration=exaggeration_input, | |
| temperature=temperature_input, | |
| cfg_weight=cfgw_input, | |
| ) | |
| wav_chunk = wav.squeeze(0).numpy() | |
| audio_segments.append(wav_chunk) | |
| else: | |
| # ๊ธด ํ ์คํธ๋ ์ฒญํฌ๋ก ๋ถํ | |
| chunks = split_text_into_chunks(text, max_chars=chunk_size_input) | |
| print(f"๐ ํ ์คํธ๋ฅผ {len(chunks)}๊ฐ ์ฒญํฌ๋ก ๋ถํ ") | |
| chunk_audio_segments = [] | |
| for j, chunk in enumerate(chunks): | |
| print(f" ๐ ์ฒญํฌ {j+1}/{len(chunks)} ์์ฑ ์ค...") | |
| wav = model.generate( | |
| chunk, | |
| audio_prompt_path=audio_prompt_path_input, | |
| exaggeration=exaggeration_input, | |
| temperature=temperature_input, | |
| cfg_weight=cfgw_input, | |
| ) | |
| wav_chunk = wav.squeeze(0).numpy() | |
| chunk_audio_segments.append(wav_chunk) | |
| # ์ฒญํฌ๋ค์ ์ฐ๊ฒฐ | |
| if chunk_audio_segments: | |
| silence_duration = int(0.1 * model.sr) # 0.1์ด ๋ฌด์ | |
| silence = np.zeros(silence_duration) | |
| turn_audio = [] | |
| for j, segment in enumerate(chunk_audio_segments): | |
| turn_audio.append(segment) | |
| if j < len(chunk_audio_segments) - 1: | |
| turn_audio.append(silence) | |
| concatenated_turn = np.concatenate(turn_audio) | |
| audio_segments.append(concatenated_turn) | |
| except Exception as e: | |
| print(f"โ Speaker {i+1} ์์ฑ ์ค ์ค๋ฅ ๋ฐ์: {e}") | |
| # ์ค๋ฅ ๋ฐ์ ์ ๋ฌด์์ผ๋ก ๋์ฒด | |
| silence_duration = int(2.0 * model.sr) # 2์ด ๋ฌด์ | |
| silence = np.zeros(silence_duration) | |
| audio_segments.append(silence) | |
| continue | |
| if not audio_segments: | |
| raise RuntimeError("๋ชจ๋ ์ค๋์ค ์์ฑ์ ์คํจํ์ต๋๋ค.") | |
| # ๋ชจ๋ ์คํผ์ปค์ ์ค๋์ค ์ธ๊ทธ๋จผํธ ์ฐ๊ฒฐ | |
| speaker_silence_duration = int(0.5 * model.sr) # ์คํผ์ปค ๊ฐ 0.5์ด ๋ฌด์ | |
| speaker_silence = np.zeros(speaker_silence_duration) | |
| final_audio = [] | |
| for i, segment in enumerate(audio_segments): | |
| final_audio.append(segment) | |
| if i < len(audio_segments) - 1: | |
| final_audio.append(speaker_silence) | |
| concatenated_audio = np.concatenate(final_audio) | |
| print(f"๐ ์ค๋์ค ์์ฑ ์๋ฃ! ์ด ๊ธธ์ด: {len(concatenated_audio) / model.sr:.2f}์ด") | |
| return (model.sr, concatenated_audio) | |
| def _create_output_directory(self) -> str: | |
| """Create a unique output directory""" | |
| random_bytes = os.urandom(8) | |
| folder_name = base64.urlsafe_b64encode(random_bytes).decode("utf-8") | |
| os.makedirs(folder_name, exist_ok=True) | |
| return folder_name | |
| # Global converter instance | |
| converter = UnifiedAudioConverter(ConversationConfig()) | |
| async def synthesize(article_input, input_type: str = "URL", mode: str = "Local"): | |
| """Main synthesis function - handles URL, PDF, and Keyword inputs""" | |
| try: | |
| # Extract text based on input type | |
| if input_type == "URL": | |
| if not article_input or not isinstance(article_input, str): | |
| return "Please provide a valid URL.", None | |
| text = converter.fetch_text(article_input) | |
| elif input_type == "PDF": | |
| if not article_input: | |
| return "Please upload a PDF file.", None | |
| text = converter.extract_text_from_pdf(article_input) | |
| else: # Keyword | |
| if not article_input or not isinstance(article_input, str): | |
| return "Please provide a keyword or topic.", None | |
| text = search_and_compile_content(article_input, "English") | |
| text = f"Keyword-based content:\n{text}" | |
| # Limit text to max words | |
| words = text.split() | |
| if len(words) > converter.config.max_words: | |
| text = " ".join(words[:converter.config.max_words]) | |
| # Extract conversation based on mode | |
| if mode == "Local": | |
| try: | |
| conversation_json = converter.extract_conversation_local(text, "English") | |
| except Exception as e: | |
| print(f"Local mode failed: {e}, trying API fallback") | |
| api_key = os.environ.get("TOGETHER_API_KEY") | |
| if api_key: | |
| converter.initialize_api_mode(api_key) | |
| conversation_json = converter.extract_conversation_api(text, "English") | |
| else: | |
| raise RuntimeError("Local mode failed and no API key available for fallback") | |
| else: # API mode | |
| api_key = os.environ.get("TOGETHER_API_KEY") | |
| if not api_key: | |
| print("API key not found, falling back to local mode") | |
| conversation_json = converter.extract_conversation_local(text, "English") | |
| else: | |
| try: | |
| converter.initialize_api_mode(api_key) | |
| conversation_json = converter.extract_conversation_api(text, "English") | |
| except Exception as e: | |
| print(f"API mode failed: {e}, falling back to local mode") | |
| conversation_json = converter.extract_conversation_local(text, "English") | |
| # Generate conversation text | |
| conversation_text = "\n".join( | |
| f"{turn.get('speaker', f'Speaker {i+1}')}: {turn['text']}" | |
| for i, turn in enumerate(conversation_json["conversation"]) | |
| ) | |
| return conversation_text, None | |
| except Exception as e: | |
| return f"Error: {str(e)}", None | |
| async def regenerate_audio( | |
| conversation_text: str, | |
| ref_audio_path: str, | |
| exaggeration: float = 0.5, | |
| temperature: float = 0.8, | |
| seed_num: int = 0, | |
| cfg_weight: float = 0.5, | |
| chunk_size: int = 250 | |
| ): | |
| """Regenerate audio from edited conversation text using Chatterbox TTS""" | |
| if not conversation_text.strip(): | |
| return "Please provide conversation text.", None | |
| if not CHATTERBOX_AVAILABLE: | |
| return "Chatterbox TTS not available. Please check the installation.", None | |
| try: | |
| conversation_json = converter.parse_conversation_text(conversation_text) | |
| if not conversation_json["conversation"]: | |
| return "No valid conversation found in the text.", None | |
| # Generate audio using Chatterbox TTS | |
| try: | |
| sr, audio = converter.generate_tts_audio_gpu( | |
| conversation_json, | |
| ref_audio_path, | |
| exaggeration, | |
| temperature, | |
| seed_num, | |
| cfg_weight, | |
| chunk_size | |
| ) | |
| # Save audio to file | |
| output_dir = converter._create_output_directory() | |
| output_file = os.path.join(output_dir, "podcast_audio.wav") | |
| sf.write(output_file, audio, sr) | |
| return "๐ Audio generated successfully!", output_file | |
| except Exception as e: | |
| error_msg = str(e) | |
| if "Chatterbox TTS not available" in error_msg: | |
| return "โ Chatterbox TTS is not properly installed. Please check the requirements.", None | |
| elif "CUDA" in error_msg or "GPU" in error_msg: | |
| return f"โ GPU error: {error_msg}. Please try reducing chunk size or use CPU.", None | |
| else: | |
| return f"โ Audio generation error: {error_msg}", None | |
| except Exception as e: | |
| return f"โ Error processing conversation: {str(e)}", None | |
| def synthesize_sync(article_input, input_type: str = "URL", mode: str = "Local"): | |
| """Synchronous wrapper for async synthesis""" | |
| return asyncio.run(synthesize(article_input, input_type, mode)) | |
| def regenerate_audio_sync(conversation_text: str, ref_audio_path: str, exaggeration: float, temperature: float, seed_num: int, cfg_weight: float, chunk_size: int): | |
| """Synchronous wrapper for async audio regeneration""" | |
| return asyncio.run(regenerate_audio(conversation_text, ref_audio_path, exaggeration, temperature, seed_num, cfg_weight, chunk_size)) | |
| def toggle_input_visibility(input_type): | |
| """Toggle visibility of URL input, file upload, and keyword input based on input type""" | |
| if input_type == "URL": | |
| return gr.update(visible=True), gr.update(visible=False), gr.update(visible=False) | |
| elif input_type == "PDF": | |
| return gr.update(visible=False), gr.update(visible=True), gr.update(visible=False) | |
| else: # Keyword | |
| return gr.update(visible=False), gr.update(visible=False), gr.update(visible=True) | |
| def update_char_count(text, chunk_size): | |
| """ํ ์คํธ ๊ธธ์ด ์ ๋ณด ์ ๋ฐ์ดํธ""" | |
| char_len = len(text) | |
| if char_len <= 300: | |
| return f"{char_len} characters (single generation)" | |
| else: | |
| chunks = split_text_into_chunks(text, max_chars=chunk_size) | |
| chunk_count = len(chunks) | |
| estimated_time = chunk_count * 3 # ์ฒญํฌ๋น ์ฝ 3์ด ์์ | |
| return f"{char_len} characters, {chunk_count} chunks (estimated time: ~{estimated_time}s)" | |
| # ๋ชจ๋ธ ์ด๊ธฐํ (์ฑ ์์ ์) | |
| if LLAMA_CPP_AVAILABLE: | |
| try: | |
| model_path = hf_hub_download( | |
| repo_id=converter.config.local_model_repo, | |
| filename=converter.config.local_model_name, | |
| local_dir="./models" | |
| ) | |
| print(f"Model downloaded to: {model_path}") | |
| except Exception as e: | |
| print(f"Failed to download model at startup: {e}") | |
| # Gradio Interface | |
| with gr.Blocks(theme='soft', title="AI Podcast Generator", css=""" | |
| .container {max-width: 1200px; margin: auto; padding: 20px;} | |
| .header-text {text-align: center; margin-bottom: 30px;} | |
| .input-group {background: #f7f7f7; padding: 20px; border-radius: 10px; margin-bottom: 20px;} | |
| .output-group {background: #f0f0f0; padding: 20px; border-radius: 10px;} | |
| .status-box {background: #e8f4f8; padding: 15px; border-radius: 8px; margin-top: 10px;} | |
| """) as demo: | |
| with gr.Column(elem_classes="container"): | |
| # ํค๋ | |
| with gr.Row(elem_classes="header-text"): | |
| gr.Markdown(""" | |
| # ๐๏ธ LIVE Podcast Generator with Chatterbox TTS | |
| ### Convert any article, blog, PDF document, or topic into an engaging professional podcast conversation! | |
| """) | |
| with gr.Row(elem_classes="discord-badge"): | |
| gr.HTML(""" | |
| <p style="text-align: center;"> | |
| <a href="https://discord.gg/openfreeai" target="_blank" style="display: inline-block; margin-right: 10px;"> | |
| <img src="https://img.shields.io/static/v1?label=Discord&message=Openfree%20AI&color=%230000ff&labelColor=%23800080&logo=discord&logoColor=white&style=for-the-badge" alt="badge"> | |
| </a> | |
| <a href="https://open.spotify.com/show/36GtIP7iqJxCwp7FfXmTYK?si=KsIsUJq7SJiiudPTaMsXAA" target="_blank" style="display: inline-block;"> | |
| <img src="https://img.shields.io/static/v1?label=Spotify&message=Podcast&color=%230000ff&labelColor=%23000080&logo=Spotify&logoColor=white&style=for-the-badge" alt="badge"> | |
| </a> | |
| <a href="https://huggingface.co/spaces/openfree/AI-Podcast" target="_blank" style="display: inline-block;"> | |
| <img src="https://img.shields.io/static/v1?label=Huggingface&message=AI%20Podcast&color=%230000ff&labelColor=%23ffa500&logo=huggingface&logoColor=white&style=for-the-badge" alt="badge"> | |
| </a> | |
| </p> | |
| """) | |
| # ์ํ ํ์ ์น์ | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown(f""" | |
| #### ๐ค System Status | |
| - **LLM**: {converter.config.local_model_name.split('.')[0]} | |
| - **Fallback**: {converter.config.api_model_name.split('/')[-1]} | |
| - **Llama CPP**: {"โ Ready" if LLAMA_CPP_AVAILABLE else "โ Not Available"} | |
| - **Chatterbox TTS**: {"โ Ready" if CHATTERBOX_AVAILABLE else "โ Not Available"} | |
| - **Search**: {"โ Brave API" if BRAVE_KEY else "โ No API"} | |
| """) | |
| with gr.Column(scale=1): | |
| gr.Markdown(""" | |
| #### ๐๏ธ Chatterbox TTS Features | |
| - **High Quality**: Neural voice synthesis | |
| - **Voice Cloning**: Upload your reference audio | |
| - **Unlimited Length**: Automatic text chunking | |
| - **Professional Style**: Expert podcast discussions | |
| """) | |
| # ๋ฉ์ธ ์ ๋ ฅ ์น์ | |
| with gr.Group(elem_classes="input-group"): | |
| with gr.Row(): | |
| # ์ผ์ชฝ: ์ ๋ ฅ ์ต์ ๋ค | |
| with gr.Column(scale=2): | |
| # ์ ๋ ฅ ํ์ ์ ํ | |
| input_type_selector = gr.Radio( | |
| choices=["URL", "PDF", "Keyword"], | |
| value="URL", | |
| label="๐ฅ Input Type", | |
| info="Choose your content source" | |
| ) | |
| # URL ์ ๋ ฅ | |
| url_input = gr.Textbox( | |
| label="๐ Article URL", | |
| placeholder="Enter the article URL here...", | |
| value="", | |
| visible=True, | |
| lines=2 | |
| ) | |
| # PDF ์ ๋ก๋ | |
| pdf_input = gr.File( | |
| label="๐ Upload PDF", | |
| file_types=[".pdf"], | |
| visible=False | |
| ) | |
| # ํค์๋ ์ ๋ ฅ | |
| keyword_input = gr.Textbox( | |
| label="๐ Topic/Keyword", | |
| placeholder="Enter a topic (e.g., 'AI trends 2024', 'quantum computing')", | |
| value="", | |
| visible=False, | |
| info="System will search and compile latest information", | |
| lines=2 | |
| ) | |
| # ์ค๋ฅธ์ชฝ: ์ค์ ์ต์ ๋ค | |
| with gr.Column(scale=1): | |
| # ์ฒ๋ฆฌ ๋ชจ๋ | |
| mode_selector = gr.Radio( | |
| choices=["Local", "API"], | |
| value="Local", | |
| label="โ๏ธ Processing Mode", | |
| info="Local: On-device | API: Cloud" | |
| ) | |
| # ์์ฑ ๋ฒํผ | |
| with gr.Row(): | |
| convert_btn = gr.Button( | |
| "๐ฏ Generate Professional Conversation", | |
| variant="primary", | |
| size="lg", | |
| scale=1 | |
| ) | |
| # TTS ์ค์ ์น์ | |
| with gr.Group(elem_classes="input-group"): | |
| gr.Markdown("### ๐๏ธ Chatterbox TTS Settings") | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| ref_audio = gr.Audio( | |
| sources=["upload", "microphone"], | |
| type="filepath", | |
| label="Reference Audio File (Upload your voice)", | |
| value="https://storage.googleapis.com/chatterbox-demo-samples/prompts/female_shadowheart4.flac", | |
| info="Upload your voice sample for voice cloning" | |
| ) | |
| with gr.Column(scale=1): | |
| exaggeration = gr.Slider( | |
| 0.25, 2, step=.05, | |
| label="Exaggeration (Neutral = 0.5)", | |
| value=.5 | |
| ) | |
| cfg_weight = gr.Slider( | |
| 0.2, 1, step=.05, | |
| label="CFG/Pace", | |
| value=0.5 | |
| ) | |
| chunk_size = gr.Slider( | |
| 100, 300, step=50, | |
| label="Chunk Size (characters)", | |
| value=250, | |
| info="Text chunking for long conversations" | |
| ) | |
| with gr.Accordion("Advanced Options", open=False): | |
| seed_num = gr.Number(value=0, label="Random seed (0 for random)") | |
| temperature = gr.Slider(0.05, 5, step=.05, label="Temperature", value=.8) | |
| # ์ถ๋ ฅ ์น์ | |
| with gr.Group(elem_classes="output-group"): | |
| with gr.Row(): | |
| # ์ผ์ชฝ: ๋ํ ํ ์คํธ | |
| with gr.Column(scale=3): | |
| conversation_output = gr.Textbox( | |
| label="๐ฌ Generated Professional Conversation (Editable)", | |
| lines=25, | |
| max_lines=50, | |
| interactive=True, | |
| placeholder="Professional podcast conversation will appear here...", | |
| info="Edit the conversation as needed. Format: 'Speaker Name: Text'" | |
| ) | |
| # ํ ์คํธ ๊ธธ์ด ํ์ | |
| char_count = gr.Textbox( | |
| label="Text Information", | |
| value="0 characters", | |
| interactive=False | |
| ) | |
| # ์ค๋์ค ์์ฑ ๋ฒํผ | |
| with gr.Row(): | |
| generate_audio_btn = gr.Button( | |
| "๐๏ธ Generate Audio with Chatterbox TTS", | |
| variant="secondary", | |
| size="lg" | |
| ) | |
| # ์ค๋ฅธ์ชฝ: ์ค๋์ค ์ถ๋ ฅ ๋ฐ ์ํ | |
| with gr.Column(scale=2): | |
| audio_output = gr.Audio( | |
| label="๐ง Professional Podcast Audio", | |
| type="filepath", | |
| interactive=False | |
| ) | |
| status_output = gr.Textbox( | |
| label="๐ Status", | |
| interactive=False, | |
| lines=3, | |
| elem_classes="status-box" | |
| ) | |
| # ๋์๋ง | |
| gr.Markdown(""" | |
| #### ๐ก Quick Tips: | |
| - **URL**: Paste any article link | |
| - **PDF**: Upload documents directly | |
| - **Keyword**: Enter topics for AI research | |
| - **Voice Cloning**: Upload reference audio | |
| - Edit conversation before audio generation | |
| - Longer text automatically chunked | |
| """) | |
| # ์์ ์น์ | |
| with gr.Accordion("๐ Examples", open=False): | |
| gr.Examples( | |
| examples=[ | |
| ["https://huggingface.co/blog/openfreeai/cycle-navigator", "URL", "Local"], | |
| ["quantum computing breakthroughs", "Keyword", "Local"], | |
| ["https://huggingface.co/papers/2505.14810", "URL", "Local"], | |
| ["artificial intelligence ethics", "Keyword", "Local"], | |
| ], | |
| inputs=[url_input, input_type_selector, mode_selector], | |
| outputs=[conversation_output, status_output], | |
| fn=synthesize_sync, | |
| cache_examples=False, | |
| ) | |
| # Input type change handler | |
| input_type_selector.change( | |
| fn=toggle_input_visibility, | |
| inputs=[input_type_selector], | |
| outputs=[url_input, pdf_input, keyword_input] | |
| ) | |
| # ํ ์คํธ ์ ๋ ฅ ์ ๋ฌธ์ ์ ์ ๋ฐ์ดํธ | |
| conversation_output.change( | |
| fn=update_char_count, | |
| inputs=[conversation_output, chunk_size], | |
| outputs=[char_count] | |
| ) | |
| chunk_size.change( | |
| fn=update_char_count, | |
| inputs=[conversation_output, chunk_size], | |
| outputs=[char_count] | |
| ) | |
| # ์ด๋ฒคํธ ์ฐ๊ฒฐ | |
| def get_article_input(input_type, url_input, pdf_input, keyword_input): | |
| """Get the appropriate input based on input type""" | |
| if input_type == "URL": | |
| return url_input | |
| elif input_type == "PDF": | |
| return pdf_input | |
| else: # Keyword | |
| return keyword_input | |
| convert_btn.click( | |
| fn=lambda input_type, url_input, pdf_input, keyword_input, mode: synthesize_sync( | |
| get_article_input(input_type, url_input, pdf_input, keyword_input), input_type, mode | |
| ), | |
| inputs=[input_type_selector, url_input, pdf_input, keyword_input, mode_selector], | |
| outputs=[conversation_output, status_output] | |
| ) | |
| generate_audio_btn.click( | |
| fn=regenerate_audio_sync, | |
| inputs=[conversation_output, ref_audio, exaggeration, temperature, seed_num, cfg_weight, chunk_size], | |
| outputs=[status_output, audio_output] | |
| ) | |
| # Launch the app | |
| if __name__ == "__main__": | |
| demo.queue(api_open=True, default_concurrency_limit=10).launch( | |
| show_api=True, | |
| share=False, | |
| server_name="0.0.0.0", | |
| server_port=7860 | |
| ) |