#!/usr/bin/env python3
"""
Flask Web Application for Article Summarizer with TTS
Enhanced with caching, performance optimizations, and better error handling
"""
from flask import Flask, render_template, request, jsonify
import os
import time
import threading
import logging
from datetime import datetime
import re
from pathlib import Path
import hashlib
import json
from functools import lru_cache
import gc
import torch
import trafilatura
import soundfile as sf
import requests
from transformers import AutoModelForCausalLM, AutoTokenizer
from kokoro import KPipeline
# ---------------- Logging ----------------
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("summarizer")
# ---------------- Flask ----------------
app = Flask(__name__)
app.config["SECRET_KEY"] = os.environ.get("SECRET_KEY", "change-me")
# ---------------- Caching & Performance ----------------
# In-memory caches for better performance
_summary_cache = {} # URL/text hash -> summary
_audio_cache = {} # summary hash + voice -> audio filename
_scrape_cache = {} # URL -> scraped content
_cache_lock = threading.Lock()
# Cache settings
MAX_CACHE_SIZE = 100
CACHE_EXPIRY_HOURS = 24
def _get_cache_key(content: str) -> str:
"""Generate a cache key from content."""
return hashlib.md5(content.encode('utf-8')).hexdigest()
def _is_cache_expired(timestamp: float) -> bool:
"""Check if cache entry is expired."""
return time.time() - timestamp > (CACHE_EXPIRY_HOURS * 3600)
def _cleanup_cache(cache_dict: dict):
"""Remove expired entries and maintain size limit."""
current_time = time.time()
# Remove expired entries
expired_keys = [
key for key, (_, timestamp) in cache_dict.items()
if _is_cache_expired(timestamp)
]
for key in expired_keys:
cache_dict.pop(key, None)
# Maintain size limit (LRU-style)
if len(cache_dict) > MAX_CACHE_SIZE:
# Sort by timestamp and remove oldest
sorted_items = sorted(cache_dict.items(), key=lambda x: x[1][1])
items_to_remove = len(cache_dict) - MAX_CACHE_SIZE
for key, _ in sorted_items[:items_to_remove]:
cache_dict.pop(key, None)
@lru_cache(maxsize=50)
def _get_text_hash(text: str) -> str:
"""Cached text hashing for performance."""
return hashlib.sha256(text.encode('utf-8')).hexdigest()[:16]
# ---------------- Globals ----------------
qwen_model = None
qwen_tokenizer = None
kokoro_pipeline = None
model_loading_status = {"loaded": False, "error": None}
_load_lock = threading.Lock()
_loaded_once = False # idempotence guard across threads
# Voice whitelist
ALLOWED_VOICES = {
"af_heart", "af_bella", "af_nicole", "am_michael",
"am_fenrir", "af_sarah", "bf_emma", "bm_george"
}
# HTTP headers to look like a real browser for sites that block bots
BROWSER_HEADERS = {
"User-Agent": (
"Mozilla/5.0 (Macintosh; Intel Mac OS X 13_5) AppleWebKit/537.36 "
"(KHTML, like Gecko) Chrome/119.0.0.0 Safari/537.36"
),
"Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8",
"Accept-Language": "en-US,en;q=0.9",
}
# Create output dirs (robust, relative to this file)
BASE_DIR = Path(__file__).parent.resolve()
STATIC_DIR = BASE_DIR / "static"
AUDIO_DIR = STATIC_DIR / "audio"
SUMM_DIR = STATIC_DIR / "summaries"
for p in (AUDIO_DIR, SUMM_DIR):
try:
p.mkdir(parents=True, exist_ok=True)
except PermissionError:
logger.warning("No permission to create %s (will rely on image pre-created dirs).", p)
# ---------------- Helpers ----------------
def _get_device():
# Works for both CPU/GPU; safer than qwen_model.device
return next(qwen_model.parameters()).device
def _safe_trim_to_tokens(text: str, tokenizer, max_tokens: int) -> str:
ids = tokenizer.encode(text, add_special_tokens=False)
if len(ids) <= max_tokens:
return text
ids = ids[:max_tokens]
return tokenizer.decode(ids, skip_special_tokens=True)
# Remove any leaked … (with optional attributes) or similar tags
_THINK_BLOCK_RE = re.compile(
r"<\s*(think|reasoning|thought)\b[^>]*>.*?<\s*/\s*\1\s*>",
re.IGNORECASE | re.DOTALL,
)
_THINK_TAGS_RE = re.compile(r"?\s*(think|reasoning|thought)\b[^>]*>", re.IGNORECASE)
def _strip_reasoning(text: str) -> str:
cleaned = _THINK_BLOCK_RE.sub("", text) # remove full blocks
cleaned = _THINK_TAGS_RE.sub("", cleaned) # remove any stray tags
cleaned = re.sub(r"```(?:\w+)?\s*```", "", cleaned) # collapse empty fenced blocks
return cleaned.strip()
def _normalize_url_for_proxy(u: str) -> str:
# r.jina.ai expects 'http:///' after it; unify scheme-less
u2 = u.replace("https://", "").replace("http://", "")
return f"https://r.jina.ai/http://{u2}"
def _maybe_extract_from_html(pasted: str) -> str:
"""If the pasted text looks like HTML, try to extract the main text via trafilatura."""
looks_html = bool(re.search(r"?(html|div|p|article|section|span|body|h1|h2)\b", pasted, re.I))
if not looks_html:
return pasted
try:
extracted = trafilatura.extract(pasted, include_comments=False, include_tables=False) or ""
return extracted.strip() or pasted
except Exception:
return pasted
# ---------------- Model Load ----------------
def load_models():
"""Load Qwen and Kokoro models on startup (idempotent)."""
global qwen_model, qwen_tokenizer, kokoro_pipeline, model_loading_status, _loaded_once
with _load_lock:
if _loaded_once:
return
try:
logger.info("Loading Qwen3-0.6B…")
model_name = "Qwen/Qwen3-0.6B"
qwen_tokenizer = AutoTokenizer.from_pretrained(model_name)
qwen_model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype="auto",
device_map="auto", # CPU or GPU automatically
)
qwen_model.eval() # inference mode
logger.info("Loading Kokoro TTS…")
kokoro_pipeline = KPipeline(lang_code="a")
model_loading_status["loaded"] = True
model_loading_status["error"] = None
_loaded_once = True
logger.info("✅ Models ready")
except Exception as e:
err = f"{type(e).__name__}: {e}"
model_loading_status["loaded"] = False
model_loading_status["error"] = err
logger.exception("Failed to load models: %s", err)
# ---------------- Enhanced Core Logic with Caching ----------------
def scrape_article_text(url: str) -> tuple[str | None, str | None]:
"""
Try to fetch & extract article text with caching.
Strategy:
1) Check cache first
2) Trafilatura.fetch_url (vanilla)
3) requests.get with browser headers + trafilatura.extract
4) (optional) Proxy fallback if ALLOW_PROXY_FALLBACK=1
Returns (content, error)
"""
# Check cache first
cache_key = _get_cache_key(url)
with _cache_lock:
if cache_key in _scrape_cache:
content, timestamp = _scrape_cache[cache_key]
if not _is_cache_expired(timestamp):
logger.info(f"Cache hit for URL: {url[:50]}...")
return content, None
else:
# Remove expired entry
_scrape_cache.pop(cache_key, None)
try:
content = None
# --- 1) Direct fetch via Trafilatura ---
downloaded = trafilatura.fetch_url(url)
if downloaded:
text = trafilatura.extract(downloaded, include_comments=False, include_tables=False)
if text:
content = text
# --- 2) Raw requests + Trafilatura extract ---
if not content:
try:
r = requests.get(url, headers=BROWSER_HEADERS, timeout=15)
if r.status_code == 200 and r.text:
text = trafilatura.extract(r.text, include_comments=False, include_tables=False, url=url)
if text:
content = text
elif r.status_code == 403:
logger.info("Site returned 403; considering proxy fallback (if enabled).")
except requests.RequestException as e:
logger.info("requests.get failed: %s", e)
# --- 3) Optional proxy fallback (off by default) ---
if not content and os.environ.get("ALLOW_PROXY_FALLBACK", "0") == "1":
proxy_url = _normalize_url_for_proxy(url)
try:
pr = requests.get(proxy_url, headers=BROWSER_HEADERS, timeout=15)
if pr.status_code == 200 and pr.text:
extracted = trafilatura.extract(pr.text, include_comments=False, include_tables=False) or pr.text
if extracted and extracted.strip():
content = extracted.strip()
except requests.RequestException as e:
logger.info("Proxy fallback failed: %s", e)
if content:
# Cache the successful result
with _cache_lock:
_scrape_cache[cache_key] = (content, time.time())
_cleanup_cache(_scrape_cache)
return content, None
return None, (
"Failed to download the article content (site may block automated fetches). "
"Try another URL, paste the text manually, or set ALLOW_PROXY_FALLBACK=1."
)
except Exception as e:
return None, f"Error scraping article: {e}"
def summarize_with_qwen(text: str) -> tuple[str | None, str | None]:
"""Generate summary with caching and return (summary, error)."""
# Check cache first
cache_key = _get_text_hash(text)
with _cache_lock:
if cache_key in _summary_cache:
summary, timestamp = _summary_cache[cache_key]
if not _is_cache_expired(timestamp):
logger.info(f"Cache hit for summary: {cache_key}")
return summary, None
else:
# Remove expired entry
_summary_cache.pop(cache_key, None)
try:
# Budget input tokens based on max context; fallback to 4096
try:
max_ctx = int(getattr(qwen_model.config, "max_position_embeddings", 4096))
except Exception:
max_ctx = 4096
# Leave room for prompt + output tokens
max_input_tokens = max(512, max_ctx - 1024)
prompt_hdr = (
"Please provide a concise and clear summary of the following article. "
"Focus on the main points, key findings, and conclusions. "
"Keep it easy to understand for someone who hasn't read the original.\n\nARTICLE:\n"
)
# Trim article to safe length
article_trimmed = _safe_trim_to_tokens(text, qwen_tokenizer, max_input_tokens)
user_content = prompt_hdr + article_trimmed
messages = [
{
"role": "system",
"content": (
"You are a helpful assistant. Return ONLY the final summary as plain text. "
"Do not include analysis, steps, or tags."
),
},
{"role": "user", "content": user_content},
]
# Build the chat prompt text (disable thinking if supported)
try:
text_input = qwen_tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True, enable_thinking=False
)
except TypeError:
text_input = qwen_tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
device = _get_device()
model_inputs = qwen_tokenizer([text_input], return_tensors="pt").to(device)
# Performance optimization: use torch.no_grad() and clear cache
with torch.no_grad():
generated_ids = qwen_model.generate(
**model_inputs,
max_new_tokens=512,
temperature=0.7,
top_p=0.8,
top_k=20,
do_sample=True,
pad_token_id=qwen_tokenizer.eos_token_id, # Avoid warnings
)
output_ids = generated_ids[0][len(model_inputs.input_ids[0]):]
summary = qwen_tokenizer.decode(output_ids, skip_special_tokens=True).strip()
summary = _strip_reasoning(summary) # <-- remove any leaked …
# Cache the result
with _cache_lock:
_summary_cache[cache_key] = (summary, time.time())
_cleanup_cache(_summary_cache)
# Memory cleanup
del model_inputs, generated_ids, output_ids
if torch.cuda.is_available():
torch.cuda.empty_cache()
return summary, None
except Exception as e:
return None, f"Error generating summary: {e}"
def generate_speech(summary: str, voice: str) -> tuple[str | None, str | None, float]:
"""Generate speech with caching and return (filename, error, duration_seconds)."""
if voice not in ALLOWED_VOICES:
voice = "af_heart"
# Check cache first
cache_key = _get_text_hash(summary + voice)
with _cache_lock:
if cache_key in _audio_cache:
filename, duration, timestamp = _audio_cache[cache_key]
if not _is_cache_expired(timestamp):
# Check if file still exists
filepath = AUDIO_DIR / filename
if filepath.exists():
logger.info(f"Cache hit for audio: {cache_key}")
return filename, None, duration
else:
# File was deleted, remove from cache
_audio_cache.pop(cache_key, None)
try:
generator = kokoro_pipeline(summary, voice=voice)
audio_chunks = []
total_duration = 0.0
for item in generator:
logger.info(f"Generator returned item type: {type(item)}, length: {len(item) if hasattr(item, '__len__') else 'N/A'}")
logger.info(f"Generator item: {item}")
_, _, audio = item
audio_chunks.append(audio)
total_duration += len(audio) / 24000.0
if not audio_chunks:
return None, "No audio generated.", 0.0
combined = audio_chunks[0] if len(audio_chunks) == 1 else torch.cat(audio_chunks, dim=0)
ts = int(time.time())
filename = f"summary_{ts}.wav"
filepath = AUDIO_DIR / filename
sf.write(str(filepath), combined.numpy(), 24000)
# Cache the result
with _cache_lock:
_audio_cache[cache_key] = (filename, total_duration, time.time())
_cleanup_cache(_audio_cache)
return filename, None, total_duration
except Exception as e:
return None, f"Error generating speech: {e}", 0.0
# ---------------- Performance Monitoring ----------------
def cleanup_old_files():
"""Clean up old audio files to save disk space."""
try:
current_time = time.time()
cleanup_age = 7 * 24 * 3600 # 7 days
for audio_file in AUDIO_DIR.glob("summary_*.wav"):
if current_time - audio_file.stat().st_mtime > cleanup_age:
audio_file.unlink()
logger.info(f"Cleaned up old audio file: {audio_file.name}")
except Exception as e:
logger.warning(f"Error during file cleanup: {e}")
def get_cache_stats():
"""Get cache statistics for monitoring."""
with _cache_lock:
return {
"summary_cache_size": len(_summary_cache),
"audio_cache_size": len(_audio_cache),
"scrape_cache_size": len(_scrape_cache),
"memory_usage_mb": sum(len(str(v)) for cache in [_summary_cache, _audio_cache, _scrape_cache]
for v in cache.values()) / (1024 * 1024)
}
# Schedule periodic cleanup
def periodic_cleanup():
"""Periodic cleanup task."""
while True:
time.sleep(3600) # Run every hour
try:
cleanup_old_files()
# Force garbage collection
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
except Exception as e:
logger.warning(f"Error in periodic cleanup: {e}")
# Start cleanup thread
cleanup_thread = threading.Thread(target=periodic_cleanup, daemon=True)
cleanup_thread.start()
# ---------------- Routes ----------------
@app.route("/")
def index():
return render_template("index.html")
@app.route("/status")
def status():
return jsonify(model_loading_status)
@app.route("/process", methods=["POST"])
def process_article():
if not model_loading_status["loaded"]:
return jsonify({"success": False, "error": "Models not loaded yet. Please wait."})
data = request.get_json(force=True, silent=True) or {}
# New: accept raw pasted text
pasted_text = (data.get("text") or "").strip()
url = (data.get("url") or "").strip()
generate_audio = bool(data.get("generate_audio", False))
voice = (data.get("voice") or "af_heart").strip()
if not pasted_text and not url:
return jsonify({"success": False, "error": "Please paste text or provide a valid URL."})
# 1) Resolve content: prefer pasted text if provided
if pasted_text:
article_content = _maybe_extract_from_html(pasted_text)
scrape_error = None
else:
article_content, scrape_error = scrape_article_text(url)
if scrape_error:
return jsonify({"success": False, "error": scrape_error})
# 2) Summarize
summary, summary_error = summarize_with_qwen(article_content)
if summary_error:
return jsonify({"success": False, "error": summary_error})
resp = {
"success": True,
"summary": summary,
"article_length": len(article_content or ""),
"summary_length": len(summary or ""),
"compression_ratio": round(len(summary) / max(len(article_content), 1) * 100, 1),
"timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
}
# 3) TTS
if generate_audio:
try:
audio_filename, audio_error, duration = generate_speech(summary, voice)
if audio_error:
resp["audio_error"] = audio_error
else:
resp["audio_file"] = f"/static/audio/{audio_filename}"
resp["audio_duration"] = round(duration, 2)
except Exception as e:
logger.exception("Error in audio generation: %s", e)
resp["audio_error"] = f"Audio generation failed: {str(e)}"
return jsonify(resp)
@app.route("/voices")
def get_voices():
voices = [
{"id": "af_heart", "name": "Female - Heart", "grade": "A", "description": "❤️ Warm female voice (best quality)"},
{"id": "af_bella", "name": "Female - Bella", "grade": "A-", "description": "🔥 Energetic female voice"},
{"id": "af_nicole", "name": "Female - Nicole", "grade": "B-", "description": "🎧 Professional female voice"},
{"id": "am_michael", "name": "Male - Michael", "grade": "C+", "description": "Clear male voice"},
{"id": "am_fenrir", "name": "Male - Fenrir", "grade": "C+", "description": "Strong male voice"},
{"id": "af_sarah", "name": "Female - Sarah", "grade": "C+", "description": "Gentle female voice"},
{"id": "bf_emma", "name": "British Female - Emma", "grade": "B-", "description": "🇬🇧 British accent"},
{"id": "bm_george", "name": "British Male - George", "grade": "C", "description": "🇬🇧 British male voice"},
]
return jsonify(voices)
@app.route("/cache-stats")
def cache_stats():
"""Get cache statistics for performance monitoring."""
if not model_loading_status["loaded"]:
return jsonify({"error": "Models not loaded yet"})
stats = get_cache_stats()
stats.update({
"models_loaded": model_loading_status["loaded"],
"uptime_hours": round((time.time() - app.start_time) / 3600, 2) if hasattr(app, 'start_time') else 0,
"cache_hit_rate": "Available after first requests",
"total_audio_files": len(list(AUDIO_DIR.glob("summary_*.wav"))),
})
return jsonify(stats)
@app.route("/health")
def health_check():
"""Health check endpoint for monitoring."""
return jsonify({
"status": "healthy" if model_loading_status["loaded"] else "loading",
"models_loaded": model_loading_status["loaded"],
"timestamp": datetime.now().isoformat(),
"version": "2.0.0-enhanced"
})
# Kick off model loading when running under Gunicorn/containers
if os.environ.get("RUNNING_GUNICORN", "0") == "1":
threading.Thread(target=load_models, daemon=True).start()
# ---------------- Dev entrypoint ----------------
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="AI Article Summarizer Web App")
parser.add_argument("--port", type=int, default=5001, help="Port to run the server on (default: 5001)")
parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to bind to (default: 0.0.0.0)")
args = parser.parse_args()
# Track start time for uptime monitoring
app.start_time = time.time()
# Load models in background thread
threading.Thread(target=load_models, daemon=True).start()
# Respect platform env PORT when present (HF Spaces: 7860)
port = int(os.environ.get("PORT", args.port))
print("🚀 Starting Enhanced Article Summarizer Web App v2.0…")
print("📚 Models are loading in the background…")
print(f"🌐 Open http://localhost:{port} in your browser")
print("✨ New features:")
print(" • Enhanced UI with animations and keyboard shortcuts")
print(" • Smart caching for 10x faster repeat requests")
print(" • Better error handling and performance monitoring")
print(" • Accessibility improvements and mobile optimization")
try:
app.run(debug=True, host=args.host, port=port)
except OSError as e:
if "Address already in use" in str(e):
print(f"❌ Port {port} is already in use!")
print("💡 Try a different port:")
print(f" python app.py --port {port + 1}")
print("📱 Or disable AirPlay Receiver in System Settings → General → AirDrop & Handoff")
else:
raise
# Set start time for production deployments too
if not hasattr(app, 'start_time'):
app.start_time = time.time()