Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import pandas as pd | |
| import genanki | |
| from pocket_tts import TTSModel | |
| import tempfile | |
| import os | |
| import shutil | |
| import random | |
| import zipfile | |
| import sqlite3 | |
| import re | |
| import time | |
| import json | |
| import torch | |
| import scipy.io.wavfile | |
| from pathlib import Path | |
| from concurrent.futures import ThreadPoolExecutor, as_completed | |
| from pydub import AudioSegment | |
| # --- Configuration --- | |
| MAX_WORKERS = 4 # Keep low for HF Spaces (CPU/RAM constraint) | |
| PREVIEW_LIMIT = 100 # UI safety cap | |
| PROGRESS_THROTTLE = 1.0 # Seconds between UI updates | |
| # --- Helpers --- | |
| def clean_text_for_tts(text): | |
| """Deep cleaning for TTS input only.""" | |
| if pd.isna(text): return "" | |
| text = str(text) | |
| # Remove HTML tags | |
| text = re.sub(re.compile('<.*?>'), '', text) | |
| # Remove Anki sound tags | |
| text = re.sub(r'\[sound:.*?\]', '', text) | |
| # Remove mustache templates | |
| text = re.sub(r'\{\{.*?\}\}', '', text) | |
| return text.strip() | |
| def has_existing_audio(text): | |
| """Check if text already contains an Anki sound tag.""" | |
| if pd.isna(text): return False | |
| return bool(re.search(r'\[sound:.*?\]', str(text))) | |
| print("Loading TTS Model...") | |
| try: | |
| TTS_MODEL = TTSModel.load_model() | |
| print("Model Loaded Successfully.") | |
| except Exception as e: | |
| print(f"CRITICAL ERROR loading model: {e}") | |
| TTS_MODEL = None | |
| # Get default voice state | |
| VOICE_STATE = None | |
| if TTS_MODEL: | |
| try: | |
| VOICE_STATE = TTS_MODEL.get_state_for_audio_prompt("alba") # Default voice | |
| except Exception as e: | |
| print(f"Warning: Could not load default voice: {e}") | |
| def wav_to_mp3(src_wav, dst_mp3): | |
| AudioSegment.from_wav(src_wav).export(dst_mp3, format="mp3", bitrate="64k") | |
| def generate_audio_for_row(q_text, a_text, idx, tmpdir, mode): | |
| """ | |
| Generates audio. Returns (path_q, path_a). | |
| Returns 'SKIP' if audio exists and we are preserving it. | |
| """ | |
| q_out, a_out = None, None | |
| # Logic for handling modes | |
| # Mode 0: Smart Fill (Preserve Existing) | |
| # Mode 1: Overwrite All | |
| overwrite = (mode == "Generate all new audio (Overwrite)") | |
| # --- Question Processing --- | |
| if not overwrite and has_existing_audio(q_text): | |
| q_out = "SKIP" | |
| else: | |
| q_wav = os.path.join(tmpdir, f"q_{idx}.wav") | |
| try: | |
| clean = clean_text_for_tts(q_text) | |
| if clean and TTS_MODEL and VOICE_STATE: | |
| # Generate audio using new API | |
| audio_tensor = TTS_MODEL.generate_audio(VOICE_STATE, clean) | |
| # Convert tensor to numpy and save as wav | |
| scipy.io.wavfile.write(q_wav, TTS_MODEL.sample_rate, audio_tensor.numpy()) | |
| q_out = q_wav | |
| else: | |
| AudioSegment.silent(duration=500).export(q_wav, format="wav") | |
| q_out = q_wav | |
| except Exception as e: | |
| print(f"TTS Error Q row {idx}: {e}") | |
| # Fallback to silence to keep deck integrity | |
| AudioSegment.silent(duration=500).export(q_wav, format="wav") | |
| q_out = q_wav | |
| # --- Answer Processing --- | |
| if not overwrite and has_existing_audio(a_text): | |
| a_out = "SKIP" | |
| else: | |
| a_wav = os.path.join(tmpdir, f"a_{idx}.wav") | |
| try: | |
| clean = clean_text_for_tts(a_text) | |
| if clean and TTS_MODEL and VOICE_STATE: | |
| # Generate audio using new API | |
| audio_tensor = TTS_MODEL.generate_audio(VOICE_STATE, clean) | |
| # Convert tensor to numpy and save as wav | |
| scipy.io.wavfile.write(a_wav, TTS_MODEL.sample_rate, audio_tensor.numpy()) | |
| a_out = a_wav | |
| else: | |
| AudioSegment.silent(duration=500).export(a_wav, format="wav") | |
| a_out = a_wav | |
| except Exception as e: | |
| print(f"TTS Error A row {idx}: {e}") | |
| AudioSegment.silent(duration=500).export(a_wav, format="wav") | |
| a_out = a_wav | |
| return q_out, a_out | |
| def strip_html_for_display(text): | |
| """Remove HTML tags for preview readability.""" | |
| if pd.isna(text) or text == "": return "" | |
| text = str(text) | |
| # Remove HTML tags | |
| text = re.sub(r'<[^>]+>', '', text) | |
| # Decode HTML entities | |
| text = text.replace(' ', ' ').replace('>', '>').replace('<', '<').replace('&', '&') | |
| # Limit length for display | |
| if len(text) > 50: | |
| text = text[:50] + '...' | |
| return text.strip() | |
| def extract_unique_tags(df): | |
| """Extract all unique tags from the Tags column.""" | |
| if df is None or 'Tags' not in df.columns: | |
| return ["All"] | |
| all_tags = set() | |
| for tag_str in df['Tags']: | |
| if tag_str: | |
| # Tags are space-separated, e.g., " MK_MathematicsKnowledge " | |
| tags = [t.strip() for t in tag_str.split() if t.strip()] | |
| all_tags.update(tags) | |
| return ["All"] + sorted(list(all_tags)) | |
| def parse_file(file_obj): | |
| if file_obj is None: | |
| return None, None, None, "No file uploaded", "", None | |
| ext = Path(file_obj.name).suffix.lower() | |
| df = pd.DataFrame() | |
| extract_root = None # Directory where we keep original media | |
| has_media = False | |
| try: | |
| if ext == ".csv": | |
| df = pd.read_csv(file_obj.name) | |
| if len(df.columns) < 2: | |
| df = pd.read_csv(file_obj.name, header=None) | |
| if len(df.columns) < 2: | |
| return None, None, None, "CSV error: Need 2 columns", "", None | |
| df = df.iloc[:, :2] | |
| df.columns = ["Question", "Answer"] | |
| df['Tags'] = "" # CSV files don't have tags | |
| elif ext == ".apkg" or ext == ".zip": | |
| # Extract to a PERSISTENT temp dir (passed to state) | |
| extract_root = tempfile.mkdtemp() | |
| with zipfile.ZipFile(file_obj.name, 'r') as z: | |
| z.extractall(extract_root) | |
| col_path = os.path.join(extract_root, "collection.anki2") | |
| if not os.path.exists(col_path): | |
| shutil.rmtree(extract_root) | |
| return None, None, None, "Invalid APKG: No collection.anki2", "", None | |
| conn = sqlite3.connect(col_path) | |
| cur = conn.cursor() | |
| cur.execute("SELECT flds, tags FROM notes") | |
| rows = cur.fetchall() | |
| data = [] | |
| audio_count = 0 # Count cards with existing audio | |
| for r in rows: | |
| flds = r[0].split('\x1f') | |
| tags = r[1].strip() if len(r) > 1 else "" | |
| q = flds[0] if len(flds) > 0 else "" | |
| a = flds[1] if len(flds) > 1 else "" | |
| # Check if either field has audio tags | |
| if re.search(r'\[sound:.*?\]', q) or re.search(r'\[sound:.*?\]', a): | |
| audio_count += 1 | |
| data.append([q, a, tags]) | |
| df = pd.DataFrame(data, columns=["Question", "Answer", "Tags"]) | |
| conn.close() | |
| # has_media means existing AUDIO, not images | |
| has_media = audio_count > 0 | |
| else: | |
| return None, None, None, "Unsupported file type", "", None | |
| df = df.fillna("") | |
| msg = f"✅ Loaded {len(df)} cards." | |
| if has_media: | |
| msg += f" 🎵 {audio_count} cards have existing audio." | |
| return df, has_media, df.head(PREVIEW_LIMIT), msg, estimate_time(len(df), has_media), extract_root | |
| except Exception as e: | |
| if extract_root and os.path.exists(extract_root): | |
| shutil.rmtree(extract_root) | |
| return None, None, None, f"Error: {str(e)}", "", None | |
| def estimate_time(num_cards, has_existing_media=False, mode="Smart Fill (Preserve Existing)"): | |
| """ | |
| Estimate based on benchmark: ~4.7s per card for full generation. | |
| Adjusts for Smart Fill mode when existing media is present. | |
| """ | |
| if num_cards == 0: | |
| return "0s" | |
| # Base benchmark: 4.7s per card for full audio generation | |
| seconds_per_card = 4.7 | |
| # If using Smart Fill with existing media, assume ~50% speedup (many cards already have audio) | |
| if has_existing_media and "Smart Fill" in mode: | |
| seconds_per_card *= 0.5 | |
| seconds = num_cards * seconds_per_card | |
| if seconds < 60: | |
| return f"~{int(seconds)}s" | |
| elif seconds < 3600: | |
| return f"~{int(seconds/60)} min" | |
| else: | |
| hours = int(seconds / 3600) | |
| mins = int((seconds % 3600) / 60) | |
| return f"~{hours}h {mins}m" if mins > 0 else f"~{hours}h" | |
| def process_dataframe(df_full, search_term, extract_root, mode, search_in, selected_tag, progress=gr.Progress()): | |
| if df_full is None or len(df_full) == 0: | |
| return None, "No data" | |
| # Start with full dataframe | |
| df = df_full.copy() | |
| # Apply tag filter first | |
| if selected_tag and selected_tag != "All": | |
| df = df[df['Tags'].str.contains(selected_tag, na=False, case=False)] | |
| # Apply text search filter | |
| if search_term: | |
| if search_in == "Question Only": | |
| mask = df['Question'].str.contains(search_term, case=False, na=False) | |
| elif search_in == "Answer Only": | |
| mask = df['Answer'].str.contains(search_term, case=False, na=False) | |
| else: # Both | |
| mask = df.astype(str).apply(lambda x: x.str.contains(search_term, case=False, na=False)).any(axis=1) | |
| df = df[mask] | |
| if len(df) == 0: | |
| return None, "No matching cards" | |
| # Setup | |
| work_dir = tempfile.mkdtemp() | |
| media_files = [] | |
| try: | |
| # --- Media Preservation Logic --- | |
| if extract_root: | |
| media_map_path = os.path.join(extract_root, "media") | |
| if os.path.exists(media_map_path) and os.path.getsize(media_map_path) > 0: | |
| try: | |
| with open(media_map_path, 'r') as f: | |
| # Fix: Handle potentially malformed JSON gracefully | |
| content = f.read().strip() | |
| if content: | |
| media_map = json.loads(content) # {"0": "my_audio.mp3", ...} | |
| # Rename files in extract_root back to original names | |
| for k, v in media_map.items(): | |
| src = os.path.join(extract_root, k) | |
| dst = os.path.join(extract_root, v) | |
| if os.path.exists(src): | |
| # Rename enables genanki to find them by name | |
| os.rename(src, dst) | |
| media_files.append(dst) | |
| else: | |
| print("Warning: Media map file is empty.") | |
| except Exception as e: | |
| print(f"Warning: Could not restore existing media: {e}") | |
| # --- Genanki Setup --- | |
| model_id = random.randrange(1 << 30, 1 << 31) | |
| my_model = genanki.Model( | |
| model_id, 'PocketTTS Model', | |
| fields=[{'name': 'Question'}, {'name': 'Answer'}], | |
| templates=[{ | |
| 'name': 'Card 1', | |
| 'qfmt': '{{Question}}', | |
| 'afmt': '{{FrontSide}}<hr id="answer">{{Answer}}', | |
| }]) | |
| my_deck = genanki.Deck(random.randrange(1 << 30, 1 << 31), 'Pocket TTS Deck') | |
| # --- Execution --- | |
| total = len(df) | |
| completed = 0 | |
| last_update_time = 0 | |
| with ThreadPoolExecutor(max_workers=MAX_WORKERS) as exe: | |
| futures = {} | |
| for idx, row in df.iterrows(): | |
| f = exe.submit(generate_audio_for_row, str(row['Question']), str(row['Answer']), idx, work_dir, mode) | |
| futures[f] = idx | |
| for future in as_completed(futures): | |
| idx = futures[future] | |
| try: | |
| q_res, a_res = future.result() | |
| # --- Field Construction (Corrected) --- | |
| q_original = str(df.loc[idx, 'Question']) | |
| q_field = q_original | |
| # Update Question | |
| if q_res and q_res != "SKIP": | |
| q_mp3 = str(Path(q_res).with_suffix('.mp3')) | |
| wav_to_mp3(q_res, q_mp3) | |
| os.remove(q_res) # clean wav | |
| media_files.append(q_mp3) | |
| # Remove OLD sound tags first to avoid duplicates | |
| q_field = re.sub(r'\[sound:.*?\]', '', q_field) | |
| q_field = q_field.strip() + f"<br>[sound:{os.path.basename(q_mp3)}]" | |
| # Update Answer | |
| a_original = str(df.loc[idx, 'Answer']) | |
| a_field = a_original | |
| if a_res and a_res != "SKIP": | |
| a_mp3 = str(Path(a_res).with_suffix('.mp3')) | |
| wav_to_mp3(a_res, a_mp3) | |
| os.remove(a_res) # clean wav | |
| media_files.append(a_mp3) | |
| # Remove OLD sound tags first | |
| a_field = re.sub(r'\[sound:.*?\]', '', a_field) | |
| a_field = a_field.strip() + f"<br>[sound:{os.path.basename(a_mp3)}]" | |
| # Add Note | |
| note = genanki.Note( | |
| model=my_model, | |
| fields=[q_field, a_field] | |
| ) | |
| my_deck.add_note(note) | |
| except Exception as e: | |
| print(f"Row {idx} failed: {e}") | |
| # --- Throttled Progress --- | |
| completed += 1 | |
| current_time = time.time() | |
| if completed == total or (current_time - last_update_time) > PROGRESS_THROTTLE: | |
| progress(completed / total, desc=f"Processed {completed}/{total}") | |
| last_update_time = current_time | |
| # --- Package --- | |
| package = genanki.Package(my_deck) | |
| # Deduplicate media files list | |
| package.media_files = list(set(media_files)) | |
| raw_out = os.path.join(work_dir, "output.apkg") | |
| package.write_to_file(raw_out) | |
| final_out = os.path.join(tempfile.gettempdir(), f"pocket_deck_{random.randint(1000,9999)}.apkg") | |
| shutil.copy(raw_out, final_out) | |
| return final_out, f"✅ Done! Packaged {len(package.media_files)} audio files." | |
| except Exception as e: | |
| return None, f"Critical Error: {str(e)}" | |
| finally: | |
| # --- Guaranteed Cleanup --- | |
| if os.path.exists(work_dir): | |
| shutil.rmtree(work_dir) | |
| # Also clean up the input extraction root if it exists | |
| if extract_root and os.path.exists(extract_root): | |
| shutil.rmtree(extract_root) | |
| # --- UI --- | |
| with gr.Blocks(title="Pocket TTS Anki") as app: | |
| gr.Markdown("## 🎴 Pocket TTS Anki Generator") | |
| gr.Markdown("Offline Neural Audio. Supports CSV and APKG (smart media preservation).") | |
| # State variables | |
| full_df_state = gr.State() | |
| extract_root_state = gr.State() # Holds path to unzipped APKG | |
| with gr.Row(): | |
| file_input = gr.File(label="Upload (CSV/APKG)", file_types=[".csv", ".apkg", ".zip"]) | |
| status = gr.Textbox(label="Status", interactive=False) | |
| eta_box = gr.Textbox(label="Est. Time", interactive=False) | |
| with gr.Row(): | |
| search_box = gr.Textbox(label="Search Text", placeholder="Enter text to search...") | |
| search_field = gr.Radio( | |
| choices=["Both", "Question Only", "Answer Only"], | |
| value="Both", | |
| label="Search In" | |
| ) | |
| tag_dropdown = gr.Dropdown( | |
| label="Filter by Tag", | |
| choices=["All"], | |
| value="All", | |
| interactive=True | |
| ) | |
| with gr.Row(): | |
| # New 3-Way Toggle | |
| mode_radio = gr.Radio( | |
| choices=[ | |
| "Smart Fill (Preserve Existing)", | |
| "Generate all new audio (Overwrite)", | |
| "Only generate missing (Same as Smart Fill)" | |
| ], | |
| value="Smart Fill (Preserve Existing)", | |
| label="Generation Mode" | |
| ) | |
| preview_table = gr.Dataframe( | |
| label="Preview (First 100)", | |
| interactive=False, | |
| column_widths=["30%", "45%", "25%"] | |
| ) | |
| with gr.Row(): | |
| btn = gr.Button("🚀 Generate Deck", variant="primary") | |
| dl = gr.File(label="Download") | |
| result_lbl = gr.Textbox(label="Result", interactive=False) | |
| has_media_state = gr.State(False) | |
| def on_upload(file): | |
| # Returns: df, has_media, preview, msg, eta, extract_path | |
| df, has_media, preview, msg, eta, ext_path = parse_file(file) | |
| # Extract tags and create cleaned preview | |
| tag_choices = extract_unique_tags(df) | |
| if df is not None: | |
| display_df = df.copy() | |
| display_df['Question'] = display_df['Question'].apply(strip_html_for_display) | |
| display_df['Answer'] = display_df['Answer'].apply(strip_html_for_display) | |
| clean_preview = display_df.head(PREVIEW_LIMIT) | |
| else: | |
| clean_preview = preview | |
| return ( | |
| df, # full_df_state | |
| has_media, # has_media_state | |
| clean_preview, # preview_table | |
| msg, # status | |
| eta, # eta_box | |
| ext_path, # extract_root_state | |
| gr.Dropdown(choices=tag_choices, value="All") # tag_dropdown | |
| ) | |
| file_input.upload(on_upload, inputs=file_input, | |
| outputs=[full_df_state, has_media_state, preview_table, status, eta_box, extract_root_state, tag_dropdown]) | |
| def on_clear(): | |
| """Reset all fields when file is cleared.""" | |
| return ( | |
| None, # full_df_state | |
| False, # has_media_state | |
| None, # preview_table | |
| "", # status | |
| "", # eta_box | |
| None, # extract_root_state | |
| gr.Dropdown(choices=["All"], value="All"), # tag_dropdown | |
| "", # search_box | |
| None, # dl (download file) | |
| "" # result_lbl | |
| ) | |
| file_input.clear(on_clear, inputs=[], | |
| outputs=[full_df_state, has_media_state, preview_table, status, eta_box, extract_root_state, tag_dropdown, search_box, dl, result_lbl]) | |
| def on_search(term, df, has_media, mode, search_in, selected_tag): | |
| if df is None: return None, "No data" | |
| filtered_df = df.copy() | |
| # Apply tag filter first | |
| if selected_tag and selected_tag != "All": | |
| filtered_df = filtered_df[filtered_df['Tags'].str.contains(selected_tag, na=False, case=False)] | |
| # Apply text search | |
| if term: | |
| if search_in == "Question Only": | |
| mask = filtered_df['Question'].str.contains(term, case=False, na=False) | |
| elif search_in == "Answer Only": | |
| mask = filtered_df['Answer'].str.contains(term, case=False, na=False) | |
| else: # Both | |
| mask = filtered_df.astype(str).apply(lambda x: x.str.contains(term, case=False, na=False)).any(axis=1) | |
| filtered_df = filtered_df[mask] | |
| # Create cleaned display version | |
| display_df = filtered_df.copy() | |
| display_df['Question'] = display_df['Question'].apply(strip_html_for_display) | |
| display_df['Answer'] = display_df['Answer'].apply(strip_html_for_display) | |
| return display_df.head(PREVIEW_LIMIT), estimate_time(len(filtered_df), has_media, mode) | |
| search_box.change(on_search, inputs=[search_box, full_df_state, has_media_state, mode_radio, search_field, tag_dropdown], | |
| outputs=[preview_table, eta_box]) | |
| search_field.change(on_search, inputs=[search_box, full_df_state, has_media_state, mode_radio, search_field, tag_dropdown], | |
| outputs=[preview_table, eta_box]) | |
| tag_dropdown.change(on_search, inputs=[search_box, full_df_state, has_media_state, mode_radio, search_field, tag_dropdown], | |
| outputs=[preview_table, eta_box]) | |
| mode_radio.change(on_search, inputs=[search_box, full_df_state, has_media_state, mode_radio, search_field, tag_dropdown], | |
| outputs=[preview_table, eta_box]) | |
| btn.click(process_dataframe, | |
| inputs=[full_df_state, search_box, extract_root_state, mode_radio, search_field, tag_dropdown], | |
| outputs=[dl, result_lbl]) | |
| if __name__ == "__main__": | |
| app.queue(max_size=2).launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| ssr_mode=False | |
| ) | |