import gradio as gr import pandas as pd import genanki from pocket_tts import TTSModel import tempfile import os import shutil import random import zipfile import sqlite3 import re import time import json import torch import scipy.io.wavfile from pathlib import Path from concurrent.futures import ThreadPoolExecutor, as_completed from pydub import AudioSegment # --- Configuration --- MAX_WORKERS = 4 # Keep low for HF Spaces (CPU/RAM constraint) PREVIEW_LIMIT = 100 # UI safety cap PROGRESS_THROTTLE = 1.0 # Seconds between UI updates # --- Helpers --- def clean_text_for_tts(text): """Deep cleaning for TTS input only.""" if pd.isna(text): return "" text = str(text) # Remove HTML tags text = re.sub(re.compile('<.*?>'), '', text) # Remove Anki sound tags text = re.sub(r'\[sound:.*?\]', '', text) # Remove mustache templates text = re.sub(r'\{\{.*?\}\}', '', text) return text.strip() def has_existing_audio(text): """Check if text already contains an Anki sound tag.""" if pd.isna(text): return False return bool(re.search(r'\[sound:.*?\]', str(text))) print("Loading TTS Model...") try: TTS_MODEL = TTSModel.load_model() print("Model Loaded Successfully.") except Exception as e: print(f"CRITICAL ERROR loading model: {e}") TTS_MODEL = None # Get default voice state VOICE_STATE = None if TTS_MODEL: try: VOICE_STATE = TTS_MODEL.get_state_for_audio_prompt("alba") # Default voice except Exception as e: print(f"Warning: Could not load default voice: {e}") def wav_to_mp3(src_wav, dst_mp3): AudioSegment.from_wav(src_wav).export(dst_mp3, format="mp3", bitrate="64k") def generate_audio_for_row(q_text, a_text, idx, tmpdir, mode): """ Generates audio. Returns (path_q, path_a). Returns 'SKIP' if audio exists and we are preserving it. """ q_out, a_out = None, None # Logic for handling modes # Mode 0: Smart Fill (Preserve Existing) # Mode 1: Overwrite All overwrite = (mode == "Generate all new audio (Overwrite)") # --- Question Processing --- if not overwrite and has_existing_audio(q_text): q_out = "SKIP" else: q_wav = os.path.join(tmpdir, f"q_{idx}.wav") try: clean = clean_text_for_tts(q_text) if clean and TTS_MODEL and VOICE_STATE: # Generate audio using new API audio_tensor = TTS_MODEL.generate_audio(VOICE_STATE, clean) # Convert tensor to numpy and save as wav scipy.io.wavfile.write(q_wav, TTS_MODEL.sample_rate, audio_tensor.numpy()) q_out = q_wav else: AudioSegment.silent(duration=500).export(q_wav, format="wav") q_out = q_wav except Exception as e: print(f"TTS Error Q row {idx}: {e}") # Fallback to silence to keep deck integrity AudioSegment.silent(duration=500).export(q_wav, format="wav") q_out = q_wav # --- Answer Processing --- if not overwrite and has_existing_audio(a_text): a_out = "SKIP" else: a_wav = os.path.join(tmpdir, f"a_{idx}.wav") try: clean = clean_text_for_tts(a_text) if clean and TTS_MODEL and VOICE_STATE: # Generate audio using new API audio_tensor = TTS_MODEL.generate_audio(VOICE_STATE, clean) # Convert tensor to numpy and save as wav scipy.io.wavfile.write(a_wav, TTS_MODEL.sample_rate, audio_tensor.numpy()) a_out = a_wav else: AudioSegment.silent(duration=500).export(a_wav, format="wav") a_out = a_wav except Exception as e: print(f"TTS Error A row {idx}: {e}") AudioSegment.silent(duration=500).export(a_wav, format="wav") a_out = a_wav return q_out, a_out def strip_html_for_display(text): """Remove HTML tags for preview readability.""" if pd.isna(text) or text == "": return "" text = str(text) # Remove HTML tags text = re.sub(r'<[^>]+>', '', text) # Decode HTML entities text = text.replace(' ', ' ').replace('>', '>').replace('<', '<').replace('&', '&') # Limit length for display if len(text) > 50: text = text[:50] + '...' return text.strip() def extract_unique_tags(df): """Extract all unique tags from the Tags column.""" if df is None or 'Tags' not in df.columns: return ["All"] all_tags = set() for tag_str in df['Tags']: if tag_str: # Tags are space-separated, e.g., " MK_MathematicsKnowledge " tags = [t.strip() for t in tag_str.split() if t.strip()] all_tags.update(tags) return ["All"] + sorted(list(all_tags)) def parse_file(file_obj): if file_obj is None: return None, None, None, "No file uploaded", "", None ext = Path(file_obj.name).suffix.lower() df = pd.DataFrame() extract_root = None # Directory where we keep original media has_media = False try: if ext == ".csv": df = pd.read_csv(file_obj.name) if len(df.columns) < 2: df = pd.read_csv(file_obj.name, header=None) if len(df.columns) < 2: return None, None, None, "CSV error: Need 2 columns", "", None df = df.iloc[:, :2] df.columns = ["Question", "Answer"] df['Tags'] = "" # CSV files don't have tags elif ext == ".apkg" or ext == ".zip": # Extract to a PERSISTENT temp dir (passed to state) extract_root = tempfile.mkdtemp() with zipfile.ZipFile(file_obj.name, 'r') as z: z.extractall(extract_root) col_path = os.path.join(extract_root, "collection.anki2") if not os.path.exists(col_path): shutil.rmtree(extract_root) return None, None, None, "Invalid APKG: No collection.anki2", "", None conn = sqlite3.connect(col_path) cur = conn.cursor() cur.execute("SELECT flds, tags FROM notes") rows = cur.fetchall() data = [] audio_count = 0 # Count cards with existing audio for r in rows: flds = r[0].split('\x1f') tags = r[1].strip() if len(r) > 1 else "" q = flds[0] if len(flds) > 0 else "" a = flds[1] if len(flds) > 1 else "" # Check if either field has audio tags if re.search(r'\[sound:.*?\]', q) or re.search(r'\[sound:.*?\]', a): audio_count += 1 data.append([q, a, tags]) df = pd.DataFrame(data, columns=["Question", "Answer", "Tags"]) conn.close() # has_media means existing AUDIO, not images has_media = audio_count > 0 else: return None, None, None, "Unsupported file type", "", None df = df.fillna("") msg = f"✅ Loaded {len(df)} cards." if has_media: msg += f" 🎵 {audio_count} cards have existing audio." return df, has_media, df.head(PREVIEW_LIMIT), msg, estimate_time(len(df), has_media), extract_root except Exception as e: if extract_root and os.path.exists(extract_root): shutil.rmtree(extract_root) return None, None, None, f"Error: {str(e)}", "", None def estimate_time(num_cards, has_existing_media=False, mode="Smart Fill (Preserve Existing)"): """ Estimate based on benchmark: ~4.7s per card for full generation. Adjusts for Smart Fill mode when existing media is present. """ if num_cards == 0: return "0s" # Base benchmark: 4.7s per card for full audio generation seconds_per_card = 4.7 # If using Smart Fill with existing media, assume ~50% speedup (many cards already have audio) if has_existing_media and "Smart Fill" in mode: seconds_per_card *= 0.5 seconds = num_cards * seconds_per_card if seconds < 60: return f"~{int(seconds)}s" elif seconds < 3600: return f"~{int(seconds/60)} min" else: hours = int(seconds / 3600) mins = int((seconds % 3600) / 60) return f"~{hours}h {mins}m" if mins > 0 else f"~{hours}h" def process_dataframe(df_full, search_term, extract_root, mode, search_in, selected_tag, progress=gr.Progress()): if df_full is None or len(df_full) == 0: return None, "No data" # Start with full dataframe df = df_full.copy() # Apply tag filter first if selected_tag and selected_tag != "All": df = df[df['Tags'].str.contains(selected_tag, na=False, case=False)] # Apply text search filter if search_term: if search_in == "Question Only": mask = df['Question'].str.contains(search_term, case=False, na=False) elif search_in == "Answer Only": mask = df['Answer'].str.contains(search_term, case=False, na=False) else: # Both mask = df.astype(str).apply(lambda x: x.str.contains(search_term, case=False, na=False)).any(axis=1) df = df[mask] if len(df) == 0: return None, "No matching cards" # Setup work_dir = tempfile.mkdtemp() media_files = [] try: # --- Media Preservation Logic --- if extract_root: media_map_path = os.path.join(extract_root, "media") if os.path.exists(media_map_path) and os.path.getsize(media_map_path) > 0: try: with open(media_map_path, 'r') as f: # Fix: Handle potentially malformed JSON gracefully content = f.read().strip() if content: media_map = json.loads(content) # {"0": "my_audio.mp3", ...} # Rename files in extract_root back to original names for k, v in media_map.items(): src = os.path.join(extract_root, k) dst = os.path.join(extract_root, v) if os.path.exists(src): # Rename enables genanki to find them by name os.rename(src, dst) media_files.append(dst) else: print("Warning: Media map file is empty.") except Exception as e: print(f"Warning: Could not restore existing media: {e}") # --- Genanki Setup --- model_id = random.randrange(1 << 30, 1 << 31) my_model = genanki.Model( model_id, 'PocketTTS Model', fields=[{'name': 'Question'}, {'name': 'Answer'}], templates=[{ 'name': 'Card 1', 'qfmt': '{{Question}}', 'afmt': '{{FrontSide}}