Spaces:
Sleeping
Sleeping
| """Smart analysis - multi-model iterative processing.""" | |
| import os | |
| import gradio as gr | |
| from google import genai | |
| import utils | |
| from core.state import get_global_results, set_global_results | |
| from core.cache import get_cached_dataset, cache_dataset | |
| from core.comparison import select_best_model_result | |
| from ui.dashboard import generate_dashboard_outputs | |
| from gemini_api import GeminiIntegrator | |
| def run_smart_analysis( | |
| api_key: str, | |
| dataset_name: str, | |
| limit_files: int, | |
| temperature: float, | |
| thinking_budget: int, | |
| similarity_threshold: int, | |
| recheck_problematic: bool = False, | |
| progress=gr.Progress() | |
| ): | |
| global_results = get_global_results() | |
| # Robust type conversion for Gradio inputs | |
| limit_files = int(float(limit_files)) if limit_files else 0 | |
| thinking_budget = int(float(thinking_budget)) if thinking_budget else 0 | |
| similarity_threshold = int(float(similarity_threshold)) if similarity_threshold else 90 | |
| temperature = float(temperature) | |
| if not api_key: | |
| raise gr.Error("Калі ласка, увядзіце Gemini API ключ.") | |
| models = [ | |
| ("gemini-2.5-flash-lite", "Этап 1/4: Flash-Lite (першы праход)"), | |
| ("gemini-2.5-flash-lite", "Этап 2/4: Flash-Lite (другі праход)"), | |
| ("gemini-2.5-flash", "Этап 3/4: Flash"), | |
| ("gemini-3-flash-preview", "Этап 4/4: Gemini-3-Flash"), | |
| ] | |
| try: | |
| gemini_tool = GeminiIntegrator(api_key=api_key) | |
| config_args = {"temperature": temperature} | |
| gen_config = genai.types.GenerateContentConfig(**config_args) | |
| results = [] | |
| # STEP 1: Initialization / First Pass | |
| step_desc = models[0][1] | |
| model_name = models[0][0] | |
| if recheck_problematic: | |
| results = _smart_recheck_first_pass( | |
| gemini_tool, model_name, step_desc, dataset_name, | |
| limit_files, similarity_threshold, gen_config, progress | |
| ) | |
| else: | |
| results = _smart_fresh_first_pass( | |
| gemini_tool, model_name, step_desc, dataset_name, | |
| limit_files, similarity_threshold, gen_config, progress | |
| ) | |
| # STEP 2-4: Iterative improvement | |
| base_progress = 0.25 | |
| step_progress_size = 0.25 | |
| for step_idx in range(1, len(models)): | |
| model_name = models[step_idx][0] | |
| step_desc = models[step_idx][1] | |
| # Find items that are STILL problematic AND not verified correct | |
| problematic_indices = [ | |
| i for i, r in enumerate(results) | |
| if r['score'] < similarity_threshold | |
| and r.get('verification_status') != 'correct' | |
| ] | |
| if not problematic_indices: | |
| progress(base_progress + step_idx * step_progress_size, | |
| desc=f"{step_desc}: няма праблемных запісаў, прапускаем...") | |
| continue | |
| progress(base_progress + (step_idx - 1) * step_progress_size, | |
| desc=f"{step_desc}: пераапрацоўка {len(problematic_indices)} праблемных запісаў...") | |
| for j, res_idx in enumerate(problematic_indices): | |
| progress(base_progress + (step_idx - 1) * step_progress_size + (j + 1) / len(problematic_indices) * step_progress_size, | |
| desc=f"{step_desc}: запіс {j+1}/{len(problematic_indices)}") | |
| result = results[res_idx] | |
| audio_data = result.get('audio_array') | |
| sampling_rate = result.get('sampling_rate') | |
| ref_text = result.get('ref_text', "") | |
| if audio_data is None or len(audio_data) == 0: | |
| continue | |
| hyp_text = gemini_tool.transcribe_audio(model_name, audio_data, sampling_rate, config=gen_config) | |
| score, norm_ref, norm_hyp = utils.calculate_similarity(ref_text, hyp_text) | |
| # Save model result | |
| if 'model_results' not in results[res_idx]: | |
| results[res_idx]['model_results'] = {} | |
| results[res_idx]['model_results'][model_name] = { | |
| "hyp_text": hyp_text, | |
| "score": score, | |
| "norm_ref": norm_ref, | |
| "norm_hyp": norm_hyp | |
| } | |
| # Select best result | |
| best_model, best_result = select_best_model_result( | |
| results[res_idx]['model_results'], | |
| similarity_threshold | |
| ) | |
| if best_result and (best_result['score'] > result['score'] or best_result['score'] >= similarity_threshold): | |
| new_status = "correct" if best_result['score'] >= similarity_threshold else "incorrect" | |
| print(f"✅ UPDATE APPLIED [Idx={res_idx}]: {result.get('path')} | Best model: {best_model} | Score: {result['score']} -> {best_result['score']}") | |
| results[res_idx].update({ | |
| "hyp_text": best_result['hyp_text'], | |
| "score": best_result['score'], | |
| "norm_ref": best_result['norm_ref'], | |
| "norm_hyp": best_result['norm_hyp'], | |
| "model_used": best_model, | |
| "verification_status": new_status | |
| }) | |
| else: | |
| print(f"⏭️ SKIP UPDATE [Idx={res_idx}]: Best score {best_result['score'] if best_result else 'N/A'} is not better than {result.get('score')} and not meeting threshold {similarity_threshold}") | |
| set_global_results(results) | |
| return generate_dashboard_outputs(similarity_threshold) | |
| except Exception as e: | |
| raise gr.Error(f"Памылка: {e}") | |
| def _smart_recheck_first_pass( | |
| gemini_tool, model_name, step_desc, dataset_name, | |
| limit_files, similarity_threshold, gen_config, progress | |
| ): | |
| """First pass for recheck mode.""" | |
| global_results = get_global_results() | |
| if not global_results: | |
| gr.Warning("Няма вынікаў для пераправеркі.") | |
| return [] | |
| results = global_results | |
| # Identify start set: only problematic items | |
| problematic_indices = [ | |
| i for i, r in enumerate(results) | |
| if r['score'] < similarity_threshold | |
| and r.get('verification_status') != 'correct' | |
| ] | |
| if limit_files > 0: | |
| problematic_indices = problematic_indices[:limit_files] | |
| if not problematic_indices: | |
| gr.Info("Няма праблемных файлаў для пераправеркі.") | |
| return results | |
| # Load dataset to get audio for files that might be missing it | |
| limit = None | |
| cached_ds = get_cached_dataset(dataset_name, limit) | |
| if cached_ds is not None: | |
| progress(0, desc=f"Выкарыстоўваю закэшаваны датасет '{dataset_name}'...") | |
| ds = cached_ds | |
| else: | |
| progress(0, desc=f"Загрузка датасета '{dataset_name}'...") | |
| ds = utils.load_hf_dataset(dataset_name, limit=limit) | |
| cache_dataset(dataset_name, limit, ds) | |
| progress(0.03, desc=f"Датасет закэшаваны") | |
| # Build audio map by filename | |
| audio_map = {} | |
| for item in ds: | |
| path = item['audio']['path'] | |
| if path: | |
| fname = os.path.basename(path) | |
| audio_map[fname] = item | |
| audio_map[path] = item | |
| progress(0.05, desc=f"{step_desc}: пераправерка {len(problematic_indices)} запісаў...") | |
| for j, res_idx in enumerate(problematic_indices): | |
| progress(0.05 + (j + 1) / len(problematic_indices) * 0.20, desc=f"{step_desc}: запіс {j+1}/{len(problematic_indices)}") | |
| result = results[res_idx] | |
| audio_data = result.get('audio_array') | |
| sampling_rate = result.get('sampling_rate') | |
| ref_text = result.get('ref_text', "") | |
| # If audio is missing, try to fetch from dataset | |
| if audio_data is None or len(audio_data) == 0: | |
| path = result.get('path', '') | |
| item = audio_map.get(path) or audio_map.get(os.path.basename(path)) | |
| if not item: | |
| rec_id = result.get('id') | |
| if rec_id is not None: | |
| try: | |
| rec_id = int(rec_id) | |
| if 0 <= rec_id < len(ds): | |
| item = ds[rec_id] | |
| except: | |
| pass | |
| if item: | |
| audio_data = item['audio']['array'] | |
| sampling_rate = item['audio']['sampling_rate'] | |
| results[res_idx]['audio_array'] = audio_data | |
| results[res_idx]['sampling_rate'] = sampling_rate | |
| else: | |
| print(f"Smart Analysis Recheck: Skipping index {res_idx}, path '{path}', id {result.get('id')}: Audio not found.") | |
| continue | |
| hyp_text = gemini_tool.transcribe_audio(model_name, audio_data, sampling_rate, config=gen_config) | |
| score, norm_ref, norm_hyp = utils.calculate_similarity(ref_text, hyp_text) | |
| print(f"🔄 Smart Updated (Step 1): {result.get('path')} | Score: {result.get('score')} -> {score} | Text: {hyp_text}") | |
| if 'model_results' not in results[res_idx]: | |
| results[res_idx]['model_results'] = {} | |
| results[res_idx]['model_results'][model_name] = { | |
| "hyp_text": hyp_text, | |
| "score": score, | |
| "norm_ref": norm_ref, | |
| "norm_hyp": norm_hyp | |
| } | |
| best_model, best_result = select_best_model_result( | |
| results[res_idx]['model_results'], | |
| similarity_threshold | |
| ) | |
| if best_result: | |
| results[res_idx].update({ | |
| "hyp_text": best_result['hyp_text'], | |
| "score": best_result['score'], | |
| "norm_ref": best_result['norm_ref'], | |
| "norm_hyp": best_result['norm_hyp'], | |
| "model_used": best_model, | |
| "verification_status": "correct" if best_result['score'] >= similarity_threshold else "incorrect" | |
| }) | |
| return results | |
| def _smart_fresh_first_pass( | |
| gemini_tool, model_name, step_desc, dataset_name, | |
| limit_files, similarity_threshold, gen_config, progress | |
| ): | |
| """First pass for fresh analysis.""" | |
| limit = int(limit_files) if limit_files > 0 else None | |
| cached_ds = get_cached_dataset(dataset_name, limit) | |
| if cached_ds is not None: | |
| progress(0, desc=f"Выкарыстоўваю закэшаваны датасет '{dataset_name}'...") | |
| ds = cached_ds | |
| else: | |
| progress(0, desc=f"Загрузка датасета '{dataset_name}'...") | |
| ds = utils.load_hf_dataset(dataset_name, limit=limit) | |
| cache_dataset(dataset_name, limit, ds) | |
| progress(0.05, desc=f"Датасет закэшаваны для паўторнага выкарыстання") | |
| results = [] | |
| progress(0.05, desc=f"{step_desc}: апрацоўка ўсіх {len(ds)} запісаў...") | |
| for idx, item in enumerate(ds): | |
| progress(0.05 + (idx + 1) / len(ds) * 0.20, desc=f"{step_desc}: файл {idx+1}/{len(ds)}") | |
| audio_data = item['audio']['array'] | |
| sampling_rate = item['audio']['sampling_rate'] | |
| ref_text = item.get('sentence') or item.get('text') or item.get('transcription') or item.get('transcript') or "" | |
| hyp_text = gemini_tool.transcribe_audio(model_name, audio_data, sampling_rate, config=gen_config) | |
| score, norm_ref, norm_hyp = utils.calculate_similarity(ref_text, hyp_text) | |
| results.append({ | |
| "id": idx, | |
| "path": item['audio']['path'], | |
| "ref_text": ref_text, | |
| "hyp_text": hyp_text, | |
| "score": score, | |
| "norm_ref": norm_ref, | |
| "norm_hyp": norm_hyp, | |
| "audio_array": audio_data, | |
| "sampling_rate": sampling_rate, | |
| "model_used": model_name, | |
| "verification_status": "correct" if score >= similarity_threshold else "incorrect", | |
| "model_results": { | |
| model_name: { | |
| "hyp_text": hyp_text, | |
| "score": score, | |
| "norm_ref": norm_ref, | |
| "norm_hyp": norm_hyp | |
| } | |
| } | |
| }) | |
| return results | |