# pragent/backend/blog_pipeline.py from tqdm.asyncio import tqdm import asyncio from pathlib import Path from typing import Tuple, List, Dict, Optional from openai import AsyncOpenAI import re import os import json # ADDED FOR OCR & CACHE SAFETY: New imports for OCR import pytesseract from PIL import Image import asyncio from pragent.backend.agents import setup_client, BlogGeneratorAgent, FigureDescriberAgent, BlogIntegratorAgent, call_text_llm_api,call_text_llm_api_with_token_count from pragent.backend.data_loader import load_plain_text, load_paired_image_paths from pragent.backend.text_processor import summarize_long_text from .prompts import ( TEXT_GENERATOR_PROMPT, TEXT_GENERATOR_PROMPT_CHINESE, TWITTER_RICH_TEXT_PROMPT_ENGLISH, TWITTER_TEXT_ONLY_PROMPT_ENGLISH, TWITTER_RICH_TEXT_PROMPT_CHINESE, TWITTER_TEXT_ONLY_PROMPT_CHINESE, XIAOHONGSHU_PROMPT_ENGLISH, XIAOHONGSHU_PROMPT_CHINESE, XIAOHONGSHU_TEXT_ONLY_PROMPT_ENGLISH, XIAOHONGSHU_TEXT_ONLY_PROMPT_CHINESE, BASELINE_PROMPT_ENGLISH, BASELINE_PROMPT_CHINESE, GENERIC_RICH_PROMPT_CHINESE,GENERIC_RICH_PROMPT_ENGLISH, GENERIC_TEXT_ONLY_PROMPT_CHINESE,GENERIC_TEXT_ONLY_PROMPT_ENGLISH, BASELINE_FEWSHOT_PROMPT_ENGLISH, BASELINE_FEWSHOT_PROMPT_CHINESE ) TOKEN_THRESHOLD = 8000 PROMPT_MAPPING = { ('twitter', 'rich', 'en'): TWITTER_RICH_TEXT_PROMPT_ENGLISH, ('twitter', 'text_only', 'en'): TWITTER_TEXT_ONLY_PROMPT_ENGLISH, ('twitter', 'rich', 'zh'): TWITTER_RICH_TEXT_PROMPT_CHINESE, ('twitter', 'text_only', 'zh'): TWITTER_TEXT_ONLY_PROMPT_CHINESE, ('xiaohongshu', 'rich', 'en'): XIAOHONGSHU_PROMPT_ENGLISH, ('xiaohongshu', 'rich', 'zh'): XIAOHONGSHU_PROMPT_CHINESE, ('xiaohongshu', 'text_only', 'en'): XIAOHONGSHU_TEXT_ONLY_PROMPT_ENGLISH, ('xiaohongshu', 'text_only', 'zh'): XIAOHONGSHU_TEXT_ONLY_PROMPT_CHINESE, ('generic', 'rich', 'en'): GENERIC_RICH_PROMPT_ENGLISH, ('generic', 'text_only', 'en'): GENERIC_TEXT_ONLY_PROMPT_ENGLISH, ('generic', 'rich', 'zh'): GENERIC_RICH_PROMPT_CHINESE, ('generic', 'text_only', 'zh'): GENERIC_TEXT_ONLY_PROMPT_CHINESE, } # ADDED FOR OCR & CACHE SAFETY: Asynchronous OCR helper function async def ocr_image_to_text(image_path: str) -> str: """ Performs OCR on an image file to extract text asynchronously. """ if not Path(image_path).exists(): return "" try: # pytesseract is a blocking library, so we run it in a thread pool loop = asyncio.get_running_loop() text = await loop.run_in_executor( None, lambda: pytesseract.image_to_string(Image.open(image_path)) ) return text.strip() except Exception as e: tqdm.write(f"[!] OCR failed for {image_path}: {e}") return "" async def generate_text_blog( txt_path: str, api_key: str, text_api_base: str, model: str, language: str, disable_qwen_thinking: bool = False, ablation_mode: str = "none" ) -> Tuple[str, str]: """ Generates a structured, factual blog DRAFT in the specified language. (Stage 1) """ async with setup_client(api_key, text_api_base) as client: if not client: return "Error: API client configuration failed.", None paper_text = await load_plain_text(txt_path) if not paper_text: return "Error: Could not load text file.", None text_for_generation = "" if len(paper_text) > TOKEN_THRESHOLD: if ablation_mode == 'no_hierarchical_summary': tqdm.write(f"[*] ABLATION (no_hierarchical_summary): Truncating text to {TOKEN_THRESHOLD} characters.") text_for_generation = paper_text[:TOKEN_THRESHOLD] else: summarized_text = await summarize_long_text( paper_text, model, client, disable_qwen_thinking=disable_qwen_thinking ) if summarized_text.startswith("Error:"): summarized_text = paper_text[:TOKEN_THRESHOLD] text_for_generation = summarized_text else: text_for_generation = paper_text if ablation_mode in ['no_logical_draft', 'stage2']: ablation_reason = "no_logical_draft" if ablation_mode != 'stage2' else 'stage2' tqdm.write(f"[*] ABLATION ({ablation_reason}): Skipping structured draft generation.") return text_for_generation, text_for_generation draft_prompt = TEXT_GENERATOR_PROMPT_CHINESE if language == 'zh' else TEXT_GENERATOR_PROMPT generator = BlogGeneratorAgent(draft_prompt, model) generated_blog_draft = await generator.run( client, text_for_generation, disable_qwen_thinking=disable_qwen_thinking ) return generated_blog_draft, text_for_generation async def generate_final_post( blog_draft: str, source_paper_text: str, assets_dir: Optional[str], text_api_key: str, vision_api_key: str, text_api_base: str, vision_api_base: str, vision_model: str, text_model: str, platform: str, language: str, post_format: str, description_cache_dir: Optional[str] = None, pdf_hash: Optional[str] = None, disable_qwen_thinking: bool = False, ablation_mode: str = "none" ) -> Optional[Tuple[str, Optional[List[Dict]]]]: effective_platform = platform if ablation_mode == 'no_platform_adaptation': tqdm.write(f"[*] ABLATION (no_platform_adaptation): Using generic prompts instead of '{platform}' specific prompts.") effective_platform = 'generic' prompt_format = 'rich' if post_format == 'description_only' else post_format prompt_key = (effective_platform, prompt_format, language) selected_prompt = PROMPT_MAPPING.get(prompt_key) if not selected_prompt: tqdm.write(f"[!] Warning: No prompt found for configuration: {prompt_key}. Falling back to generic prompt.") generic_fallback_key = ('generic', prompt_format, language) selected_prompt = PROMPT_MAPPING.get(generic_fallback_key) if not selected_prompt: return f"Error: No prompt found for configuration: {prompt_key} or generic fallback.", None tqdm.write(f"\n--- Generating final post for: Platform='{effective_platform}', Format='{post_format}', Language='{language}' ---") items_with_descriptions = [] if post_format in ['rich', 'description_only'] and assets_dir and Path(assets_dir).is_dir(): all_items = load_paired_image_paths(Path(assets_dir)) all_items = all_items[:50] # Limit to first 50 items to avoid overloading the model if all_items: cache_file_path = None if description_cache_dir and pdf_hash: sanitized_model_name = re.sub(r'[\\/:"*?<>|]', '_', vision_model) cache_dir = Path(description_cache_dir) / pdf_hash cache_dir.mkdir(parents=True, exist_ok=True) cache_file_path = cache_dir / f"{sanitized_model_name}.json" if cache_file_path and cache_file_path.exists() and ablation_mode not in ['no_visual_analysis', 'stage2']: tqdm.write(f"[✓] Cache hit! Loading all descriptions from {cache_file_path}") with cache_file_path.open('r', encoding='utf-8') as f: items_with_descriptions = json.load(f) else: # MODIFIED: Trigger this ablation also for 'stage2' if ablation_mode in ['no_visual_analysis', 'stage2']: ablation_reason = "no_visual_analysis" if ablation_mode != 'stage2' else 'stage2' tqdm.write(f"[*] ABLATION ({ablation_reason}): Using OCR on caption images instead of vision model.") temp_items_with_desc = [] ocr_tasks = [ocr_image_to_text(item['caption_path']) for item in all_items] ocr_results = await asyncio.gather(*ocr_tasks) for i, item in enumerate(all_items): caption_content = ocr_results[i] if caption_content: item['description'] = caption_content temp_items_with_desc.append(item) items_with_descriptions = temp_items_with_desc else: # Full pipeline: use vision model tqdm.write(f"--- Cache miss. Describing {len(all_items)} new figures using model '{vision_model}'... ---") async with setup_client(vision_api_key, vision_api_base) as vision_client: if not vision_client: return "Error: Vision API client configuration failed.", None describer = FigureDescriberAgent(model=vision_model) description_tasks = [ describer.run( vision_client, item['item_path'], item['caption_path'], disable_qwen_thinking=disable_qwen_thinking ) for item in all_items ] descriptions = await asyncio.gather(*description_tasks) temp_items_with_desc = [] for i, item in enumerate(all_items): if not descriptions[i].startswith("Error:"): item['description'] = descriptions[i] temp_items_with_desc.append(item) items_with_descriptions = temp_items_with_desc # MODIFIED: Prevent caching for 'stage2' as well if cache_file_path and ablation_mode not in ['no_visual_analysis', 'stage2']: tqdm.write(f"[*] Saving all descriptions to cache file: {cache_file_path}") with cache_file_path.open('w', encoding='utf-8') as f: json.dump(items_with_descriptions, f, ensure_ascii=False, indent=4) elif cache_file_path and ablation_mode in ['no_visual_analysis', 'stage2']: ablation_reason = "no_visual_analysis" if ablation_mode != 'stage2' else 'stage2' tqdm.write(f"[*] ABLATION ({ablation_reason}): Description caching is disabled for this mode to avoid saving OCR results.") items_with_descriptions = items_with_descriptions[:20] if post_format in ['rich', 'description_only'] and not items_with_descriptions: return f"Error: '{post_format}' format requires images, but none were found/described.", None async with setup_client(text_api_key, text_api_base) as text_client: if not text_client: return "Error: Text API client configuration failed.", None if ablation_mode in ['no_visual_integration', 'stage2'] and post_format in ['rich', 'description_only']: ablation_reason = "no_visual_integration" if ablation_mode != 'stage2' else 'stage2' tqdm.write(f"[*] ABLATION ({ablation_reason}): Generating text first, then appending all figures at the end.") integrator = BlogIntegratorAgent(selected_prompt, model=text_model) text_only_post = await integrator.run( local_client=text_client, blog_text=blog_draft, items_with_descriptions=[], source_text=source_paper_text, disable_qwen_thinking=disable_qwen_thinking ) if not text_only_post or text_only_post.startswith("Error:"): return f"Blog integration failed for text-only part: {text_only_post}", None final_blog_content = text_only_post assets_for_packaging = [] for i, item_data in enumerate(items_with_descriptions): if post_format == 'rich': new_asset_filename = f"img_{i}{Path(item_data['item_path']).suffix}" alt_text = f"Figure {i}" new_markdown_tag = f"\n\n![{alt_text}](./img/{new_asset_filename})" assets_for_packaging.append({'src_path': item_data['item_path'], 'dest_name': new_asset_filename, 'new_index': i}) final_blog_content += new_markdown_tag elif post_format == 'description_only': alt_text_description = item_data.get('description', f'Figure {i}').strip().replace('\n', ' ') new_markdown_tag = f"\n\n![{alt_text_description}]()" final_blog_content += new_markdown_tag return final_blog_content, assets_for_packaging if assets_for_packaging else None integrator = BlogIntegratorAgent(selected_prompt, model=text_model) final_post_with_placeholders = await integrator.run( local_client=text_client, blog_text=blog_draft, items_with_descriptions=items_with_descriptions, source_text=source_paper_text, disable_qwen_thinking=disable_qwen_thinking ) if not final_post_with_placeholders or final_post_with_placeholders.startswith("Error:"): return f"Blog integration failed: {final_post_with_placeholders}", None found_indices = re.findall(r'\[FIGURE_PLACEHOLDER_(\d+)\]', final_post_with_placeholders) final_blog_content = final_post_with_placeholders assets_for_packaging = [] if found_indices: items_map = {i: item for i, item in enumerate(items_with_descriptions)} for new_index, original_index_str in enumerate(found_indices): original_index = int(original_index_str) item_data = items_map.get(original_index) if not item_data: continue placeholder_to_replace = f"[FIGURE_PLACEHOLDER_{original_index}]" if post_format == 'rich': new_asset_filename = f"img_{new_index}{Path(item_data['item_path']).suffix}" alt_text = f"Figure {new_index}" new_markdown_tag = f"![{alt_text}](./img/{new_asset_filename})" assets_for_packaging.append({'src_path': item_data['item_path'], 'dest_name': new_asset_filename, 'new_index': new_index}) elif post_format == 'description_only': alt_text_description = item_data.get('description', f'Figure {new_index}').strip().replace('\n', ' ') new_markdown_tag = f"![{alt_text_description}]()" else: new_markdown_tag = "" final_blog_content = final_blog_content.replace(placeholder_to_replace, new_markdown_tag, 1) final_blog_content = re.sub(r'\[FIGURE_PLACEHOLDER_(\d+)\]', '', final_blog_content) if post_format == 'rich': return final_blog_content, assets_for_packaging else: return final_blog_content, None async def generate_baseline_post( paper_text: str, api_key: str, api_base: str, model: str, platform: str, language: str, disable_qwen_thinking: bool = False, mode: str = 'original', assets_dir: Optional[str] = None ) -> Tuple[str, List[Dict], int]: """ Generates a post using a simple, single-prompt baseline method. """ tqdm.write(f"\n--- Generating baseline post (mode: {mode}) for: Platform='{platform}', Language='{language}' ---") async with setup_client(api_key, api_base) as client: if not client: return "Error: API client configuration failed.", [], 0 if mode == 'fewshot': prompt_template = BASELINE_FEWSHOT_PROMPT_CHINESE if language == 'zh' else BASELINE_FEWSHOT_PROMPT_ENGLISH else: prompt_template = BASELINE_PROMPT_CHINESE if language == 'zh' else BASELINE_PROMPT_ENGLISH user_prompt = prompt_template.format(paper_text=paper_text[:20000], platform=platform.capitalize()) system_prompt = "You are an assistant that summarizes academic papers for social media." text_post, think_token_count = await call_text_llm_api_with_token_count( local_client=client, system_prompt=system_prompt, user_prompt=user_prompt, model=model, disable_qwen_thinking=disable_qwen_thinking ) if text_post.startswith("Error:"): return text_post, [], think_token_count final_post = text_post assets_for_packaging = [] if mode == 'with_figure' and assets_dir and Path(assets_dir).is_dir(): tqdm.write(f"[*] Attaching top 3 figures/tables for 'with_figure' baseline...") paired_item_dirs = [ d for d in Path(assets_dir).rglob('paired_*') if d.is_dir() and (d.name.startswith('paired_figure_') or d.name.startswith('paired_table_')) ] def get_global_sort_key(dir_path: Path): page_num = -1 item_type = '' item_index = -1 try: page_match = re.search(r'page_(\d+)', dir_path.parts[-2]) if page_match: page_num = int(page_match.group(1)) except (IndexError, ValueError): pass item_match = re.search(r'paired_(figure|table)_(\d+)', dir_path.name) if item_match: item_type = item_match.group(1) item_index = int(item_match.group(2)) return (page_num, item_index) sorted_dirs = sorted(paired_item_dirs, key=get_global_sort_key) all_items = [] for item_dir in sorted_dirs: item_type = 'figure' if 'figure' in item_dir.name else 'table' item_file = next( (f for f in item_dir.iterdir() if f.is_file() and f.name.startswith(item_type) and 'caption' not in f.name), None ) if item_file: all_items.append(item_file) selected_items = all_items[:3] if selected_items: final_post += "\n\n--- Key Figures & Tables ---\n" for i, item_path in enumerate(selected_items): new_asset_filename = f"img_{i}{item_path.suffix}" alt_text = "Table" if "table" in item_path.parent.name else "Figure" alt_text += f" {i+1}" final_post += f"\n![{alt_text}](./img/{new_asset_filename})" assets_for_packaging.append({'src_path': str(item_path), 'dest_name': new_asset_filename}) tqdm.write(f"[✓] Appended {len(selected_items)} items (figures/tables) to the post.") else: tqdm.write("[!] Warning: 'with_figure' mode was selected, but no paired items were found.") return final_post, assets_for_packaging, think_token_count