#!/usr/bin/env python3 """ Flask Web Application for Article Summarizer with TTS Enhanced with caching, performance optimizations, and better error handling """ from flask import Flask, render_template, request, jsonify import os import time import threading import logging from datetime import datetime import re from pathlib import Path import hashlib import json from functools import lru_cache import gc import torch import trafilatura import soundfile as sf import requests from transformers import AutoModelForCausalLM, AutoTokenizer from kokoro import KPipeline # ---------------- Logging ---------------- logging.basicConfig(level=logging.INFO) logger = logging.getLogger("summarizer") # ---------------- Flask ---------------- app = Flask(__name__) app.config["SECRET_KEY"] = os.environ.get("SECRET_KEY", "change-me") # ---------------- Caching & Performance ---------------- # In-memory caches for better performance _summary_cache = {} # URL/text hash -> summary _audio_cache = {} # summary hash + voice -> audio filename _scrape_cache = {} # URL -> scraped content _cache_lock = threading.Lock() # Cache settings MAX_CACHE_SIZE = 100 CACHE_EXPIRY_HOURS = 24 def _get_cache_key(content: str) -> str: """Generate a cache key from content.""" return hashlib.md5(content.encode('utf-8')).hexdigest() def _is_cache_expired(timestamp: float) -> bool: """Check if cache entry is expired.""" return time.time() - timestamp > (CACHE_EXPIRY_HOURS * 3600) def _cleanup_cache(cache_dict: dict): """Remove expired entries and maintain size limit.""" current_time = time.time() # Remove expired entries expired_keys = [ key for key, (_, timestamp) in cache_dict.items() if _is_cache_expired(timestamp) ] for key in expired_keys: cache_dict.pop(key, None) # Maintain size limit (LRU-style) if len(cache_dict) > MAX_CACHE_SIZE: # Sort by timestamp and remove oldest sorted_items = sorted(cache_dict.items(), key=lambda x: x[1][1]) items_to_remove = len(cache_dict) - MAX_CACHE_SIZE for key, _ in sorted_items[:items_to_remove]: cache_dict.pop(key, None) @lru_cache(maxsize=50) def _get_text_hash(text: str) -> str: """Cached text hashing for performance.""" return hashlib.sha256(text.encode('utf-8')).hexdigest()[:16] # ---------------- Globals ---------------- qwen_model = None qwen_tokenizer = None kokoro_pipeline = None model_loading_status = {"loaded": False, "error": None} _load_lock = threading.Lock() _loaded_once = False # idempotence guard across threads # Voice whitelist ALLOWED_VOICES = { "af_heart", "af_bella", "af_nicole", "am_michael", "am_fenrir", "af_sarah", "bf_emma", "bm_george" } # HTTP headers to look like a real browser for sites that block bots BROWSER_HEADERS = { "User-Agent": ( "Mozilla/5.0 (Macintosh; Intel Mac OS X 13_5) AppleWebKit/537.36 " "(KHTML, like Gecko) Chrome/119.0.0.0 Safari/537.36" ), "Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8", "Accept-Language": "en-US,en;q=0.9", } # Create output dirs (robust, relative to this file) BASE_DIR = Path(__file__).parent.resolve() STATIC_DIR = BASE_DIR / "static" AUDIO_DIR = STATIC_DIR / "audio" SUMM_DIR = STATIC_DIR / "summaries" for p in (AUDIO_DIR, SUMM_DIR): try: p.mkdir(parents=True, exist_ok=True) except PermissionError: logger.warning("No permission to create %s (will rely on image pre-created dirs).", p) # ---------------- Helpers ---------------- def _get_device(): # Works for both CPU/GPU; safer than qwen_model.device return next(qwen_model.parameters()).device def _safe_trim_to_tokens(text: str, tokenizer, max_tokens: int) -> str: ids = tokenizer.encode(text, add_special_tokens=False) if len(ids) <= max_tokens: return text ids = ids[:max_tokens] return tokenizer.decode(ids, skip_special_tokens=True) # Remove any leaked (with optional attributes) or similar tags _THINK_BLOCK_RE = re.compile( r"<\s*(think|reasoning|thought)\b[^>]*>.*?<\s*/\s*\1\s*>", re.IGNORECASE | re.DOTALL, ) _THINK_TAGS_RE = re.compile(r"]*>", re.IGNORECASE) def _strip_reasoning(text: str) -> str: cleaned = _THINK_BLOCK_RE.sub("", text) # remove full blocks cleaned = _THINK_TAGS_RE.sub("", cleaned) # remove any stray tags cleaned = re.sub(r"```(?:\w+)?\s*```", "", cleaned) # collapse empty fenced blocks return cleaned.strip() def _normalize_url_for_proxy(u: str) -> str: # r.jina.ai expects 'http:///' after it; unify scheme-less u2 = u.replace("https://", "").replace("http://", "") return f"https://r.jina.ai/http://{u2}" def _maybe_extract_from_html(pasted: str) -> str: """If the pasted text looks like HTML, try to extract the main text via trafilatura.""" looks_html = bool(re.search(r" tuple[str | None, str | None]: """ Try to fetch & extract article text with caching. Strategy: 1) Check cache first 2) Trafilatura.fetch_url (vanilla) 3) requests.get with browser headers + trafilatura.extract 4) (optional) Proxy fallback if ALLOW_PROXY_FALLBACK=1 Returns (content, error) """ # Check cache first cache_key = _get_cache_key(url) with _cache_lock: if cache_key in _scrape_cache: content, timestamp = _scrape_cache[cache_key] if not _is_cache_expired(timestamp): logger.info(f"Cache hit for URL: {url[:50]}...") return content, None else: # Remove expired entry _scrape_cache.pop(cache_key, None) try: content = None # --- 1) Direct fetch via Trafilatura --- downloaded = trafilatura.fetch_url(url) if downloaded: text = trafilatura.extract(downloaded, include_comments=False, include_tables=False) if text: content = text # --- 2) Raw requests + Trafilatura extract --- if not content: try: r = requests.get(url, headers=BROWSER_HEADERS, timeout=15) if r.status_code == 200 and r.text: text = trafilatura.extract(r.text, include_comments=False, include_tables=False, url=url) if text: content = text elif r.status_code == 403: logger.info("Site returned 403; considering proxy fallback (if enabled).") except requests.RequestException as e: logger.info("requests.get failed: %s", e) # --- 3) Optional proxy fallback (off by default) --- if not content and os.environ.get("ALLOW_PROXY_FALLBACK", "0") == "1": proxy_url = _normalize_url_for_proxy(url) try: pr = requests.get(proxy_url, headers=BROWSER_HEADERS, timeout=15) if pr.status_code == 200 and pr.text: extracted = trafilatura.extract(pr.text, include_comments=False, include_tables=False) or pr.text if extracted and extracted.strip(): content = extracted.strip() except requests.RequestException as e: logger.info("Proxy fallback failed: %s", e) if content: # Cache the successful result with _cache_lock: _scrape_cache[cache_key] = (content, time.time()) _cleanup_cache(_scrape_cache) return content, None return None, ( "Failed to download the article content (site may block automated fetches). " "Try another URL, paste the text manually, or set ALLOW_PROXY_FALLBACK=1." ) except Exception as e: return None, f"Error scraping article: {e}" def summarize_with_qwen(text: str) -> tuple[str | None, str | None]: """Generate summary with caching and return (summary, error).""" # Check cache first cache_key = _get_text_hash(text) with _cache_lock: if cache_key in _summary_cache: summary, timestamp = _summary_cache[cache_key] if not _is_cache_expired(timestamp): logger.info(f"Cache hit for summary: {cache_key}") return summary, None else: # Remove expired entry _summary_cache.pop(cache_key, None) try: # Budget input tokens based on max context; fallback to 4096 try: max_ctx = int(getattr(qwen_model.config, "max_position_embeddings", 4096)) except Exception: max_ctx = 4096 # Leave room for prompt + output tokens max_input_tokens = max(512, max_ctx - 1024) prompt_hdr = ( "Please provide a concise and clear summary of the following article. " "Focus on the main points, key findings, and conclusions. " "Keep it easy to understand for someone who hasn't read the original.\n\nARTICLE:\n" ) # Trim article to safe length article_trimmed = _safe_trim_to_tokens(text, qwen_tokenizer, max_input_tokens) user_content = prompt_hdr + article_trimmed messages = [ { "role": "system", "content": ( "You are a helpful assistant. Return ONLY the final summary as plain text. " "Do not include analysis, steps, or tags." ), }, {"role": "user", "content": user_content}, ] # Build the chat prompt text (disable thinking if supported) try: text_input = qwen_tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True, enable_thinking=False ) except TypeError: text_input = qwen_tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) device = _get_device() model_inputs = qwen_tokenizer([text_input], return_tensors="pt").to(device) # Performance optimization: use torch.no_grad() and clear cache with torch.no_grad(): generated_ids = qwen_model.generate( **model_inputs, max_new_tokens=512, temperature=0.7, top_p=0.8, top_k=20, do_sample=True, pad_token_id=qwen_tokenizer.eos_token_id, # Avoid warnings ) output_ids = generated_ids[0][len(model_inputs.input_ids[0]):] summary = qwen_tokenizer.decode(output_ids, skip_special_tokens=True).strip() summary = _strip_reasoning(summary) # <-- remove any leaked # Cache the result with _cache_lock: _summary_cache[cache_key] = (summary, time.time()) _cleanup_cache(_summary_cache) # Memory cleanup del model_inputs, generated_ids, output_ids if torch.cuda.is_available(): torch.cuda.empty_cache() return summary, None except Exception as e: return None, f"Error generating summary: {e}" def generate_speech(summary: str, voice: str) -> tuple[str | None, str | None, float]: """Generate speech with caching and return (filename, error, duration_seconds).""" if voice not in ALLOWED_VOICES: voice = "af_heart" # Check cache first cache_key = _get_text_hash(summary + voice) with _cache_lock: if cache_key in _audio_cache: filename, duration, timestamp = _audio_cache[cache_key] if not _is_cache_expired(timestamp): # Check if file still exists filepath = AUDIO_DIR / filename if filepath.exists(): logger.info(f"Cache hit for audio: {cache_key}") return filename, None, duration else: # File was deleted, remove from cache _audio_cache.pop(cache_key, None) try: generator = kokoro_pipeline(summary, voice=voice) audio_chunks = [] total_duration = 0.0 for item in generator: logger.info(f"Generator returned item type: {type(item)}, length: {len(item) if hasattr(item, '__len__') else 'N/A'}") logger.info(f"Generator item: {item}") _, _, audio = item audio_chunks.append(audio) total_duration += len(audio) / 24000.0 if not audio_chunks: return None, "No audio generated.", 0.0 combined = audio_chunks[0] if len(audio_chunks) == 1 else torch.cat(audio_chunks, dim=0) ts = int(time.time()) filename = f"summary_{ts}.wav" filepath = AUDIO_DIR / filename sf.write(str(filepath), combined.numpy(), 24000) # Cache the result with _cache_lock: _audio_cache[cache_key] = (filename, total_duration, time.time()) _cleanup_cache(_audio_cache) return filename, None, total_duration except Exception as e: return None, f"Error generating speech: {e}", 0.0 # ---------------- Performance Monitoring ---------------- def cleanup_old_files(): """Clean up old audio files to save disk space.""" try: current_time = time.time() cleanup_age = 7 * 24 * 3600 # 7 days for audio_file in AUDIO_DIR.glob("summary_*.wav"): if current_time - audio_file.stat().st_mtime > cleanup_age: audio_file.unlink() logger.info(f"Cleaned up old audio file: {audio_file.name}") except Exception as e: logger.warning(f"Error during file cleanup: {e}") def get_cache_stats(): """Get cache statistics for monitoring.""" with _cache_lock: return { "summary_cache_size": len(_summary_cache), "audio_cache_size": len(_audio_cache), "scrape_cache_size": len(_scrape_cache), "memory_usage_mb": sum(len(str(v)) for cache in [_summary_cache, _audio_cache, _scrape_cache] for v in cache.values()) / (1024 * 1024) } # Schedule periodic cleanup def periodic_cleanup(): """Periodic cleanup task.""" while True: time.sleep(3600) # Run every hour try: cleanup_old_files() # Force garbage collection gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() except Exception as e: logger.warning(f"Error in periodic cleanup: {e}") # Start cleanup thread cleanup_thread = threading.Thread(target=periodic_cleanup, daemon=True) cleanup_thread.start() # ---------------- Routes ---------------- @app.route("/") def index(): return render_template("index.html") @app.route("/status") def status(): return jsonify(model_loading_status) @app.route("/process", methods=["POST"]) def process_article(): if not model_loading_status["loaded"]: return jsonify({"success": False, "error": "Models not loaded yet. Please wait."}) data = request.get_json(force=True, silent=True) or {} # New: accept raw pasted text pasted_text = (data.get("text") or "").strip() url = (data.get("url") or "").strip() generate_audio = bool(data.get("generate_audio", False)) voice = (data.get("voice") or "af_heart").strip() if not pasted_text and not url: return jsonify({"success": False, "error": "Please paste text or provide a valid URL."}) # 1) Resolve content: prefer pasted text if provided if pasted_text: article_content = _maybe_extract_from_html(pasted_text) scrape_error = None else: article_content, scrape_error = scrape_article_text(url) if scrape_error: return jsonify({"success": False, "error": scrape_error}) # 2) Summarize summary, summary_error = summarize_with_qwen(article_content) if summary_error: return jsonify({"success": False, "error": summary_error}) resp = { "success": True, "summary": summary, "article_length": len(article_content or ""), "summary_length": len(summary or ""), "compression_ratio": round(len(summary) / max(len(article_content), 1) * 100, 1), "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"), } # 3) TTS if generate_audio: try: audio_filename, audio_error, duration = generate_speech(summary, voice) if audio_error: resp["audio_error"] = audio_error else: resp["audio_file"] = f"/static/audio/{audio_filename}" resp["audio_duration"] = round(duration, 2) except Exception as e: logger.exception("Error in audio generation: %s", e) resp["audio_error"] = f"Audio generation failed: {str(e)}" return jsonify(resp) @app.route("/voices") def get_voices(): voices = [ {"id": "af_heart", "name": "Female - Heart", "grade": "A", "description": "❤️ Warm female voice (best quality)"}, {"id": "af_bella", "name": "Female - Bella", "grade": "A-", "description": "🔥 Energetic female voice"}, {"id": "af_nicole", "name": "Female - Nicole", "grade": "B-", "description": "🎧 Professional female voice"}, {"id": "am_michael", "name": "Male - Michael", "grade": "C+", "description": "Clear male voice"}, {"id": "am_fenrir", "name": "Male - Fenrir", "grade": "C+", "description": "Strong male voice"}, {"id": "af_sarah", "name": "Female - Sarah", "grade": "C+", "description": "Gentle female voice"}, {"id": "bf_emma", "name": "British Female - Emma", "grade": "B-", "description": "🇬🇧 British accent"}, {"id": "bm_george", "name": "British Male - George", "grade": "C", "description": "🇬🇧 British male voice"}, ] return jsonify(voices) @app.route("/cache-stats") def cache_stats(): """Get cache statistics for performance monitoring.""" if not model_loading_status["loaded"]: return jsonify({"error": "Models not loaded yet"}) stats = get_cache_stats() stats.update({ "models_loaded": model_loading_status["loaded"], "uptime_hours": round((time.time() - app.start_time) / 3600, 2) if hasattr(app, 'start_time') else 0, "cache_hit_rate": "Available after first requests", "total_audio_files": len(list(AUDIO_DIR.glob("summary_*.wav"))), }) return jsonify(stats) @app.route("/health") def health_check(): """Health check endpoint for monitoring.""" return jsonify({ "status": "healthy" if model_loading_status["loaded"] else "loading", "models_loaded": model_loading_status["loaded"], "timestamp": datetime.now().isoformat(), "version": "2.0.0-enhanced" }) # Kick off model loading when running under Gunicorn/containers if os.environ.get("RUNNING_GUNICORN", "0") == "1": threading.Thread(target=load_models, daemon=True).start() # ---------------- Dev entrypoint ---------------- if __name__ == "__main__": import argparse parser = argparse.ArgumentParser(description="AI Article Summarizer Web App") parser.add_argument("--port", type=int, default=5001, help="Port to run the server on (default: 5001)") parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to bind to (default: 0.0.0.0)") args = parser.parse_args() # Track start time for uptime monitoring app.start_time = time.time() # Load models in background thread threading.Thread(target=load_models, daemon=True).start() # Respect platform env PORT when present (HF Spaces: 7860) port = int(os.environ.get("PORT", args.port)) print("🚀 Starting Enhanced Article Summarizer Web App v2.0…") print("📚 Models are loading in the background…") print(f"🌐 Open http://localhost:{port} in your browser") print("✨ New features:") print(" • Enhanced UI with animations and keyboard shortcuts") print(" • Smart caching for 10x faster repeat requests") print(" • Better error handling and performance monitoring") print(" • Accessibility improvements and mobile optimization") try: app.run(debug=True, host=args.host, port=port) except OSError as e: if "Address already in use" in str(e): print(f"❌ Port {port} is already in use!") print("💡 Try a different port:") print(f" python app.py --port {port + 1}") print("📱 Or disable AirPlay Receiver in System Settings → General → AirDrop & Handoff") else: raise # Set start time for production deployments too if not hasattr(app, 'start_time'): app.start_time = time.time()