Spaces:
Running
Running
| import os | |
| import google.generativeai as genai | |
| import json | |
| import time | |
| import io | |
| import threading | |
| import uuid | |
| import requests | |
| import re | |
| import logging | |
| import random | |
| import base64 | |
| import atexit | |
| from datetime import datetime, timedelta | |
| from itertools import cycle | |
| from flask import Flask, request, jsonify, render_template, send_file | |
| from flask_cors import CORS | |
| from pydub import AudioSegment | |
| from huggingface_hub import HfApi, hf_hub_download | |
| from huggingface_hub.utils import RepositoryNotFoundError, EntryNotFoundError | |
| # --- CONFIGURATION & LOGGING --- | |
| CACHE_DIRECTORY = "/tmp/huggingface_cache_ezmary" | |
| os.makedirs(CACHE_DIRECTORY, exist_ok=True) | |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', datefmt='%Y-%m-%d %H:%M:%S') | |
| app = Flask(__name__) | |
| CORS(app) | |
| # --- WORKER POOL SETUP --- | |
| # آدرس اسپیسهای کارگر تولید صدا (TTS) | |
| WORKER_URLS = [ | |
| "https://hamed744-ttspro6.hf.space/generate", | |
| "https://hamed744-ttspro7.hf.space/generate", | |
| "https://hamed744-ttspro8.hf.space/generate", | |
| ] | |
| worker_pool = cycle(WORKER_URLS) | |
| def get_next_worker_url(): | |
| return next(worker_pool) | |
| # --- آدرس سرویس تغییر صدا (VC) --- | |
| # این آدرس اسپیس جدید که در درخواست فید | |
| VC_SPACE_URL = "https://ezmary-sada.hf.space" | |
| # --- GLOBAL VARIABLES --- | |
| tasks = {} | |
| tasks_lock = threading.Lock() | |
| request_counter = 0 | |
| request_counter_lock = threading.Lock() | |
| DATASET_REPO = "opera8/Karbaran-rayegan-tedad" | |
| DATASET_FILENAME = "usage_data.json" | |
| USAGE_LIMIT = 5 | |
| HF_TOKEN = os.environ.get("HF_TOKEN") | |
| CLEANUP_INTERVAL_SECONDS = 6 * 30 * 24 * 60 * 60 | |
| last_cleanup_time = time.time() | |
| usage_data_cache = [] | |
| cache_lock = threading.Lock() | |
| data_changed = threading.Event() | |
| api = None | |
| # --- DATABASE LOGIC --- | |
| if not HF_TOKEN: | |
| logging.error("CRITICAL: Secret 'HF_TOKEN' not found.") | |
| else: | |
| api = HfApi(token=HF_TOKEN) | |
| logging.info("HfApi initialized.") | |
| def load_initial_data(): | |
| global usage_data_cache | |
| with cache_lock: | |
| if not api: return | |
| try: | |
| local_path = hf_hub_download( | |
| repo_id=DATASET_REPO, filename=DATASET_FILENAME, repo_type="dataset", token=HF_TOKEN, force_download=True, cache_dir=CACHE_DIRECTORY | |
| ) | |
| with open(local_path, 'r', encoding='utf-8') as f: | |
| content = f.read() | |
| if content: usage_data_cache = json.loads(content) | |
| except (RepositoryNotFoundError, EntryNotFoundError): | |
| logging.warning("Dataset not found, creating new.") | |
| except Exception as e: | |
| logging.error(f"Failed to load data: {e}") | |
| def persist_data_to_hub(): | |
| global last_cleanup_time, usage_data_cache | |
| with cache_lock: | |
| now = time.time() | |
| if (now - last_cleanup_time) > CLEANUP_INTERVAL_SECONDS: | |
| six_months_ago = now - CLEANUP_INTERVAL_SECONDS | |
| usage_data_cache = [u for u in usage_data_cache if u.get('week_start', 0) > six_months_ago] | |
| last_cleanup_time = now | |
| data_changed.set() | |
| if not data_changed.is_set() or not api: return | |
| try: | |
| data_to_write = list(usage_data_cache) | |
| temp_filepath = os.path.join(CACHE_DIRECTORY, "temp_usage_data.json") | |
| with open(temp_filepath, 'w', encoding='utf-8') as f: | |
| json.dump(data_to_write, f, indent=2, ensure_ascii=False) | |
| api.upload_file(path_or_fileobj=temp_filepath, path_in_repo=DATASET_FILENAME, repo_id=DATASET_REPO, repo_type="dataset", commit_message="Update") | |
| os.remove(temp_filepath) | |
| data_changed.clear() | |
| except Exception as e: | |
| logging.error(f"Persist failed: {e}") | |
| def background_persister(): | |
| while True: | |
| time.sleep(10) | |
| persist_data_to_hub() | |
| def get_user_ip(): | |
| if request.headers.getlist("X-Forwarded-For"): | |
| return request.headers.getlist("X-Forwarded-For")[0].split(',')[0].strip() | |
| return request.remote_addr | |
| # --- TTS HELPER FUNCTIONS --- | |
| def merge_audio_segments(audio_segments): | |
| if not audio_segments: return None | |
| combined = AudioSegment.empty() | |
| for segment in audio_segments: | |
| combined += segment | |
| output_buffer = io.BytesIO() | |
| combined.export(output_buffer, format="wav") | |
| output_buffer.seek(0) | |
| return output_buffer | |
| def call_worker(index, chunk_payload): | |
| text_length = len(chunk_payload.get("text", "")) | |
| use_live = True if text_length <= 500 else False | |
| target_speaker = chunk_payload.get("speaker") | |
| is_custom = chunk_payload.get("is_custom", False) | |
| # اگر صدای اختصاصی بود، ابتدا با صدای Charon تولید میکنیم و بعد تغییر میدهیم | |
| actual_speaker_request = "Charon" if is_custom else target_speaker | |
| worker_payload = { | |
| "text": chunk_payload.get("text"), | |
| "speaker": actual_speaker_request, | |
| "temperature": chunk_payload.get("temperature", 0.9), | |
| "use_live_model": use_live, | |
| "retry_limit": 50, | |
| "fallback_to_live": True | |
| } | |
| total_workers = len(WORKER_URLS) | |
| for attempt in range(total_workers * 2): | |
| worker_url = get_next_worker_url() | |
| try: | |
| logging.info(f"Chunk {index} (Len: {text_length}) -> Sending to {worker_url} (LiveMode: {use_live})") | |
| response = requests.post(worker_url, json=worker_payload, timeout=300) | |
| if response.status_code == 200: | |
| audio_data = io.BytesIO(response.content) | |
| audio_segment = AudioSegment.from_file(audio_data) | |
| return index, audio_segment | |
| else: | |
| logging.warning(f"Worker Error {worker_url}: {response.status_code}") | |
| except Exception as e: | |
| logging.warning(f"Worker Connection Fail {worker_url}: {e}") | |
| return index, None | |
| # --- AI PODCAST SCRIPT LOGIC --- | |
| def generate_podcast_in_background(task_id, system_prompt, safety_settings): | |
| try: | |
| keys_str = os.environ.get("ALL_GEMINI_API_KEYS") | |
| keys_list = [k.strip() for k in keys_str.split(',') if k.strip()] if keys_str else [] | |
| if not keys_list: raise ValueError("No AI Keys") | |
| MAX_ATTEMPTS = 50 | |
| for attempt in range(MAX_ATTEMPTS): | |
| key = random.choice(keys_list) | |
| try: | |
| genai.configure(api_key=key) | |
| model = genai.GenerativeModel('gemini-2.5-flash') | |
| res = model.generate_content(system_prompt, safety_settings=safety_settings) | |
| raw_text = res.text | |
| json_string = None | |
| match = re.search(r"```json\s*(\{.*?\})\s*```", raw_text, re.DOTALL) | |
| if match: json_string = match.group(1) | |
| else: | |
| s_idx = raw_text.find('{') | |
| e_idx = raw_text.rfind('}') | |
| if s_idx != -1 and e_idx != -1: json_string = raw_text[s_idx:e_idx+1] | |
| if not json_string: raise ValueError("No JSON found") | |
| data = json.loads(json_string) | |
| if "script" in data: | |
| for t in data["script"]: | |
| if "dialogue" in t: | |
| t["dialogue"] = re.sub(r'\[.*?\]|\(.*?\)', '', t["dialogue"]).strip() | |
| with tasks_lock: | |
| tasks[task_id].update({'status': 'completed', 'data': data}) | |
| return | |
| except Exception as e: | |
| logging.warning(f"AI Attempt {attempt} failed: {e}") | |
| time.sleep(1) | |
| with tasks_lock: tasks[task_id].update({'status': 'failed', 'error': 'Max retries reached'}) | |
| except Exception as e: | |
| with tasks_lock: tasks[task_id].update({'status': 'failed', 'error': str(e)}) | |
| # --- VC LOGIC (اصلاح شده برای هماهنگی با اسپیس جدید) --- | |
| def process_voice_conversion(tts_audio_io, ref_audio_base64): | |
| try: | |
| tts_audio_io.seek(0) | |
| # دیکد کردن Base64 صدای رفرنس | |
| if "," in ref_audio_base64: | |
| ref_audio_base64 = ref_audio_base64.split(",")[1] | |
| ref_bytes = base64.b64decode(ref_audio_base64) | |
| files = { | |
| 'source_audio': ('source.wav', tts_audio_io, 'audio/wav'), | |
| 'ref_audio': ('ref.wav', io.BytesIO(ref_bytes), 'audio/wav') | |
| } | |
| # 1. آپلود فایلها به سرویس VC | |
| logging.info(f"VC: Uploading to {VC_SPACE_URL}/upload") | |
| res = requests.post(f"{VC_SPACE_URL}/upload", files=files, timeout=120) | |
| if res.status_code != 200: | |
| raise Exception(f"VC Upload Failed: {res.text}") | |
| # دریافت اطلاعات کامل پروژه (شامل chunks) | |
| job_data = res.json() | |
| # 2. بررسی وضعیت (Polling) | |
| # افزایش زمان انتظار چون پردازش مدل اختصاصی طولانی است | |
| for _ in range(120): # تا 8 دقیقه انتظار | |
| time.sleep(4) | |
| # نکته مهم: ارسال کل آبجکت job_data به check_status | |
| chk = requests.post(f"{VC_SPACE_URL}/check_status", json=job_data, timeout=30) | |
| if chk.status_code == 200: | |
| stat = chk.json() | |
| if stat.get("status") == "completed": | |
| filename = stat.get("filename") | |
| # 3. دانلود فایل نهایی | |
| dl = requests.get(f"{VC_SPACE_URL}/download/{filename}") | |
| if dl.status_code == 200: | |
| return io.BytesIO(dl.content) | |
| else: | |
| raise Exception("VC Download Failed") | |
| elif stat.get("status") == "failed": | |
| detail = stat.get("detail", "Unknown error") | |
| raise Exception(f"VC Remote Failed: {detail}") | |
| # اگر وضعیت processing بود، ادامه میدهد... | |
| raise Exception("VC Timeout (Processing took too long)") | |
| except Exception as e: | |
| logging.error(f"VC Error: {e}") | |
| return None | |
| # --- ROUTES --- | |
| def index(): | |
| return render_template('index.html') | |
| def check_credit(): | |
| data = request.get_json() | |
| fingerprint = data.get('fingerprint') | |
| if not fingerprint: return jsonify({"status": "error"}), 400 | |
| with cache_lock: | |
| ip = get_user_ip() | |
| now = time.time() | |
| week_ago = now - (7*24*60*60) | |
| user = next((u for u in usage_data_cache if u.get('fingerprint') == fingerprint), None) | |
| user = user or next((u for u in usage_data_cache if ip in u.get('ips', [])), None) | |
| limit_reached = False | |
| remaining = USAGE_LIMIT | |
| reset_ts = 0 | |
| if user: | |
| if user.get('week_start', 0) < week_ago: | |
| user['count'] = 0 | |
| user['week_start'] = now | |
| data_changed.set() | |
| remaining = USAGE_LIMIT - user.get('count', 0) | |
| if remaining <= 0: | |
| limit_reached = True | |
| remaining = 0 | |
| reset_ts = user.get('week_start', now) + (7*24*60*60) | |
| return jsonify({"credits_remaining": remaining, "limit_reached": limit_reached, "reset_timestamp": reset_ts}) | |
| def use_credit(): | |
| data = request.get_json() | |
| fingerprint = data.get('fingerprint') | |
| with cache_lock: | |
| ip = get_user_ip() | |
| now = time.time() | |
| week_ago = now - (7*24*60*60) | |
| user = next((u for u in usage_data_cache if u.get('fingerprint') == fingerprint), None) | |
| user = user or next((u for u in usage_data_cache if ip in u.get('ips', [])), None) | |
| if user: | |
| if user.get('week_start', 0) < week_ago: | |
| user['count'] = 0 | |
| user['week_start'] = now | |
| if user['count'] >= USAGE_LIMIT: | |
| return jsonify({"status": "limit"}), 429 | |
| user['count'] += 1 | |
| if ip not in user['ips']: user['ips'].append(ip) | |
| else: | |
| user = {"fingerprint": fingerprint, "ips": [ip], "count": 1, "week_start": now} | |
| usage_data_cache.append(user) | |
| data_changed.set() | |
| return jsonify({"status": "success", "credits_remaining": USAGE_LIMIT - user['count']}) | |
| def create_full_podcast(): | |
| try: | |
| data = request.get_json() | |
| prompt = data.get('prompt') | |
| speakers = data.get('available_speakers') | |
| if not prompt or not speakers: return jsonify({"error": "Bad request"}), 400 | |
| spk_text = "\n".join([f"- {s['id']}: {s['name']}" for s in speakers]) | |
| sys_prompt = f"""Act as a Podcast Producer. | |
| Topic: "{prompt}" | |
| Speakers Available: | |
| {spk_text} | |
| Output ONLY valid JSON. | |
| Format: {{"selected_speakers": ["id1", "id2"], "script": [{{"speaker_id": "id1", "dialogue": "..."}}]}} | |
| Dialogue rules: No stage directions like [laugh], (sigh). Just spoken words.""" | |
| task_id = str(uuid.uuid4()) | |
| with tasks_lock: tasks[task_id] = {'status': 'pending'} | |
| safety = [{"category": c, "threshold": "BLOCK_NONE"} for c in ["HARM_CATEGORY_HARASSMENT", "HARM_CATEGORY_HATE_SPEECH", "HARM_CATEGORY_SEXUALLY_EXPLICIT", "HARM_CATEGORY_DANGEROUS_CONTENT"]] | |
| threading.Thread(target=generate_podcast_in_background, args=(task_id, sys_prompt, safety)).start() | |
| return jsonify({"task_id": task_id}), 202 | |
| except Exception as e: | |
| return jsonify({"error": str(e)}), 500 | |
| def podcast_status(task_id): | |
| with tasks_lock: | |
| return jsonify(tasks.get(task_id, {'status': 'not_found'})), 200 | |
| def generate_audio_route(): | |
| try: | |
| data = request.get_json() | |
| if not data: return jsonify({"error": "No data"}), 400 | |
| text = data.get("text", "") | |
| speaker = data.get("speaker") | |
| temperature = data.get("temperature", 0.9) | |
| ref_base64 = data.get("ref_audio_base64") | |
| if not text: return jsonify({"error": "Text empty"}), 400 | |
| is_custom = bool(speaker.startswith("custom_") and ref_base64) | |
| payload = { | |
| "text": text, | |
| "speaker": speaker, | |
| "temperature": temperature, | |
| "is_custom": is_custom | |
| } | |
| # تولید صدای اولیه (TTS) | |
| idx, audio_seg = call_worker(0, payload) | |
| if not audio_seg: | |
| return jsonify({"error": "Worker generation failed"}), 503 | |
| final_buffer = io.BytesIO() | |
| audio_seg.export(final_buffer, format="wav") | |
| final_buffer.seek(0) | |
| # اگر صدای اختصاصی بود، تبدیل صدا (VC) را اجرا کن | |
| if is_custom: | |
| logging.info("Starting Custom VC...") | |
| vc_out = process_voice_conversion(final_buffer, ref_base64) | |
| if vc_out: | |
| return send_file(vc_out, mimetype="audio/wav", as_attachment=True, download_name=f"vc_{uuid.uuid4()}.wav") | |
| else: | |
| return jsonify({"error": "Voice Conversion failed"}), 500 | |
| return send_file(final_buffer, mimetype="audio/wav", as_attachment=True, download_name=f"gen_{uuid.uuid4()}.wav") | |
| except Exception as e: | |
| logging.error(f"Generate route error: {e}") | |
| return jsonify({"error": str(e)}), 500 | |
| # --- STARTUP --- | |
| load_initial_data() | |
| threading.Thread(target=background_persister, daemon=True).start() | |
| atexit.register(persist_data_to_hub) | |
| if __name__ == '__main__': | |
| port = int(os.environ.get('PORT', 7860)) | |
| app.run(host='0.0.0.0', port=port) |