"""Smart analysis - multi-model iterative processing.""" import os import gradio as gr from google import genai import utils from core.state import get_global_results, set_global_results from core.cache import get_cached_dataset, cache_dataset from core.comparison import select_best_model_result from ui.dashboard import generate_dashboard_outputs from gemini_api import GeminiIntegrator def run_smart_analysis( api_key: str, dataset_name: str, limit_files: int, temperature: float, thinking_budget: int, similarity_threshold: int, recheck_problematic: bool = False, progress=gr.Progress() ): global_results = get_global_results() # Robust type conversion for Gradio inputs limit_files = int(float(limit_files)) if limit_files else 0 thinking_budget = int(float(thinking_budget)) if thinking_budget else 0 similarity_threshold = int(float(similarity_threshold)) if similarity_threshold else 90 temperature = float(temperature) if not api_key: raise gr.Error("Калі ласка, увядзіце Gemini API ключ.") models = [ ("gemini-2.5-flash-lite", "Этап 1/4: Flash-Lite (першы праход)"), ("gemini-2.5-flash-lite", "Этап 2/4: Flash-Lite (другі праход)"), ("gemini-2.5-flash", "Этап 3/4: Flash"), ("gemini-3-flash-preview", "Этап 4/4: Gemini-3-Flash"), ] try: gemini_tool = GeminiIntegrator(api_key=api_key) config_args = {"temperature": temperature} gen_config = genai.types.GenerateContentConfig(**config_args) results = [] # STEP 1: Initialization / First Pass step_desc = models[0][1] model_name = models[0][0] if recheck_problematic: results = _smart_recheck_first_pass( gemini_tool, model_name, step_desc, dataset_name, limit_files, similarity_threshold, gen_config, progress ) else: results = _smart_fresh_first_pass( gemini_tool, model_name, step_desc, dataset_name, limit_files, similarity_threshold, gen_config, progress ) # STEP 2-4: Iterative improvement base_progress = 0.25 step_progress_size = 0.25 for step_idx in range(1, len(models)): model_name = models[step_idx][0] step_desc = models[step_idx][1] # Find items that are STILL problematic AND not verified correct problematic_indices = [ i for i, r in enumerate(results) if r['score'] < similarity_threshold and r.get('verification_status') != 'correct' ] if not problematic_indices: progress(base_progress + step_idx * step_progress_size, desc=f"{step_desc}: няма праблемных запісаў, прапускаем...") continue progress(base_progress + (step_idx - 1) * step_progress_size, desc=f"{step_desc}: пераапрацоўка {len(problematic_indices)} праблемных запісаў...") for j, res_idx in enumerate(problematic_indices): progress(base_progress + (step_idx - 1) * step_progress_size + (j + 1) / len(problematic_indices) * step_progress_size, desc=f"{step_desc}: запіс {j+1}/{len(problematic_indices)}") result = results[res_idx] audio_data = result.get('audio_array') sampling_rate = result.get('sampling_rate') ref_text = result.get('ref_text', "") if audio_data is None or len(audio_data) == 0: continue hyp_text = gemini_tool.transcribe_audio(model_name, audio_data, sampling_rate, config=gen_config) score, norm_ref, norm_hyp = utils.calculate_similarity(ref_text, hyp_text) # Save model result if 'model_results' not in results[res_idx]: results[res_idx]['model_results'] = {} results[res_idx]['model_results'][model_name] = { "hyp_text": hyp_text, "score": score, "norm_ref": norm_ref, "norm_hyp": norm_hyp } # Select best result best_model, best_result = select_best_model_result( results[res_idx]['model_results'], similarity_threshold ) if best_result and (best_result['score'] > result['score'] or best_result['score'] >= similarity_threshold): new_status = "correct" if best_result['score'] >= similarity_threshold else "incorrect" print(f"✅ UPDATE APPLIED [Idx={res_idx}]: {result.get('path')} | Best model: {best_model} | Score: {result['score']} -> {best_result['score']}") results[res_idx].update({ "hyp_text": best_result['hyp_text'], "score": best_result['score'], "norm_ref": best_result['norm_ref'], "norm_hyp": best_result['norm_hyp'], "model_used": best_model, "verification_status": new_status }) else: print(f"⏭️ SKIP UPDATE [Idx={res_idx}]: Best score {best_result['score'] if best_result else 'N/A'} is not better than {result.get('score')} and not meeting threshold {similarity_threshold}") set_global_results(results) return generate_dashboard_outputs(similarity_threshold) except Exception as e: raise gr.Error(f"Памылка: {e}") def _smart_recheck_first_pass( gemini_tool, model_name, step_desc, dataset_name, limit_files, similarity_threshold, gen_config, progress ): """First pass for recheck mode.""" global_results = get_global_results() if not global_results: gr.Warning("Няма вынікаў для пераправеркі.") return [] results = global_results # Identify start set: only problematic items problematic_indices = [ i for i, r in enumerate(results) if r['score'] < similarity_threshold and r.get('verification_status') != 'correct' ] if limit_files > 0: problematic_indices = problematic_indices[:limit_files] if not problematic_indices: gr.Info("Няма праблемных файлаў для пераправеркі.") return results # Load dataset to get audio for files that might be missing it limit = None cached_ds = get_cached_dataset(dataset_name, limit) if cached_ds is not None: progress(0, desc=f"Выкарыстоўваю закэшаваны датасет '{dataset_name}'...") ds = cached_ds else: progress(0, desc=f"Загрузка датасета '{dataset_name}'...") ds = utils.load_hf_dataset(dataset_name, limit=limit) cache_dataset(dataset_name, limit, ds) progress(0.03, desc=f"Датасет закэшаваны") # Build audio map by filename audio_map = {} for item in ds: path = item['audio']['path'] if path: fname = os.path.basename(path) audio_map[fname] = item audio_map[path] = item progress(0.05, desc=f"{step_desc}: пераправерка {len(problematic_indices)} запісаў...") for j, res_idx in enumerate(problematic_indices): progress(0.05 + (j + 1) / len(problematic_indices) * 0.20, desc=f"{step_desc}: запіс {j+1}/{len(problematic_indices)}") result = results[res_idx] audio_data = result.get('audio_array') sampling_rate = result.get('sampling_rate') ref_text = result.get('ref_text', "") # If audio is missing, try to fetch from dataset if audio_data is None or len(audio_data) == 0: path = result.get('path', '') item = audio_map.get(path) or audio_map.get(os.path.basename(path)) if not item: rec_id = result.get('id') if rec_id is not None: try: rec_id = int(rec_id) if 0 <= rec_id < len(ds): item = ds[rec_id] except: pass if item: audio_data = item['audio']['array'] sampling_rate = item['audio']['sampling_rate'] results[res_idx]['audio_array'] = audio_data results[res_idx]['sampling_rate'] = sampling_rate else: print(f"Smart Analysis Recheck: Skipping index {res_idx}, path '{path}', id {result.get('id')}: Audio not found.") continue hyp_text = gemini_tool.transcribe_audio(model_name, audio_data, sampling_rate, config=gen_config) score, norm_ref, norm_hyp = utils.calculate_similarity(ref_text, hyp_text) print(f"🔄 Smart Updated (Step 1): {result.get('path')} | Score: {result.get('score')} -> {score} | Text: {hyp_text}") if 'model_results' not in results[res_idx]: results[res_idx]['model_results'] = {} results[res_idx]['model_results'][model_name] = { "hyp_text": hyp_text, "score": score, "norm_ref": norm_ref, "norm_hyp": norm_hyp } best_model, best_result = select_best_model_result( results[res_idx]['model_results'], similarity_threshold ) if best_result: results[res_idx].update({ "hyp_text": best_result['hyp_text'], "score": best_result['score'], "norm_ref": best_result['norm_ref'], "norm_hyp": best_result['norm_hyp'], "model_used": best_model, "verification_status": "correct" if best_result['score'] >= similarity_threshold else "incorrect" }) return results def _smart_fresh_first_pass( gemini_tool, model_name, step_desc, dataset_name, limit_files, similarity_threshold, gen_config, progress ): """First pass for fresh analysis.""" limit = int(limit_files) if limit_files > 0 else None cached_ds = get_cached_dataset(dataset_name, limit) if cached_ds is not None: progress(0, desc=f"Выкарыстоўваю закэшаваны датасет '{dataset_name}'...") ds = cached_ds else: progress(0, desc=f"Загрузка датасета '{dataset_name}'...") ds = utils.load_hf_dataset(dataset_name, limit=limit) cache_dataset(dataset_name, limit, ds) progress(0.05, desc=f"Датасет закэшаваны для паўторнага выкарыстання") results = [] progress(0.05, desc=f"{step_desc}: апрацоўка ўсіх {len(ds)} запісаў...") for idx, item in enumerate(ds): progress(0.05 + (idx + 1) / len(ds) * 0.20, desc=f"{step_desc}: файл {idx+1}/{len(ds)}") audio_data = item['audio']['array'] sampling_rate = item['audio']['sampling_rate'] ref_text = item.get('sentence') or item.get('text') or item.get('transcription') or item.get('transcript') or "" hyp_text = gemini_tool.transcribe_audio(model_name, audio_data, sampling_rate, config=gen_config) score, norm_ref, norm_hyp = utils.calculate_similarity(ref_text, hyp_text) results.append({ "id": idx, "path": item['audio']['path'], "ref_text": ref_text, "hyp_text": hyp_text, "score": score, "norm_ref": norm_ref, "norm_hyp": norm_hyp, "audio_array": audio_data, "sampling_rate": sampling_rate, "model_used": model_name, "verification_status": "correct" if score >= similarity_threshold else "incorrect", "model_results": { model_name: { "hyp_text": hyp_text, "score": score, "norm_ref": norm_ref, "norm_hyp": norm_hyp } } }) return results