|
|
|
|
|
""" |
|
|
Unified Thematic Word Generator using WordFreq + SentenceTransformers |
|
|
|
|
|
Eliminates vocabulary redundancy by using WordFreq as the single vocabulary source |
|
|
for both word lists and frequency data, with all-mpnet-base-v2 for embeddings. |
|
|
|
|
|
Features: |
|
|
- Single vocabulary source (WordFreq 319K words vs previous 3 separate sources) |
|
|
- Unified filtering for crossword-suitable words |
|
|
- 10-tier frequency classification system |
|
|
- Compatible with crossword backend services |
|
|
- Comprehensive modern vocabulary with proper frequency data |
|
|
- Environment variable configuration for cache paths and settings |
|
|
|
|
|
Environment Variables: |
|
|
- CACHE_DIR: Cache directory for all thematic service files (default: ./model_cache) |
|
|
- THEMATIC_VOCAB_SIZE_LIMIT: Maximum vocabulary size (default: 100000) |
|
|
- MAX_VOCABULARY_SIZE: Fallback vocab size limit (used if THEMATIC_VOCAB_SIZE_LIMIT not set) |
|
|
- THEMATIC_MODEL_NAME: Sentence transformer model to use (default: all-mpnet-base-v2) |
|
|
|
|
|
Cache Structure: |
|
|
- {cache_dir}/vocabulary_{size}.pkl - Processed vocabulary words |
|
|
- {cache_dir}/frequencies_{size}.pkl - Word frequency data |
|
|
- {cache_dir}/embeddings_{model}_{size}.npy - Word embeddings |
|
|
- {cache_dir}/sentence-transformers/ - Hugging Face model cache |
|
|
|
|
|
Usage: |
|
|
# Use environment variables for production |
|
|
export CACHE_DIR=/app/cache |
|
|
export THEMATIC_VOCAB_SIZE_LIMIT=50000 |
|
|
|
|
|
# Or pass directly to constructor for development |
|
|
service = ThematicWordService(cache_dir="/custom/path", vocab_size_limit=25000) |
|
|
""" |
|
|
|
|
|
import os |
|
|
import csv |
|
|
import pickle |
|
|
import numpy as np |
|
|
import logging |
|
|
import asyncio |
|
|
import random |
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
from typing import List, Tuple, Optional, Dict, Set, Any |
|
|
from sentence_transformers import SentenceTransformer |
|
|
from sklearn.metrics.pairwise import cosine_similarity |
|
|
from sklearn.cluster import KMeans |
|
|
from datetime import datetime |
|
|
import time |
|
|
from collections import Counter |
|
|
from pathlib import Path |
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
try: |
|
|
from wordfreq import word_frequency, zipf_frequency, top_n_list |
|
|
WORDFREQ_AVAILABLE = True |
|
|
except ImportError: |
|
|
logger.warning("WordFreq not available, using Norvig vocabulary only") |
|
|
WORDFREQ_AVAILABLE = False |
|
|
|
|
|
|
|
|
from .norvig_vocabulary_manager import NorgivVocabularyManager |
|
|
|
|
|
|
|
|
def get_timestamp(): |
|
|
return datetime.now().strftime("%H:%M:%S") |
|
|
|
|
|
def get_datetimestamp(): |
|
|
return datetime.now().strftime("%Y-%m-%d %H:%M:%S") |
|
|
|
|
|
|
|
|
class VocabularyManager: |
|
|
""" |
|
|
Centralized vocabulary management supporting both WordFreq and Norvig sources. |
|
|
Handles loading, filtering, caching, and frequency data generation. |
|
|
""" |
|
|
|
|
|
def __init__(self, cache_dir: Optional[str] = None, vocab_size_limit: Optional[int] = None): |
|
|
"""Initialize vocabulary manager. |
|
|
|
|
|
Args: |
|
|
cache_dir: Directory for caching vocabulary and embeddings |
|
|
vocab_size_limit: Maximum vocabulary size (None for full vocabulary) |
|
|
""" |
|
|
if cache_dir is None: |
|
|
|
|
|
cache_dir = os.getenv("CACHE_DIR") |
|
|
if cache_dir is None: |
|
|
cache_dir = os.path.join(os.path.dirname(__file__), 'model_cache') |
|
|
|
|
|
self.cache_dir = Path(cache_dir) |
|
|
self.cache_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
|
|
|
self.vocab_size_limit = vocab_size_limit or int(os.getenv("THEMATIC_VOCAB_SIZE_LIMIT", |
|
|
os.getenv("MAX_VOCABULARY_SIZE", "100000"))) |
|
|
|
|
|
|
|
|
self.vocab_source = os.getenv("VOCAB_SOURCE", "norvig").lower() |
|
|
logger.info(f"📚 Vocabulary source: {self.vocab_source}") |
|
|
|
|
|
|
|
|
if self.vocab_source == "norvig": |
|
|
self.vocab_manager = NorgivVocabularyManager(cache_dir, vocab_size_limit) |
|
|
elif self.vocab_source == "wordfreq" and WORDFREQ_AVAILABLE: |
|
|
self.vocab_manager = None |
|
|
else: |
|
|
if not WORDFREQ_AVAILABLE: |
|
|
logger.warning("⚠️ WordFreq not available, falling back to Norvig") |
|
|
self.vocab_source = "norvig" |
|
|
self.vocab_manager = NorgivVocabularyManager(cache_dir, vocab_size_limit) |
|
|
else: |
|
|
logger.warning(f"⚠️ Unknown vocab source '{self.vocab_source}', falling back to Norvig") |
|
|
self.vocab_source = "norvig" |
|
|
self.vocab_manager = NorgivVocabularyManager(cache_dir, vocab_size_limit) |
|
|
|
|
|
|
|
|
source_suffix = f"_{self.vocab_source}" if self.vocab_source != "wordfreq" else "" |
|
|
self.vocab_cache_path = self.cache_dir / f"vocabulary{source_suffix}_{self.vocab_size_limit}.pkl" |
|
|
self.frequency_cache_path = self.cache_dir / f"frequencies{source_suffix}_{self.vocab_size_limit}.pkl" |
|
|
|
|
|
|
|
|
self.vocabulary: List[str] = [] |
|
|
self.word_frequencies: Counter = Counter() |
|
|
self.is_loaded = False |
|
|
|
|
|
def load_vocabulary(self) -> Tuple[List[str], Counter]: |
|
|
"""Load vocabulary and frequency data, with caching.""" |
|
|
if self.is_loaded: |
|
|
return self.vocabulary, self.word_frequencies |
|
|
|
|
|
|
|
|
if self.vocab_manager is not None: |
|
|
self.vocabulary, self.word_frequencies = self.vocab_manager.load_vocabulary() |
|
|
self.is_loaded = True |
|
|
return self.vocabulary, self.word_frequencies |
|
|
|
|
|
|
|
|
|
|
|
if self._load_from_cache(): |
|
|
logger.info(f"✅ Loaded vocabulary from cache: {len(self.vocabulary):,} words") |
|
|
self.is_loaded = True |
|
|
return self.vocabulary, self.word_frequencies |
|
|
|
|
|
|
|
|
logger.info("🔄 Generating vocabulary from WordFreq...") |
|
|
self._generate_vocabulary_from_wordfreq() |
|
|
|
|
|
|
|
|
self._save_to_cache() |
|
|
|
|
|
self.is_loaded = True |
|
|
return self.vocabulary, self.word_frequencies |
|
|
|
|
|
def _load_from_cache(self) -> bool: |
|
|
"""Load vocabulary and frequencies from cache.""" |
|
|
try: |
|
|
if self.vocab_cache_path.exists() and self.frequency_cache_path.exists(): |
|
|
logger.info(f"📦 Loading vocabulary from cache...") |
|
|
logger.info(f" Vocab cache: {self.vocab_cache_path}") |
|
|
logger.info(f" Freq cache: {self.frequency_cache_path}") |
|
|
|
|
|
|
|
|
if not os.access(self.vocab_cache_path, os.R_OK): |
|
|
logger.warning(f"⚠️ Vocabulary cache file not readable: {self.vocab_cache_path}") |
|
|
return False |
|
|
|
|
|
if not os.access(self.frequency_cache_path, os.R_OK): |
|
|
logger.warning(f"⚠️ Frequency cache file not readable: {self.frequency_cache_path}") |
|
|
return False |
|
|
|
|
|
with open(self.vocab_cache_path, 'rb') as f: |
|
|
self.vocabulary = pickle.load(f) |
|
|
|
|
|
with open(self.frequency_cache_path, 'rb') as f: |
|
|
self.word_frequencies = pickle.load(f) |
|
|
|
|
|
|
|
|
if not self.vocabulary or not self.word_frequencies: |
|
|
logger.warning("⚠️ Cache files contain empty data") |
|
|
return False |
|
|
|
|
|
logger.info(f"✅ Loaded {len(self.vocabulary):,} words and {len(self.word_frequencies):,} frequencies from cache") |
|
|
return True |
|
|
else: |
|
|
missing = [] |
|
|
if not self.vocab_cache_path.exists(): |
|
|
missing.append(f"vocabulary ({self.vocab_cache_path})") |
|
|
if not self.frequency_cache_path.exists(): |
|
|
missing.append(f"frequency ({self.frequency_cache_path})") |
|
|
logger.info(f"📂 Cache files missing: {', '.join(missing)}") |
|
|
return False |
|
|
except Exception as e: |
|
|
logger.warning(f"⚠️ Cache loading failed: {e}") |
|
|
|
|
|
return False |
|
|
|
|
|
def _save_to_cache(self): |
|
|
"""Save vocabulary and frequencies to cache.""" |
|
|
try: |
|
|
logger.info("💾 Saving vocabulary to cache...") |
|
|
|
|
|
with open(self.vocab_cache_path, 'wb') as f: |
|
|
pickle.dump(self.vocabulary, f) |
|
|
|
|
|
with open(self.frequency_cache_path, 'wb') as f: |
|
|
pickle.dump(self.word_frequencies, f) |
|
|
|
|
|
logger.info("✅ Vocabulary cached successfully") |
|
|
except Exception as e: |
|
|
logger.warning(f"⚠️ Cache saving failed: {e}") |
|
|
|
|
|
def _generate_vocabulary_from_wordfreq(self): |
|
|
"""Generate filtered vocabulary from WordFreq database.""" |
|
|
if not WORDFREQ_AVAILABLE: |
|
|
raise ImportError("WordFreq is not available, cannot generate vocabulary") |
|
|
|
|
|
logger.info(f"📚 Fetching top {self.vocab_size_limit:,} words from WordFreq...") |
|
|
|
|
|
|
|
|
raw_words = top_n_list('en', self.vocab_size_limit * 2, wordlist='large') |
|
|
logger.info(f"📥 Retrieved {len(raw_words):,} raw words from WordFreq") |
|
|
|
|
|
|
|
|
filtered_words = [] |
|
|
frequency_data = Counter() |
|
|
|
|
|
logger.info("🔍 Applying crossword filtering...") |
|
|
for word in raw_words: |
|
|
if self._is_crossword_suitable(word): |
|
|
filtered_words.append(word.lower()) |
|
|
|
|
|
|
|
|
try: |
|
|
freq = word_frequency(word, 'en', wordlist='large') |
|
|
if freq > 0: |
|
|
|
|
|
frequency_data[word.lower()] = int(freq * 1e9) |
|
|
except: |
|
|
frequency_data[word.lower()] = 1 |
|
|
|
|
|
if len(filtered_words) >= self.vocab_size_limit: |
|
|
break |
|
|
|
|
|
|
|
|
self.vocabulary = sorted(list(set(filtered_words))) |
|
|
self.word_frequencies = frequency_data |
|
|
|
|
|
logger.info(f"✅ Generated filtered vocabulary: {len(self.vocabulary):,} words") |
|
|
logger.info(f"📊 Frequency data coverage: {len(self.word_frequencies):,} words") |
|
|
|
|
|
def _is_crossword_suitable(self, word: str) -> bool: |
|
|
"""Check if word is suitable for crosswords.""" |
|
|
word = word.lower().strip() |
|
|
|
|
|
|
|
|
if len(word) < 3 or len(word) > 12: |
|
|
return False |
|
|
|
|
|
|
|
|
if not word.isalpha(): |
|
|
return False |
|
|
|
|
|
|
|
|
boring_words = { |
|
|
'the', 'and', 'for', 'are', 'but', 'not', 'you', 'all', 'this', 'that', |
|
|
'with', 'from', 'they', 'were', 'been', 'have', 'their', 'said', 'each', |
|
|
'which', 'what', 'there', 'will', 'more', 'when', 'some', 'like', 'into', |
|
|
'time', 'very', 'only', 'has', 'had', 'who', 'its', 'now', 'find', 'long', |
|
|
'down', 'day', 'did', 'get', 'come', 'made', 'may', 'part' |
|
|
} |
|
|
|
|
|
if word in boring_words: |
|
|
return False |
|
|
|
|
|
|
|
|
if len(word) > 4 and word.endswith('s') and not word.endswith(('ss', 'us', 'is')): |
|
|
return False |
|
|
|
|
|
|
|
|
if len(set(word)) < len(word) * 0.6: |
|
|
return False |
|
|
|
|
|
return True |
|
|
|
|
|
|
|
|
class ThematicWordService: |
|
|
""" |
|
|
Unified thematic word generator using WordFreq vocabulary and all-mpnet-base-v2 embeddings. |
|
|
|
|
|
Compatible with both hack tools and crossword backend services. |
|
|
Eliminates vocabulary redundancy by using single source for everything. |
|
|
""" |
|
|
|
|
|
def __init__(self, cache_dir: Optional[str] = None, model_name: str = 'all-mpnet-base-v2', |
|
|
vocab_size_limit: Optional[int] = None): |
|
|
"""Initialize the unified thematic word generator. |
|
|
|
|
|
Args: |
|
|
cache_dir: Directory to cache model and embeddings |
|
|
model_name: Sentence transformer model to use |
|
|
vocab_size_limit: Maximum vocabulary size (None for 100K default) |
|
|
""" |
|
|
if cache_dir is None: |
|
|
|
|
|
cache_dir = os.getenv("CACHE_DIR") |
|
|
if cache_dir is None: |
|
|
cache_dir = os.path.join(os.path.dirname(__file__), 'model_cache') |
|
|
|
|
|
self.cache_dir = Path(cache_dir) |
|
|
self.cache_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
|
|
|
self.model_name = os.getenv("THEMATIC_MODEL_NAME", model_name) |
|
|
|
|
|
|
|
|
self.vocab_size_limit = (vocab_size_limit or |
|
|
int(os.getenv("THEMATIC_VOCAB_SIZE_LIMIT", |
|
|
os.getenv("MAX_VOCABULARY_SIZE", "100000")))) |
|
|
|
|
|
|
|
|
self.vocab_source = os.getenv("VOCAB_SOURCE", "norvig").lower() |
|
|
logger.info(f"📚 Vocabulary source: {self.vocab_source}") |
|
|
|
|
|
|
|
|
if self.vocab_source == "norvig": |
|
|
from .norvig_vocabulary_manager import NorgivVocabularyManager |
|
|
self.vocab_manager = NorgivVocabularyManager(str(self.cache_dir), self.vocab_size_limit) |
|
|
elif self.vocab_source == "wordfreq" and WORDFREQ_AVAILABLE: |
|
|
self.vocab_manager = None |
|
|
else: |
|
|
if not WORDFREQ_AVAILABLE: |
|
|
logger.warning("⚠️ WordFreq not available, falling back to Norvig") |
|
|
self.vocab_source = "norvig" |
|
|
from .norvig_vocabulary_manager import NorgivVocabularyManager |
|
|
self.vocab_manager = NorgivVocabularyManager(str(self.cache_dir), self.vocab_size_limit) |
|
|
else: |
|
|
logger.warning(f"⚠️ Unknown vocab source '{self.vocab_source}', falling back to Norvig") |
|
|
self.vocab_source = "norvig" |
|
|
from .norvig_vocabulary_manager import NorgivVocabularyManager |
|
|
self.vocab_manager = NorgivVocabularyManager(str(self.cache_dir), self.vocab_size_limit) |
|
|
|
|
|
|
|
|
self.similarity_temperature = float(os.getenv("SIMILARITY_TEMPERATURE", "0.2")) |
|
|
self.use_softmax_selection = os.getenv("USE_SOFTMAX_SELECTION", "true").lower() == "true" |
|
|
self.difficulty_weight = float(os.getenv("DIFFICULTY_WEIGHT", "0.5")) |
|
|
self.thematic_pool_size = int(os.getenv("THEMATIC_POOL_SIZE", "150")) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.enable_distribution_normalization = os.getenv("ENABLE_DISTRIBUTION_NORMALIZATION", "false").lower() == "true" |
|
|
self.normalization_method = os.getenv("NORMALIZATION_METHOD", "similarity_range").lower() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.multi_topic_method = os.getenv("MULTI_TOPIC_METHOD", "soft_minimum").lower() |
|
|
self.soft_min_beta = float(os.getenv("SOFT_MIN_BETA", "10.0")) |
|
|
|
|
|
|
|
|
self.soft_min_adaptive = os.getenv("SOFT_MIN_ADAPTIVE", "true").lower() == "true" |
|
|
self.soft_min_min_words = int(os.getenv("SOFT_MIN_MIN_WORDS", "15")) |
|
|
self.soft_min_max_retries = int(os.getenv("SOFT_MIN_MAX_RETRIES", "5")) |
|
|
self.soft_min_beta_decay = float(os.getenv("SOFT_MIN_BETA_DECAY", "0.7")) |
|
|
|
|
|
|
|
|
self.enable_debug_tab = os.getenv("ENABLE_DEBUG_TAB", "false").lower() == "true" |
|
|
|
|
|
|
|
|
|
|
|
self.model: Optional[SentenceTransformer] = None |
|
|
|
|
|
|
|
|
self.vocabulary: List[str] = [] |
|
|
self.word_frequencies: Counter = Counter() |
|
|
self.vocab_embeddings: Optional[torch.Tensor] = None |
|
|
self.frequency_tiers: Dict[str, str] = {} |
|
|
self.tier_descriptions: Dict[str, str] = {} |
|
|
self.device = None |
|
|
self.word_percentiles: Dict[str, float] = {} |
|
|
|
|
|
|
|
|
vocab_hash = f"{self.model_name.replace('/', '_')}_{self.vocab_source}_{self.vocab_size_limit}" |
|
|
self.embeddings_cache_path = self.cache_dir / f"embeddings_{vocab_hash}.pt" |
|
|
|
|
|
self.is_initialized = False |
|
|
|
|
|
def initialize(self): |
|
|
"""Initialize the generator (synchronous version).""" |
|
|
if self.is_initialized: |
|
|
return |
|
|
|
|
|
start_time = time.time() |
|
|
logger.info(f"🚀 Initializing Thematic Word Service...") |
|
|
logger.info(f"📁 Cache directory: {self.cache_dir}") |
|
|
logger.info(f"🤖 Model: {self.model_name}") |
|
|
logger.info(f"📊 Vocabulary size limit: {self.vocab_size_limit:,}") |
|
|
logger.info(f"🔗 Multi-topic method: {self.multi_topic_method}") |
|
|
if self.multi_topic_method == "soft_minimum": |
|
|
logger.info(f"📐 Soft minimum beta: {self.soft_min_beta}") |
|
|
if self.soft_min_adaptive: |
|
|
logger.info(f"🔄 Adaptive beta enabled: min_words={self.soft_min_min_words}, max_retries={self.soft_min_max_retries}, decay={self.soft_min_beta_decay}") |
|
|
else: |
|
|
logger.info(f"🔒 Adaptive beta disabled (using fixed beta)") |
|
|
logger.info(f"🎲 Softmax selection: {self.use_softmax_selection} (T={self.similarity_temperature})") |
|
|
logger.info(f"⚖️ Difficulty weight: {self.difficulty_weight}") |
|
|
|
|
|
|
|
|
if not self.cache_dir.exists(): |
|
|
logger.warning(f"⚠️ Cache directory does not exist, creating: {self.cache_dir}") |
|
|
try: |
|
|
self.cache_dir.mkdir(parents=True, exist_ok=True) |
|
|
except Exception as e: |
|
|
logger.error(f"❌ Failed to create cache directory: {e}") |
|
|
raise |
|
|
|
|
|
|
|
|
vocab_start = time.time() |
|
|
self.vocabulary, self.word_frequencies = self.vocab_manager.load_vocabulary() |
|
|
vocab_time = time.time() - vocab_start |
|
|
logger.info(f"✅ Vocabulary loaded in {vocab_time:.2f}s: {len(self.vocabulary):,} words") |
|
|
|
|
|
|
|
|
self.frequency_tiers = self._create_frequency_tiers() |
|
|
|
|
|
|
|
|
logger.info(f"🤖 Loading embedding model: {self.model_name}") |
|
|
logger.info(f"📂 Cache directory: {self.cache_dir}") |
|
|
logger.info(f"📂 Cache dir exists: {os.path.exists(self.cache_dir)}") |
|
|
if os.path.exists(self.cache_dir): |
|
|
logger.info(f"📂 Cache dir writable: {os.access(self.cache_dir, os.W_OK)}") |
|
|
|
|
|
|
|
|
try: |
|
|
cache_contents = os.listdir(self.cache_dir) |
|
|
logger.info(f"📂 Cache contents ({len(cache_contents)} items): {cache_contents[:5]}...") |
|
|
except Exception as e: |
|
|
logger.error(f"❌ Cannot list cache directory: {e}") |
|
|
|
|
|
|
|
|
model_path = f'sentence-transformers/{self.model_name}' |
|
|
logger.info(f"🔍 Full model path to load: {model_path}") |
|
|
logger.info(f"🔍 Model name from env THEMATIC_MODEL_NAME: {os.getenv('THEMATIC_MODEL_NAME')}") |
|
|
|
|
|
model_start = time.time() |
|
|
|
|
|
try: |
|
|
|
|
|
import torch |
|
|
logger.info(f"🔍 PyTorch CUDA available: {torch.cuda.is_available()}") |
|
|
if torch.cuda.is_available(): |
|
|
logger.info(f"🔍 CUDA device count: {torch.cuda.device_count()}") |
|
|
logger.info(f"🔍 CUDA device name: {torch.cuda.get_device_name(0)}") |
|
|
device = 'cuda' |
|
|
else: |
|
|
logger.info(f"🔍 CUDA not available - checking why...") |
|
|
logger.info(f"🔍 PyTorch version: {torch.__version__}") |
|
|
logger.info(f"🔍 CUDA built: {torch.version.cuda}") |
|
|
logger.info(f"🔍 CUDNN version: {torch.backends.cudnn.version() if torch.backends.cudnn.is_available() else 'Not available'}") |
|
|
device = 'cpu' |
|
|
|
|
|
logger.info(f"🖥️ Using device: {device}") |
|
|
self.device = device |
|
|
|
|
|
self.model = SentenceTransformer( |
|
|
model_path, |
|
|
cache_folder=str(self.cache_dir), |
|
|
device=device |
|
|
) |
|
|
model_time = time.time() - model_start |
|
|
logger.info(f"✅ Model loaded successfully in {model_time:.2f}s") |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"❌ Failed to load SentenceTransformer model: {e}") |
|
|
logger.error(f"🔍 Error type: {type(e).__name__}") |
|
|
logger.error(f"🔍 Model name used: {self.model_name}") |
|
|
logger.error(f"🔍 Constructed path: {model_path}") |
|
|
logger.error(f"🔍 Cache folder: {self.cache_dir}") |
|
|
|
|
|
|
|
|
model_cache_path = self.cache_dir / "models--sentence-transformers--all-mpnet-base-v2" |
|
|
logger.error(f"🔍 Checking model cache path: {model_cache_path}") |
|
|
|
|
|
if model_cache_path.exists(): |
|
|
logger.error(f"📂 Model cache directory exists") |
|
|
try: |
|
|
|
|
|
for root, dirs, files in os.walk(model_cache_path): |
|
|
rel_path = os.path.relpath(root, model_cache_path) |
|
|
if rel_path == ".": |
|
|
rel_path = "root" |
|
|
logger.error(f" 📁 {rel_path}: {len(files)} files - {files[:3]}...") |
|
|
if len(dirs) > 0: |
|
|
logger.error(f" 📂 subdirs: {dirs}") |
|
|
|
|
|
if len(files) > 10: |
|
|
break |
|
|
except Exception as walk_e: |
|
|
logger.error(f"❌ Cannot walk model cache: {walk_e}") |
|
|
else: |
|
|
logger.error(f"📂 Model cache directory does not exist") |
|
|
|
|
|
|
|
|
try: |
|
|
all_items = os.listdir(self.cache_dir) |
|
|
st_items = [item for item in all_items if 'sentence' in item.lower() or 'transform' in item.lower()] |
|
|
logger.error(f"📂 SentenceTransformer-related items in cache: {st_items}") |
|
|
except Exception as list_e: |
|
|
logger.error(f"❌ Cannot check for related items: {list_e}") |
|
|
|
|
|
raise |
|
|
|
|
|
|
|
|
embeddings = self._load_or_create_embeddings() |
|
|
|
|
|
|
|
|
self.vocab_embeddings = embeddings.float().to(self.device) |
|
|
logger.info(f"🚀 Loaded {self.vocab_embeddings.shape[0]} embeddings on {self.device}") |
|
|
|
|
|
if self.device == 'cuda': |
|
|
logger.info(f"💾 GPU memory allocated: {torch.cuda.memory_allocated()/1024**2:.1f}MB") |
|
|
|
|
|
|
|
|
logger.info(f"✅ Embeddings device: {self.vocab_embeddings.device}") |
|
|
|
|
|
self.is_initialized = True |
|
|
total_time = time.time() - start_time |
|
|
logger.info(f"🎉 Unified generator initialized in {total_time:.2f}s") |
|
|
logger.info(f"📊 Vocabulary: {len(self.vocabulary):,} words") |
|
|
logger.info(f"📈 Frequency data: {len(self.word_frequencies):,} words") |
|
|
logger.info(f"🎲 Softmax selection: {'ENABLED' if self.use_softmax_selection else 'DISABLED'}") |
|
|
if self.use_softmax_selection: |
|
|
logger.info(f"🌡️ Similarity temperature: {self.similarity_temperature}") |
|
|
logger.info(f"🎯 Distribution normalization: {'ENABLED' if self.enable_distribution_normalization else 'DISABLED'}") |
|
|
if self.enable_distribution_normalization: |
|
|
logger.info(f"🔧 Normalization method: {self.normalization_method}") |
|
|
|
|
|
async def initialize_async(self): |
|
|
"""Initialize the generator (async version for backend compatibility).""" |
|
|
return self.initialize() |
|
|
|
|
|
def _load_or_create_embeddings(self) -> torch.Tensor: |
|
|
"""Load embeddings from cache or create them.""" |
|
|
|
|
|
if self.embeddings_cache_path.exists(): |
|
|
try: |
|
|
logger.info(f"📦 Loading embeddings from cache: {self.embeddings_cache_path}") |
|
|
|
|
|
|
|
|
if not os.access(self.embeddings_cache_path, os.R_OK): |
|
|
logger.warning(f"⚠️ Embeddings cache file not readable: {self.embeddings_cache_path}") |
|
|
return self._create_embeddings_from_scratch() |
|
|
|
|
|
embeddings = torch.load(self.embeddings_cache_path, map_location='cpu', weights_only=True) |
|
|
|
|
|
|
|
|
if embeddings.shape[0] != len(self.vocabulary): |
|
|
logger.warning(f"⚠️ Embeddings shape mismatch: cache={embeddings.shape[0]}, vocab={len(self.vocabulary)}") |
|
|
logger.warning("🔄 Vocabulary size changed, recreating embeddings...") |
|
|
return self._create_embeddings_from_scratch() |
|
|
|
|
|
logger.info(f"✅ Loaded embeddings from cache: {embeddings.shape}") |
|
|
return embeddings |
|
|
except Exception as e: |
|
|
logger.warning(f"⚠️ Embeddings cache loading failed: {e}") |
|
|
return self._create_embeddings_from_scratch() |
|
|
else: |
|
|
logger.info(f"📂 Embeddings cache not found: {self.embeddings_cache_path}") |
|
|
return self._create_embeddings_from_scratch() |
|
|
|
|
|
def _create_embeddings_from_scratch(self) -> torch.Tensor: |
|
|
|
|
|
|
|
|
logger.info("🔄 Creating embeddings for vocabulary...") |
|
|
start_time = time.time() |
|
|
|
|
|
|
|
|
batch_size = 512 |
|
|
all_embeddings = [] |
|
|
|
|
|
for i in range(0, len(self.vocabulary), batch_size): |
|
|
batch_words = self.vocabulary[i:i + batch_size] |
|
|
batch_embeddings = self.model.encode( |
|
|
batch_words, |
|
|
convert_to_tensor=True, |
|
|
show_progress_bar=i == 0 |
|
|
).cpu() |
|
|
all_embeddings.append(batch_embeddings) |
|
|
|
|
|
if i % (batch_size * 10) == 0: |
|
|
logger.info(f"📊 Embeddings progress: {i:,}/{len(self.vocabulary):,}") |
|
|
|
|
|
embeddings = torch.cat(all_embeddings, dim=0) |
|
|
embedding_time = time.time() - start_time |
|
|
logger.info(f"✅ Created embeddings in {embedding_time:.2f}s: {embeddings.shape}") |
|
|
|
|
|
|
|
|
try: |
|
|
torch.save(embeddings, self.embeddings_cache_path) |
|
|
logger.info("💾 Embeddings cached successfully") |
|
|
except Exception as e: |
|
|
logger.warning(f"⚠️ Embeddings cache saving failed: {e}") |
|
|
|
|
|
return embeddings |
|
|
|
|
|
def _create_frequency_tiers(self) -> Dict[str, str]: |
|
|
"""Create 10-tier frequency classification system and calculate word percentiles.""" |
|
|
if not self.word_frequencies: |
|
|
return {} |
|
|
|
|
|
logger.info("📊 Creating frequency tiers and percentiles...") |
|
|
|
|
|
tiers = {} |
|
|
percentiles = {} |
|
|
|
|
|
|
|
|
all_counts = list(self.word_frequencies.values()) |
|
|
all_counts.sort(reverse=True) |
|
|
|
|
|
|
|
|
|
|
|
count_to_rank = {} |
|
|
for rank, count in enumerate(all_counts): |
|
|
if count not in count_to_rank: |
|
|
count_to_rank[count] = rank |
|
|
|
|
|
|
|
|
tier_definitions = [ |
|
|
("tier_1_ultra_common", 0.999, "Ultra Common (Top 0.1%)"), |
|
|
("tier_2_extremely_common", 0.995, "Extremely Common (Top 0.5%)"), |
|
|
("tier_3_very_common", 0.99, "Very Common (Top 1%)"), |
|
|
("tier_4_highly_common", 0.97, "Highly Common (Top 3%)"), |
|
|
("tier_5_common", 0.92, "Common (Top 8%)"), |
|
|
("tier_6_moderately_common", 0.85, "Moderately Common (Top 15%)"), |
|
|
("tier_7_somewhat_uncommon", 0.70, "Somewhat Uncommon (Top 30%)"), |
|
|
("tier_8_uncommon", 0.50, "Uncommon (Top 50%)"), |
|
|
("tier_9_rare", 0.25, "Rare (Top 75%)"), |
|
|
("tier_10_very_rare", 0.0, "Very Rare (Bottom 25%)") |
|
|
] |
|
|
|
|
|
|
|
|
thresholds = [] |
|
|
for tier_name, percentile, description in tier_definitions: |
|
|
if percentile > 0: |
|
|
idx = int((1 - percentile) * len(all_counts)) |
|
|
threshold = all_counts[min(idx, len(all_counts) - 1)] |
|
|
else: |
|
|
threshold = 0 |
|
|
thresholds.append((tier_name, threshold, description)) |
|
|
|
|
|
|
|
|
self.tier_descriptions = {name: desc for name, _, desc in thresholds} |
|
|
|
|
|
|
|
|
for word, count in self.word_frequencies.items(): |
|
|
|
|
|
rank = count_to_rank.get(count, len(all_counts) - 1) |
|
|
percentile = 1.0 - (rank / len(all_counts)) |
|
|
percentiles[word] = percentile |
|
|
|
|
|
|
|
|
assigned = False |
|
|
for tier_name, threshold, description in thresholds: |
|
|
if count >= threshold: |
|
|
tiers[word] = tier_name |
|
|
assigned = True |
|
|
break |
|
|
|
|
|
if not assigned: |
|
|
tiers[word] = "tier_10_very_rare" |
|
|
|
|
|
|
|
|
for word in self.vocabulary: |
|
|
if word not in tiers: |
|
|
tiers[word] = "tier_10_very_rare" |
|
|
percentiles[word] = 0.0 |
|
|
|
|
|
|
|
|
self.word_percentiles = percentiles |
|
|
|
|
|
|
|
|
tier_counts = Counter(tiers.values()) |
|
|
logger.info(f"✅ Created frequency tiers:") |
|
|
for tier_name, count in sorted(tier_counts.items()): |
|
|
desc = self.tier_descriptions.get(tier_name, tier_name) |
|
|
logger.info(f" {desc}: {count:,} words") |
|
|
|
|
|
|
|
|
percentile_values = list(percentiles.values()) |
|
|
if percentile_values: |
|
|
avg_percentile = np.mean(percentile_values) |
|
|
logger.info(f"📈 Percentile statistics: avg={avg_percentile:.3f}, range=0.000-1.000") |
|
|
|
|
|
return tiers |
|
|
|
|
|
def generate_thematic_words(self, |
|
|
inputs, |
|
|
num_words: int = 100, |
|
|
min_similarity: float = 0.3, |
|
|
multi_theme: bool = False, |
|
|
difficulty: str = "medium") -> List[Tuple[str, float, str]]: |
|
|
"""Generate thematically related words from input seeds. |
|
|
|
|
|
Args: |
|
|
inputs: Single string, or list of words/sentences as theme seeds |
|
|
num_words: Number of words to return |
|
|
min_similarity: Minimum similarity threshold |
|
|
multi_theme: Whether to detect and use multiple themes |
|
|
difficulty: Difficulty level ("easy", "medium", "hard") for frequency-aware selection |
|
|
|
|
|
Returns: |
|
|
List of (word, similarity_score, frequency_tier) tuples |
|
|
""" |
|
|
if not self.is_initialized: |
|
|
self.initialize() |
|
|
|
|
|
|
|
|
if self.device == 'cuda': |
|
|
logger.info(f"📾 GPU memory before generation: {torch.cuda.memory_allocated()/1024**2:.1f}MB / {torch.cuda.max_memory_allocated()/1024**2:.1f}MB max") |
|
|
|
|
|
logger.info(f"🎯 Generating {num_words} thematic words") |
|
|
|
|
|
|
|
|
if isinstance(inputs, str): |
|
|
inputs = [inputs] |
|
|
|
|
|
if not inputs: |
|
|
return [] |
|
|
|
|
|
|
|
|
clean_inputs = [inp.strip().lower() for inp in inputs if inp.strip()] |
|
|
if not clean_inputs: |
|
|
return [] |
|
|
|
|
|
logger.info(f"📝 Input themes: {clean_inputs}") |
|
|
logger.info(f"📊 Difficulty level: {difficulty} (using frequency-aware selection)") |
|
|
|
|
|
|
|
|
|
|
|
auto_multi_theme = len(clean_inputs) > 2 |
|
|
final_multi_theme = multi_theme or auto_multi_theme |
|
|
|
|
|
logger.info(f"🔍 Multi-theme detection: {final_multi_theme} (auto: {auto_multi_theme}, manual: {multi_theme})") |
|
|
|
|
|
if final_multi_theme: |
|
|
theme_vectors = self._detect_multiple_themes(clean_inputs) |
|
|
logger.info(f"📊 Detected {len(theme_vectors)} themes") |
|
|
else: |
|
|
theme_vectors = [self._compute_theme_vector(clean_inputs)] |
|
|
logger.info("📊 Using single theme vector") |
|
|
|
|
|
|
|
|
if len(theme_vectors) > 1 and self.multi_topic_method != "averaging": |
|
|
logger.info(f"🔗 Using {self.multi_topic_method} method for {len(theme_vectors)} topic vectors") |
|
|
if self.multi_topic_method == "soft_minimum": |
|
|
logger.info(f"📐 Soft minimum beta parameter: {self.soft_min_beta}") |
|
|
all_similarities_np, effective_threshold = self._compute_multi_topic_similarities(theme_vectors, self.vocab_embeddings, min_similarity) |
|
|
|
|
|
all_similarities = torch.from_numpy(all_similarities_np).float().to(self.vocab_embeddings.device) |
|
|
else: |
|
|
|
|
|
logger.info(f"🔗 Using averaging method for {len(theme_vectors)} topic vectors") |
|
|
all_similarities = torch.zeros(len(self.vocabulary), device=self.vocab_embeddings.device) |
|
|
for theme_vector in theme_vectors: |
|
|
|
|
|
similarities = self._compute_similarities_torch(theme_vector).flatten() |
|
|
all_similarities += similarities / len(theme_vectors) |
|
|
effective_threshold = min_similarity |
|
|
|
|
|
logger.info("✅ Computed semantic similarities") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
top_indices = torch.argsort(all_similarities, descending=True) |
|
|
|
|
|
|
|
|
results = [] |
|
|
input_words_set = set(clean_inputs) |
|
|
logger.info(f"{clean_inputs=}") |
|
|
|
|
|
|
|
|
|
|
|
for idx in top_indices: |
|
|
idx_item = idx.item() |
|
|
similarity_score = all_similarities[idx].item() |
|
|
word = self.vocabulary[idx_item] |
|
|
|
|
|
|
|
|
if similarity_score < effective_threshold: |
|
|
break |
|
|
|
|
|
|
|
|
if len(results) >= num_words: |
|
|
break |
|
|
|
|
|
|
|
|
if word.lower() in input_words_set: |
|
|
continue |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
word_tier = self.frequency_tiers.get(word, "tier_10_very_rare") |
|
|
|
|
|
results.append((word, similarity_score, word_tier)) |
|
|
|
|
|
|
|
|
|
|
|
results.sort(key=lambda x: x[1], reverse=True) |
|
|
final_results = results[:num_words] |
|
|
|
|
|
logger.info(f"✅ Generated {len(final_results)} thematic words (deterministic)") |
|
|
words_by_similarity = '\n'.join([result[0] for result in final_results]) |
|
|
logger.info(f"Sorted by similarity: \n{words_by_similarity}") |
|
|
return final_results |
|
|
|
|
|
def _compute_theme_vector(self, inputs: List[str]) -> np.ndarray: |
|
|
"""Compute semantic centroid from input words/sentences.""" |
|
|
logger.info(f"🎯 Computing theme vector for {len(inputs)} inputs") |
|
|
|
|
|
|
|
|
input_embeddings_tensor = self.model.encode(inputs, convert_to_tensor=True, show_progress_bar=False) |
|
|
logger.info(f"✅ Encoded {len(inputs)} inputs") |
|
|
|
|
|
|
|
|
theme_vector_tensor = torch.mean(input_embeddings_tensor, dim=0) |
|
|
|
|
|
|
|
|
theme_vector = theme_vector_tensor.cpu().numpy() |
|
|
|
|
|
return theme_vector.reshape(1, -1) |
|
|
|
|
|
def _compute_similarities(self, query_vectors: np.ndarray) -> np.ndarray: |
|
|
"""Compute cosine similarities using PyTorch (works on both CPU and GPU). |
|
|
|
|
|
Args: |
|
|
query_vectors: Query vectors of shape (n_queries, dim) |
|
|
|
|
|
Returns: |
|
|
Similarity matrix of shape (n_vocab, n_queries) as numpy array for backward compatibility |
|
|
""" |
|
|
|
|
|
query_tensor = torch.from_numpy(query_vectors).float().to(self.vocab_embeddings.device) |
|
|
|
|
|
|
|
|
query_norm = F.normalize(query_tensor, p=2, dim=1) |
|
|
vocab_norm = F.normalize(self.vocab_embeddings, p=2, dim=1) |
|
|
|
|
|
|
|
|
similarities = torch.mm(vocab_norm, query_norm.T) |
|
|
|
|
|
|
|
|
return similarities.cpu().numpy() |
|
|
|
|
|
def _compute_similarities_torch(self, query_vectors: np.ndarray) -> torch.Tensor: |
|
|
"""Compute cosine similarities using PyTorch, return PyTorch tensor. |
|
|
|
|
|
Args: |
|
|
query_vectors: Query vectors of shape (n_queries, dim) |
|
|
|
|
|
Returns: |
|
|
Similarity matrix of shape (n_vocab, n_queries) as torch tensor |
|
|
""" |
|
|
|
|
|
query_tensor = torch.from_numpy(query_vectors).float().to(self.vocab_embeddings.device) |
|
|
|
|
|
|
|
|
query_norm = F.normalize(query_tensor, p=2, dim=1) |
|
|
vocab_norm = F.normalize(self.vocab_embeddings, p=2, dim=1) |
|
|
|
|
|
|
|
|
similarities = torch.mm(vocab_norm, query_norm.T) |
|
|
|
|
|
|
|
|
return similarities |
|
|
|
|
|
def _compute_multi_topic_similarities(self, topic_vectors: List[np.ndarray], vocab_embeddings: np.ndarray, min_similarity: float = 0.3) -> tuple[np.ndarray, float]: |
|
|
""" |
|
|
Compute word similarities using configurable multi-topic intersection methods. |
|
|
|
|
|
This method replaces simple averaging with more sophisticated intersection approaches |
|
|
that find words genuinely relevant to ALL topics, not just diluted combinations. |
|
|
|
|
|
Based on experimental results from docs/multi_vector_word_finding.md: |
|
|
- Simple averaging promotes problematic words like "ethology", "guns" for Art+Books |
|
|
- Soft minimum successfully filters these while promoting true intersections like "literature" |
|
|
- Geometric/harmonic means provide intermediate approaches |
|
|
|
|
|
Args: |
|
|
topic_vectors: List of topic embedding vectors (each is 1×embedding_dim) |
|
|
vocab_embeddings: Vocabulary embeddings matrix (vocab_size×embedding_dim) |
|
|
|
|
|
Returns: |
|
|
Tuple of (similarity_scores, effective_threshold) where: |
|
|
- similarity_scores: Array of similarity scores for each vocabulary word using the configured method |
|
|
- effective_threshold: The threshold that should be used for filtering (adjusted for adaptive beta) |
|
|
""" |
|
|
method = self.multi_topic_method |
|
|
vocab_size = vocab_embeddings.shape[0] |
|
|
|
|
|
if method == "averaging": |
|
|
|
|
|
all_similarities = np.zeros(vocab_size) |
|
|
for theme_vector in topic_vectors: |
|
|
similarities = cosine_similarity(theme_vector, vocab_embeddings)[0] |
|
|
all_similarities += similarities / len(topic_vectors) |
|
|
return all_similarities, min_similarity |
|
|
|
|
|
elif method == "soft_minimum": |
|
|
|
|
|
|
|
|
beta = self.soft_min_beta |
|
|
|
|
|
|
|
|
topic_matrix = np.vstack([tv.reshape(-1) for tv in topic_vectors]) |
|
|
similarities_matrix = self._compute_similarities(topic_matrix) |
|
|
|
|
|
|
|
|
if self.soft_min_adaptive: |
|
|
logger.info(f"🔄 Adaptive beta enabled: initial={beta:.1f}, min_words={self.soft_min_min_words}") |
|
|
|
|
|
|
|
|
final_adjusted_threshold = min_similarity |
|
|
|
|
|
for attempt in range(self.soft_min_max_retries): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
soft_min_scores = -np.log(np.sum(np.exp(-beta * similarities_matrix), axis=1)) / beta |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
base_beta = 10.0 |
|
|
adjusted_threshold = min_similarity * (beta / base_beta) |
|
|
|
|
|
|
|
|
num_valid_words = np.sum(soft_min_scores > adjusted_threshold) |
|
|
|
|
|
|
|
|
score_stats = { |
|
|
'min': float(np.min(soft_min_scores)), |
|
|
'max': float(np.max(soft_min_scores)), |
|
|
'mean': float(np.mean(soft_min_scores)), |
|
|
'threshold': adjusted_threshold, |
|
|
'orig_threshold': min_similarity, |
|
|
'above_threshold': int(num_valid_words) |
|
|
} |
|
|
logger.info(f"🔍 Beta={beta:.1f}: scores[{score_stats['min']:.3f}, {score_stats['max']:.3f}], mean={score_stats['mean']:.3f}, adj_threshold={score_stats['threshold']:.3f} (orig={score_stats['orig_threshold']:.3f}), valid={score_stats['above_threshold']}") |
|
|
|
|
|
if num_valid_words >= self.soft_min_min_words: |
|
|
|
|
|
final_adjusted_threshold = adjusted_threshold |
|
|
if attempt > 0: |
|
|
logger.info(f"✅ Adaptive beta converged: beta={beta:.1f}, valid_words={num_valid_words} (attempt {attempt+1})") |
|
|
else: |
|
|
logger.info(f"✅ Initial beta sufficient: beta={beta:.1f}, valid_words={num_valid_words}") |
|
|
break |
|
|
|
|
|
|
|
|
if attempt < self.soft_min_max_retries - 1: |
|
|
old_beta = beta |
|
|
beta = beta * self.soft_min_beta_decay |
|
|
logger.info(f"🔄 Retry {attempt+1}: Relaxing beta {old_beta:.1f}→{beta:.1f} (only {num_valid_words} valid words)") |
|
|
else: |
|
|
logger.warning(f"⚠️ Max retries reached: beta={beta:.1f}, valid_words={num_valid_words}") |
|
|
|
|
|
return soft_min_scores, final_adjusted_threshold |
|
|
else: |
|
|
|
|
|
soft_min_scores = -np.log(np.sum(np.exp(-beta * similarities_matrix), axis=1)) / beta |
|
|
return soft_min_scores, min_similarity |
|
|
|
|
|
elif method == "geometric_mean": |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
topic_matrix = np.vstack([tv.reshape(-1) for tv in topic_vectors]) |
|
|
similarities_matrix = self._compute_similarities(topic_matrix) |
|
|
|
|
|
|
|
|
similarities_matrix = np.maximum(similarities_matrix, 0.001) |
|
|
|
|
|
|
|
|
geo_means = np.exp(np.mean(np.log(similarities_matrix), axis=1)) |
|
|
|
|
|
return geo_means, min_similarity |
|
|
|
|
|
elif method == "harmonic_mean": |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
topic_matrix = np.vstack([tv.reshape(-1) for tv in topic_vectors]) |
|
|
similarities_matrix = self._compute_similarities(topic_matrix) |
|
|
|
|
|
|
|
|
similarities_matrix = np.maximum(similarities_matrix, 0.001) |
|
|
|
|
|
|
|
|
harmonic_means = similarities_matrix.shape[1] / np.sum(1.0 / similarities_matrix, axis=1) |
|
|
|
|
|
return harmonic_means, min_similarity |
|
|
|
|
|
else: |
|
|
|
|
|
logger.warning(f"⚠️ Unknown multi-topic method '{method}', falling back to averaging") |
|
|
all_similarities = np.zeros(vocab_size) |
|
|
for theme_vector in topic_vectors: |
|
|
similarities = cosine_similarity(theme_vector, vocab_embeddings)[0] |
|
|
all_similarities += similarities / len(topic_vectors) |
|
|
return all_similarities, min_similarity |
|
|
|
|
|
def _compute_composite_score(self, similarity: float, word: str, difficulty: str = "medium") -> float: |
|
|
""" |
|
|
Combine semantic similarity with frequency-based difficulty alignment using ML feature engineering. |
|
|
|
|
|
This is the core of the difficulty-aware selection system. It creates a composite score |
|
|
that balances two key factors: |
|
|
1. Semantic Relevance: How well the word matches the theme (similarity score) |
|
|
2. Difficulty Alignment: How well the word's frequency matches the desired difficulty |
|
|
|
|
|
Frequency Alignment uses Gaussian distributions to create smooth preference curves: |
|
|
|
|
|
Easy Mode (targets common words): |
|
|
- Gaussian peak at 90th percentile with narrow width (σ=0.1) |
|
|
- Words like CAT (95th percentile) get high scores |
|
|
- Words like QUETZAL (15th percentile) get low scores |
|
|
- Formula: exp(-((percentile - 0.9)² / (2 * 0.1²))) |
|
|
|
|
|
Hard Mode (targets rare words): |
|
|
- Gaussian peak at 20th percentile with moderate width (σ=0.15) |
|
|
- Words like QUETZAL (15th percentile) get high scores |
|
|
- Words like CAT (95th percentile) get low scores |
|
|
- Formula: exp(-((percentile - 0.2)² / (2 * 0.15²))) |
|
|
|
|
|
Medium Mode (balanced): |
|
|
- Flatter distribution with slight peak at 50th percentile (σ=0.3) |
|
|
- Base score of 0.5 plus Gaussian bonus |
|
|
- Less extreme preference, more balanced selection |
|
|
- Formula: 0.5 + 0.5 * exp(-((percentile - 0.5)² / (2 * 0.3²))) |
|
|
|
|
|
Final Weighting: |
|
|
composite = (1 - difficulty_weight) * similarity + difficulty_weight * frequency_alignment |
|
|
|
|
|
Where difficulty_weight (default 0.3) controls the balance: |
|
|
- Higher weight = more frequency influence, less similarity influence |
|
|
- Lower weight = more similarity influence, less frequency influence |
|
|
|
|
|
Example Calculations: |
|
|
Theme: "animals", difficulty_weight=0.3 |
|
|
|
|
|
Easy mode: |
|
|
- CAT: similarity=0.8, percentile=0.95 → freq_score=0.61 → composite=0.74 |
|
|
- PLATYPUS: similarity=0.9, percentile=0.15 → freq_score=0.01 → composite=0.63 |
|
|
- Result: CAT wins despite lower similarity (common word bonus) |
|
|
|
|
|
Hard mode: |
|
|
- CAT: similarity=0.8, percentile=0.95 → freq_score=0.01 → composite=0.32 |
|
|
- PLATYPUS: similarity=0.9, percentile=0.15 → freq_score=0.94 → composite=0.64 |
|
|
- Result: PLATYPUS wins due to rarity bonus |
|
|
|
|
|
Args: |
|
|
similarity: Semantic similarity score (0-1) from sentence transformer |
|
|
word: The word to get frequency percentile for |
|
|
difficulty: "easy", "medium", or "hard" - determines frequency preference curve |
|
|
|
|
|
Returns: |
|
|
Composite score (0-1) combining semantic relevance and difficulty alignment |
|
|
""" |
|
|
|
|
|
percentile = self.word_percentiles.get(word.lower(), 0.0) |
|
|
|
|
|
|
|
|
if difficulty == "easy": |
|
|
|
|
|
freq_score = np.exp(-((percentile - 0.9) ** 2) / (2 * 0.1 ** 2)) |
|
|
elif difficulty == "hard": |
|
|
|
|
|
freq_score = np.exp(-((percentile - 0.2) ** 2) / (2 * 0.15 ** 2)) |
|
|
else: |
|
|
|
|
|
freq_score = 0.5 + 0.5 * np.exp(-((percentile - 0.5) ** 2) / (2 * 0.3 ** 2)) |
|
|
|
|
|
|
|
|
final_alpha = 1.0 - self.difficulty_weight |
|
|
final_beta = self.difficulty_weight |
|
|
|
|
|
composite = final_alpha * similarity + final_beta * freq_score |
|
|
return composite |
|
|
|
|
|
def _apply_distribution_normalization(self, composite_scores: np.ndarray, candidates: List[Dict[str, Any]], difficulty: str) -> np.ndarray: |
|
|
""" |
|
|
Apply distribution normalization to ensure consistent difficulty distributions across topics. |
|
|
|
|
|
This method normalizes the composite score distribution to ensure that the same difficulty level |
|
|
produces consistent selection patterns regardless of the topic's inherent semantic similarity range. |
|
|
|
|
|
Args: |
|
|
composite_scores: Raw composite scores from similarity + frequency alignment |
|
|
candidates: List of candidate word dictionaries |
|
|
difficulty: Difficulty level for target percentile calculation |
|
|
|
|
|
Returns: |
|
|
Normalized composite scores with consistent distribution shape |
|
|
""" |
|
|
if len(composite_scores) <= 1: |
|
|
return composite_scores |
|
|
|
|
|
method = self.normalization_method.lower() |
|
|
|
|
|
if method == "similarity_range": |
|
|
|
|
|
|
|
|
similarities = np.array([c['similarity'] for c in candidates]) |
|
|
if len(similarities) > 1 and np.std(similarities) > 0: |
|
|
min_sim, max_sim = np.min(similarities), np.max(similarities) |
|
|
if max_sim > min_sim: |
|
|
|
|
|
normalized_scores = [] |
|
|
for i, candidate in enumerate(candidates): |
|
|
normalized_sim = (candidate['similarity'] - min_sim) / (max_sim - min_sim) |
|
|
word = candidate['word'] |
|
|
|
|
|
percentile = self.word_percentiles.get(word.lower(), 0.0) |
|
|
|
|
|
|
|
|
if difficulty == "easy": |
|
|
freq_score = np.exp(-((percentile - 0.9) ** 2) / (2 * 0.1 ** 2)) |
|
|
elif difficulty == "hard": |
|
|
freq_score = np.exp(-((percentile - 0.2) ** 2) / (2 * 0.15 ** 2)) |
|
|
else: |
|
|
freq_score = 0.5 + 0.5 * np.exp(-((percentile - 0.5) ** 2) / (2 * 0.3 ** 2)) |
|
|
|
|
|
|
|
|
final_alpha = 1.0 - self.difficulty_weight |
|
|
final_beta = self.difficulty_weight |
|
|
composite = final_alpha * normalized_sim + final_beta * freq_score |
|
|
normalized_scores.append(composite) |
|
|
|
|
|
return np.array(normalized_scores) |
|
|
|
|
|
elif method == "composite_zscore": |
|
|
|
|
|
|
|
|
mean_score = np.mean(composite_scores) |
|
|
std_score = np.std(composite_scores) |
|
|
if std_score > 0: |
|
|
return (composite_scores - mean_score) / std_score |
|
|
|
|
|
elif method == "percentile_recentering": |
|
|
|
|
|
target_percentiles = {"easy": 0.9, "medium": 0.5, "hard": 0.2} |
|
|
target = target_percentiles.get(difficulty, 0.5) |
|
|
|
|
|
|
|
|
percentiles = np.array([self.word_percentiles.get(c['word'].lower(), 0.0) for c in candidates]) |
|
|
|
|
|
|
|
|
current_center = np.mean(percentiles) |
|
|
shift = target - current_center |
|
|
|
|
|
|
|
|
percentile_alignment = np.exp(-((percentiles - target) ** 2) / (2 * 0.2 ** 2)) |
|
|
boosted_scores = composite_scores * (1 + 0.5 * percentile_alignment) |
|
|
return boosted_scores |
|
|
|
|
|
|
|
|
return composite_scores |
|
|
|
|
|
def _softmax_with_temperature(self, scores: np.ndarray, temperature: float = 1.0) -> np.ndarray: |
|
|
""" |
|
|
Apply softmax with temperature control to similarity scores. |
|
|
|
|
|
Args: |
|
|
scores: Array of similarity scores |
|
|
temperature: Temperature parameter (lower = more deterministic, higher = more random) |
|
|
- temperature < 1.0: More deterministic (favor high similarity) |
|
|
- temperature = 1.0: Standard softmax |
|
|
- temperature > 1.0: More random (flatten differences) |
|
|
|
|
|
Returns: |
|
|
Probability distribution over the scores |
|
|
""" |
|
|
if temperature <= 0: |
|
|
temperature = 0.01 |
|
|
|
|
|
|
|
|
scaled_scores = scores / temperature |
|
|
|
|
|
|
|
|
max_score = np.max(scaled_scores) |
|
|
exp_scores = np.exp(scaled_scores - max_score) |
|
|
probabilities = exp_scores / np.sum(exp_scores) |
|
|
|
|
|
return probabilities |
|
|
|
|
|
def _softmax_weighted_selection(self, candidates: List[Dict[str, Any]], |
|
|
num_words: int, temperature: float = None, difficulty: str = "medium") -> Tuple[List[Dict[str, Any]], Dict[str, Any]]: |
|
|
""" |
|
|
Select words using softmax-based probabilistic sampling weighted by composite scores. |
|
|
|
|
|
This function implements a machine learning approach to word selection that combines: |
|
|
1. Semantic similarity (how relevant the word is to the theme) |
|
|
2. Frequency percentiles (how common/rare the word is) |
|
|
3. Difficulty preference (which frequencies are preferred for easy/medium/hard) |
|
|
4. Temperature-controlled randomness (exploration vs exploitation balance) |
|
|
|
|
|
Temperature Effects: |
|
|
- temperature < 1.0: More deterministic selection, strongly favors highest composite scores |
|
|
- temperature = 1.0: Standard softmax probability distribution |
|
|
- temperature > 1.0: More random selection, flattens differences between scores |
|
|
- Default 0.7: Balanced between determinism and exploration |
|
|
|
|
|
Difficulty Effects (via composite scoring): |
|
|
- "easy": Gaussian peak at 90th percentile (favors common words like CAT, DOG) |
|
|
- "medium": Balanced distribution around 50th percentile (moderate preference) |
|
|
- "hard": Gaussian peak at 20th percentile (favors rare words like QUETZAL, PLATYPUS) |
|
|
|
|
|
Composite Score Formula: |
|
|
composite = (1 - difficulty_weight) * similarity + difficulty_weight * frequency_alignment |
|
|
|
|
|
Where frequency_alignment uses Gaussian curves to score how well a word's |
|
|
percentile matches the difficulty preference. |
|
|
|
|
|
Example Scenario: |
|
|
Theme: "animals", Easy difficulty, Temperature: 0.7 |
|
|
- CAT: similarity=0.8, percentile=0.95 → high composite score (common + relevant) |
|
|
- PLATYPUS: similarity=0.9, percentile=0.15 → lower composite (rare word penalized in easy mode) |
|
|
- Result: CAT more likely to be selected despite lower similarity |
|
|
|
|
|
Args: |
|
|
candidates: List of word dictionaries with similarity scores |
|
|
num_words: Number of words to select |
|
|
temperature: Temperature for softmax (None to use instance default of 0.7) |
|
|
difficulty: Difficulty level ("easy", "medium", "hard") for frequency weighting |
|
|
|
|
|
Returns: |
|
|
Tuple of (selected_word_dictionaries, probability_distribution_data) |
|
|
- selected_word_dictionaries: Words chosen for crossword |
|
|
- probability_distribution_data: Dict with candidate probabilities for debug visualization |
|
|
""" |
|
|
if len(candidates) <= num_words: |
|
|
|
|
|
prob_data = { |
|
|
"probabilities": [{"word": c["word"], "probability": 1.0/len(candidates), "composite_score": 0.0, "selected": True, "rank": i+1} |
|
|
for i, c in enumerate(candidates)] |
|
|
} |
|
|
return candidates, prob_data |
|
|
|
|
|
if temperature is None: |
|
|
temperature = self.similarity_temperature |
|
|
|
|
|
|
|
|
composite_scores = [] |
|
|
debug_info = [] |
|
|
for word_data in candidates: |
|
|
similarity = word_data['similarity'] |
|
|
word = word_data['word'] |
|
|
composite = self._compute_composite_score(similarity, word, difficulty) |
|
|
composite_scores.append(composite) |
|
|
|
|
|
|
|
|
if len(debug_info) < 10: |
|
|
percentile = self.word_percentiles.get(word.lower(), 0.0) |
|
|
debug_info.append({ |
|
|
'word': word, |
|
|
'similarity': similarity, |
|
|
'percentile': percentile, |
|
|
'composite': composite, |
|
|
'tier': word_data.get('tier', 'unknown') |
|
|
}) |
|
|
|
|
|
composite_scores = np.array(composite_scores) |
|
|
|
|
|
|
|
|
original_composite_scores = composite_scores.copy() |
|
|
if self.enable_distribution_normalization: |
|
|
composite_scores = self._apply_distribution_normalization(composite_scores, candidates, difficulty) |
|
|
logger.info(f"🎯 Applied distribution normalization ({self.normalization_method})") |
|
|
|
|
|
|
|
|
logger.info(f"🔍 Debug: Top 10 composite scores for difficulty={difficulty}:") |
|
|
for info in debug_info: |
|
|
logger.info(f" {info['word']:<15} sim:{info['similarity']:.3f} perc:{info['percentile']:.3f} comp:{info['composite']:.3f} ({info['tier']})") |
|
|
|
|
|
|
|
|
probabilities = self._softmax_with_temperature(composite_scores, temperature) |
|
|
|
|
|
|
|
|
selected_indices = np.random.choice( |
|
|
len(candidates), |
|
|
size=min(num_words, len(candidates)), |
|
|
replace=False, |
|
|
p=probabilities |
|
|
) |
|
|
|
|
|
|
|
|
selected_candidates = [candidates[i] for i in selected_indices] |
|
|
selected_word_set = {candidates[i]["word"] for i in selected_indices} |
|
|
|
|
|
logger.info(f"🎲 Composite softmax selection (T={temperature:.2f}, difficulty={difficulty}): {len(selected_candidates)} from {len(candidates)} candidates") |
|
|
|
|
|
|
|
|
logger.info(f"🎯 Selected words for difficulty={difficulty}:") |
|
|
for word_data in selected_candidates[:10]: |
|
|
word = word_data['word'] |
|
|
similarity = word_data['similarity'] |
|
|
percentile = self.word_percentiles.get(word.lower(), 0.0) |
|
|
composite = self._compute_composite_score(similarity, word, difficulty) |
|
|
tier = word_data.get('tier', 'unknown') |
|
|
logger.info(f" {word:<15} sim:{similarity:.3f} perc:{percentile:.3f} comp:{composite:.3f} ({tier})") |
|
|
|
|
|
|
|
|
prob_distribution = [] |
|
|
for i, candidate in enumerate(candidates): |
|
|
prob_item = { |
|
|
"word": candidate["word"], |
|
|
"probability": float(probabilities[i]), |
|
|
"composite_score": float(composite_scores[i]), |
|
|
"selected": candidate["word"] in selected_word_set, |
|
|
"rank": i + 1, |
|
|
"similarity": candidate["similarity"], |
|
|
"tier": candidate.get("tier", "unknown"), |
|
|
"percentile": self.word_percentiles.get(candidate["word"].lower(), 0.0) |
|
|
} |
|
|
|
|
|
|
|
|
if self.enable_distribution_normalization and 'original_composite_scores' in locals(): |
|
|
prob_item["original_composite_score"] = float(original_composite_scores[i]) |
|
|
prob_item["normalization_applied"] = True |
|
|
prob_item["normalization_method"] = self.normalization_method |
|
|
else: |
|
|
prob_item["normalization_applied"] = False |
|
|
|
|
|
prob_distribution.append(prob_item) |
|
|
|
|
|
|
|
|
prob_distribution.sort(key=lambda x: x["probability"], reverse=True) |
|
|
|
|
|
|
|
|
for i, item in enumerate(prob_distribution): |
|
|
item["probability_rank"] = i + 1 |
|
|
|
|
|
prob_data = { |
|
|
"probabilities": prob_distribution, |
|
|
"temperature": temperature, |
|
|
"difficulty": difficulty, |
|
|
"total_candidates": len(candidates), |
|
|
"selected_count": len(selected_candidates), |
|
|
"normalization_enabled": self.enable_distribution_normalization, |
|
|
"normalization_method": self.normalization_method if self.enable_distribution_normalization else None |
|
|
} |
|
|
|
|
|
return selected_candidates, prob_data |
|
|
|
|
|
def _detect_multiple_themes(self, inputs: List[str], max_themes: int = 3) -> List[np.ndarray]: |
|
|
"""Detect multiple themes using clustering.""" |
|
|
if len(inputs) < 2: |
|
|
return [self._compute_theme_vector(inputs)] |
|
|
|
|
|
logger.info(f"🔍 Detecting multiple themes from {len(inputs)} inputs") |
|
|
|
|
|
|
|
|
input_embeddings = self.model.encode(inputs, convert_to_tensor=False, show_progress_bar=False) |
|
|
logger.info("✅ Encoded inputs for clustering") |
|
|
|
|
|
|
|
|
n_clusters = min(max_themes, len(inputs), 3) |
|
|
logger.info(f"📊 Using {n_clusters} clusters for theme detection") |
|
|
|
|
|
if n_clusters == 1: |
|
|
return [np.mean(input_embeddings, axis=0).reshape(1, -1)] |
|
|
|
|
|
|
|
|
kmeans = KMeans(n_clusters=n_clusters, random_state=42, n_init=10) |
|
|
kmeans.fit(input_embeddings) |
|
|
|
|
|
logger.info(f"✅ Clustered inputs into {n_clusters} themes") |
|
|
|
|
|
|
|
|
return [center.reshape(1, -1) for center in kmeans.cluster_centers_] |
|
|
|
|
|
def get_tier_words(self, tier: str, limit: int = 1000) -> List[str]: |
|
|
"""Get all words from a specific frequency tier. |
|
|
|
|
|
Args: |
|
|
tier: Frequency tier name (e.g., "tier_5_common") |
|
|
limit: Maximum number of words to return |
|
|
|
|
|
Returns: |
|
|
List of words in the specified tier |
|
|
""" |
|
|
if not self.is_initialized: |
|
|
self.initialize() |
|
|
|
|
|
tier_words = [word for word, word_tier in self.frequency_tiers.items() |
|
|
if word_tier == tier] |
|
|
|
|
|
return tier_words[:limit] |
|
|
|
|
|
def get_word_info(self, word: str) -> Dict[str, Any]: |
|
|
"""Get comprehensive information about a word. |
|
|
|
|
|
Args: |
|
|
word: Word to get information for |
|
|
|
|
|
Returns: |
|
|
Dictionary with word info including frequency, tier, etc. |
|
|
""" |
|
|
if not self.is_initialized: |
|
|
self.initialize() |
|
|
|
|
|
word_lower = word.lower() |
|
|
|
|
|
info = { |
|
|
'word': word, |
|
|
'in_vocabulary': word_lower in self.vocabulary, |
|
|
'frequency': self.word_frequencies.get(word_lower, 0), |
|
|
'tier': self.frequency_tiers.get(word_lower, "tier_10_very_rare"), |
|
|
'tier_description': self.tier_descriptions.get( |
|
|
self.frequency_tiers.get(word_lower, "tier_10_very_rare"), |
|
|
"Unknown" |
|
|
) |
|
|
} |
|
|
|
|
|
return info |
|
|
|
|
|
|
|
|
async def find_similar_words(self, topic: str, difficulty: str = "medium", max_words: int = 15) -> List[Dict[str, Any]]: |
|
|
"""Backend-compatible method for finding similar words. |
|
|
|
|
|
Returns list of word dictionaries compatible with crossword_generator.py |
|
|
Expected format: [{"word": str, "clue": str}, ...] |
|
|
""" |
|
|
|
|
|
difficulty_tier_map = { |
|
|
"easy": [ "tier_2_extremely_common", "tier_3_very_common", "tier_4_highly_common"], |
|
|
"medium": ["tier_4_highly_common", "tier_5_common", "tier_6_moderately_common", "tier_7_somewhat_uncommon"], |
|
|
"hard": ["tier_7_somewhat_uncommon", "tier_8_uncommon", "tier_9_rare"] |
|
|
} |
|
|
|
|
|
allowed_tiers = difficulty_tier_map.get(difficulty, difficulty_tier_map["medium"]) |
|
|
|
|
|
|
|
|
all_results = self.generate_thematic_words( |
|
|
topic, |
|
|
num_words=150, |
|
|
min_similarity=0.3 |
|
|
) |
|
|
|
|
|
|
|
|
backend_words = [] |
|
|
for word, similarity, tier in all_results: |
|
|
|
|
|
if not self._matches_backend_difficulty(word, difficulty): |
|
|
continue |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
backend_word = { |
|
|
"word": word.upper(), |
|
|
"clue": self._generate_simple_clue(word, topic), |
|
|
"similarity": similarity, |
|
|
"tier": tier |
|
|
} |
|
|
|
|
|
backend_words.append(backend_word) |
|
|
|
|
|
if len(backend_words) >= max_words: |
|
|
break |
|
|
|
|
|
logger.info(f"🎯 Generated {len(backend_words)} words for topic '{topic}' (difficulty: {difficulty})") |
|
|
return backend_words |
|
|
|
|
|
def _matches_backend_difficulty(self, word: str, difficulty: str) -> bool: |
|
|
"""Check if word matches backend difficulty criteria.""" |
|
|
difficulty_map = { |
|
|
"easy": {"min_len": 3, "max_len": 8}, |
|
|
"medium": {"min_len": 4, "max_len": 10}, |
|
|
"hard": {"min_len": 5, "max_len": 15} |
|
|
} |
|
|
|
|
|
criteria = difficulty_map.get(difficulty, difficulty_map["medium"]) |
|
|
return criteria["min_len"] <= len(word) <= criteria["max_len"] |
|
|
|
|
|
def _generate_simple_clue(self, word: str, topic: str) -> str: |
|
|
"""Generate a simple clue for the word (backend compatibility).""" |
|
|
|
|
|
word_lower = word.lower() |
|
|
topic_lower = topic.lower() |
|
|
|
|
|
|
|
|
if "animal" in topic_lower: |
|
|
return f"{word_lower} (animal)" |
|
|
elif "tech" in topic_lower or "computer" in topic_lower: |
|
|
return f"{word_lower} (technology)" |
|
|
elif "science" in topic_lower: |
|
|
return f"{word_lower} (science)" |
|
|
elif "geo" in topic_lower or "place" in topic_lower: |
|
|
return f"{word_lower} (geography)" |
|
|
elif "food" in topic_lower: |
|
|
return f"{word_lower} (food)" |
|
|
else: |
|
|
return f"{word_lower} (related to {topic_lower})" |
|
|
|
|
|
def get_vocabulary_size(self) -> int: |
|
|
"""Get the size of the loaded vocabulary.""" |
|
|
return len(self.vocabulary) |
|
|
|
|
|
def get_tier_distribution(self) -> Dict[str, int]: |
|
|
"""Get distribution of words across frequency tiers.""" |
|
|
if not self.frequency_tiers: |
|
|
return {} |
|
|
|
|
|
tier_counts = Counter(self.frequency_tiers.values()) |
|
|
return dict(tier_counts) |
|
|
|
|
|
def get_cache_status(self) -> Dict[str, Any]: |
|
|
"""Get detailed cache status information.""" |
|
|
|
|
|
if self.vocab_manager is not None: |
|
|
|
|
|
vocab_exists = self.vocab_manager.vocab_cache_path.exists() |
|
|
freq_exists = self.vocab_manager.frequency_cache_path.exists() |
|
|
vocab_path = str(self.vocab_manager.vocab_cache_path) |
|
|
freq_path = str(self.vocab_manager.frequency_cache_path) |
|
|
else: |
|
|
|
|
|
vocab_exists = False |
|
|
freq_exists = False |
|
|
vocab_path = "N/A (using WordFreq)" |
|
|
freq_path = "N/A (using WordFreq)" |
|
|
|
|
|
embeddings_exists = self.embeddings_cache_path.exists() |
|
|
|
|
|
status = { |
|
|
"cache_directory": str(self.cache_dir), |
|
|
"vocabulary_cache": { |
|
|
"path": vocab_path, |
|
|
"exists": vocab_exists, |
|
|
"readable": vocab_exists and os.access(vocab_path, os.R_OK) if vocab_exists else False |
|
|
}, |
|
|
"frequency_cache": { |
|
|
"path": freq_path, |
|
|
"exists": freq_exists, |
|
|
"readable": freq_exists and os.access(freq_path, os.R_OK) if freq_exists else False |
|
|
}, |
|
|
"embeddings_cache": { |
|
|
"path": str(self.embeddings_cache_path), |
|
|
"exists": embeddings_exists, |
|
|
"readable": embeddings_exists and os.access(self.embeddings_cache_path, os.R_OK) |
|
|
}, |
|
|
"complete": (vocab_exists or self.vocab_manager is None) and (freq_exists or self.vocab_manager is None) and embeddings_exists |
|
|
} |
|
|
|
|
|
|
|
|
for cache_type in ["vocabulary_cache", "frequency_cache", "embeddings_cache"]: |
|
|
cache_info = status[cache_type] |
|
|
if cache_info["exists"]: |
|
|
try: |
|
|
file_path = Path(cache_info["path"]) |
|
|
cache_info["size_bytes"] = file_path.stat().st_size |
|
|
cache_info["size_mb"] = round(cache_info["size_bytes"] / (1024 * 1024), 2) |
|
|
except Exception as e: |
|
|
cache_info["size_error"] = str(e) |
|
|
|
|
|
return status |
|
|
|
|
|
async def find_words_for_crossword(self, topics: List[str], difficulty: str, requested_words: int = 10, custom_sentence: str = None, multi_theme: bool = True, advanced_params: Dict[str, Any] = None) -> Dict[str, Any]: |
|
|
""" |
|
|
Crossword-specific word finding method with 50% overgeneration and clue quality filtering. |
|
|
|
|
|
Args: |
|
|
topics: List of topic strings |
|
|
difficulty: "easy", "medium", or "hard" |
|
|
requested_words: Number of words requested by frontend |
|
|
custom_sentence: Optional custom sentence to influence word selection |
|
|
multi_theme: Whether to use multi-theme processing (True) or single-theme blending (False) |
|
|
advanced_params: Optional dict with parameter overrides (similarity_temperature, difficulty_weight) |
|
|
|
|
|
Returns: |
|
|
Dictionary with words and optional debug data: |
|
|
{ |
|
|
"words": [{"word": str, "clue": str, "similarity": float, "source": "thematic", "tier": str}], |
|
|
"debug": {...} (only if ENABLE_DEBUG_TAB=true) |
|
|
} |
|
|
""" |
|
|
if not self.is_initialized: |
|
|
await self.initialize_async() |
|
|
|
|
|
sentence_info = f", custom sentence: '{custom_sentence}'" if custom_sentence else "" |
|
|
theme_mode = "multi-theme" if multi_theme else "single-theme" |
|
|
|
|
|
|
|
|
generation_target = int(requested_words * 3) |
|
|
logger.info(f"🎯 Finding words for crossword - topics: {topics}, difficulty: {difficulty}{sentence_info}, mode: {theme_mode}") |
|
|
logger.info(f"📊 Generating {generation_target} candidates to select best {requested_words} words after clue filtering") |
|
|
|
|
|
|
|
|
min_similarity = 0.25 |
|
|
|
|
|
|
|
|
input_list = topics.copy() |
|
|
|
|
|
|
|
|
if custom_sentence: |
|
|
input_list.append(custom_sentence) |
|
|
|
|
|
|
|
|
|
|
|
thematic_pool = min(self.thematic_pool_size, max(generation_target * 5, 50)) |
|
|
logger.info(f"🚀 Optimized thematic pool size: {thematic_pool} (was 400) - {((400-thematic_pool)/400*100):.1f}% reduction") |
|
|
|
|
|
|
|
|
original_temp = self.similarity_temperature |
|
|
original_weight = self.difficulty_weight |
|
|
|
|
|
if advanced_params: |
|
|
if 'similarity_temperature' in advanced_params: |
|
|
self.similarity_temperature = advanced_params['similarity_temperature'] |
|
|
logger.info(f"🎛️ Overriding similarity temperature: {original_temp} → {self.similarity_temperature}") |
|
|
if 'difficulty_weight' in advanced_params: |
|
|
self.difficulty_weight = advanced_params['difficulty_weight'] |
|
|
logger.info(f"🎛️ Overriding difficulty weight: {original_weight} → {self.difficulty_weight}") |
|
|
|
|
|
|
|
|
raw_results = self.generate_thematic_words( |
|
|
input_list, |
|
|
num_words=thematic_pool, |
|
|
min_similarity=min_similarity, |
|
|
multi_theme=multi_theme, |
|
|
difficulty=difficulty |
|
|
) |
|
|
|
|
|
|
|
|
if raw_results: |
|
|
|
|
|
tier_groups = {} |
|
|
for word, similarity, tier in raw_results: |
|
|
if tier not in tier_groups: |
|
|
tier_groups[tier] = [] |
|
|
tier_groups[tier].append((word, similarity)) |
|
|
|
|
|
|
|
|
tier_order = [ |
|
|
"tier_1_ultra_common", |
|
|
"tier_2_extremely_common", |
|
|
"tier_3_very_common", |
|
|
"tier_4_highly_common", |
|
|
"tier_5_common", |
|
|
"tier_6_moderately_common", |
|
|
"tier_7_somewhat_uncommon", |
|
|
"tier_8_uncommon", |
|
|
"tier_9_rare", |
|
|
"tier_10_very_rare" |
|
|
] |
|
|
|
|
|
|
|
|
log_lines = [f"📊 Generated {len(raw_results)} thematic words, grouped by tiers:"] |
|
|
|
|
|
for tier in tier_order: |
|
|
|
|
|
log_lines.append(f" 📊 {tier}:") |
|
|
if tier in tier_groups: |
|
|
|
|
|
tier_words = sorted(tier_groups[tier], key=lambda x: x[0]) |
|
|
for word, similarity in tier_words: |
|
|
percentile = self.word_percentiles.get(word.lower(), 0.0) |
|
|
log_lines.append(f" {word:<15} (similarity: {similarity:.3f}, percentile: {percentile:.3f})") |
|
|
|
|
|
|
|
|
logger.info("\n".join(log_lines)) |
|
|
else: |
|
|
logger.info("📊 No thematic words generated") |
|
|
|
|
|
|
|
|
|
|
|
candidate_words = [] |
|
|
|
|
|
logger.info(f"📊 Generating clues for {len(raw_results)} thematically relevant words (optimized from 400)") |
|
|
for word, similarity, tier in raw_results: |
|
|
word_data = { |
|
|
"word": word.upper(), |
|
|
"clue": self._generate_crossword_clue(word, topics), |
|
|
"similarity": float(similarity), |
|
|
"source": "thematic", |
|
|
"tier": tier |
|
|
} |
|
|
candidate_words.append(word_data) |
|
|
|
|
|
|
|
|
logger.info(f"📊 Generated {len(candidate_words)} candidate words, applying softmax selection on ALL words") |
|
|
|
|
|
final_words = [] |
|
|
|
|
|
|
|
|
probability_data = None |
|
|
if self.use_softmax_selection: |
|
|
logger.info(f"🎲 Using softmax weighted selection on all {len(candidate_words)} candidates (temperature: {self.similarity_temperature})") |
|
|
|
|
|
|
|
|
if len(candidate_words) > requested_words: |
|
|
selected_words, probability_data = self._softmax_weighted_selection(candidate_words, requested_words, difficulty=difficulty) |
|
|
final_words.extend(selected_words) |
|
|
else: |
|
|
final_words.extend(candidate_words) |
|
|
else: |
|
|
logger.info("📊 Using traditional random selection on all candidates") |
|
|
|
|
|
|
|
|
random.shuffle(candidate_words) |
|
|
final_words.extend(candidate_words[:requested_words]) |
|
|
|
|
|
|
|
|
random.shuffle(final_words) |
|
|
|
|
|
logger.info(f"✅ Selected {len(final_words)} words from {len(candidate_words)} total candidates") |
|
|
logger.info(f"📝 Final words: {[w['word'] for w in final_words]}") |
|
|
|
|
|
|
|
|
result = {"words": final_words} |
|
|
|
|
|
|
|
|
if self.enable_debug_tab: |
|
|
debug_data = { |
|
|
"enabled": True, |
|
|
"generation_params": { |
|
|
"topics": topics, |
|
|
"difficulty": difficulty, |
|
|
"requested_words": requested_words, |
|
|
"custom_sentence": custom_sentence, |
|
|
"multi_theme": multi_theme, |
|
|
"thematic_pool_size": thematic_pool, |
|
|
"min_similarity": min_similarity, |
|
|
"multi_topic_method": self.multi_topic_method if len(topics) > 1 else None, |
|
|
"soft_min_beta": self.soft_min_beta if len(topics) > 1 and self.multi_topic_method == "soft_minimum" else None |
|
|
}, |
|
|
"thematic_pool": [ |
|
|
{ |
|
|
"word": word, |
|
|
"similarity": float(similarity), |
|
|
"tier": tier, |
|
|
"percentile": self.word_percentiles.get(word.lower(), 0.0), |
|
|
"tier_description": self.tier_descriptions.get(tier, tier) |
|
|
} |
|
|
for word, similarity, tier in raw_results |
|
|
], |
|
|
"candidate_words": [ |
|
|
{ |
|
|
"word": word_data["word"], |
|
|
"similarity": word_data["similarity"], |
|
|
"tier": word_data["tier"], |
|
|
"percentile": self.word_percentiles.get(word_data["word"].lower(), 0.0), |
|
|
"clue": word_data["clue"] |
|
|
|
|
|
} |
|
|
for word_data in candidate_words |
|
|
], |
|
|
"selection_method": "softmax" if self.use_softmax_selection else "random", |
|
|
"selection_params": { |
|
|
"use_softmax_selection": self.use_softmax_selection, |
|
|
"similarity_temperature": self.similarity_temperature, |
|
|
"difficulty_weight": self.difficulty_weight |
|
|
}, |
|
|
"selected_words": [ |
|
|
{ |
|
|
"word": word_data["word"], |
|
|
"similarity": word_data["similarity"], |
|
|
"tier": word_data["tier"], |
|
|
"percentile": self.word_percentiles.get(word_data["word"].lower(), 0.0), |
|
|
"clue": word_data["clue"] |
|
|
} |
|
|
for word_data in final_words |
|
|
] |
|
|
} |
|
|
|
|
|
|
|
|
if probability_data: |
|
|
debug_data["probability_distribution"] = probability_data |
|
|
|
|
|
result["debug"] = debug_data |
|
|
logger.info(f"🐛 Debug data collected: {len(debug_data['thematic_pool'])} thematic words, {len(debug_data['candidate_words'])} candidates, {len(debug_data['selected_words'])} selected") |
|
|
|
|
|
|
|
|
if advanced_params: |
|
|
self.similarity_temperature = original_temp |
|
|
self.difficulty_weight = original_weight |
|
|
if 'similarity_temperature' in advanced_params: |
|
|
logger.info(f"🔄 Restored similarity temperature: {self.similarity_temperature}") |
|
|
if 'difficulty_weight' in advanced_params: |
|
|
logger.info(f"🔄 Restored difficulty weight: {self.difficulty_weight}") |
|
|
|
|
|
return result |
|
|
|
|
|
def _matches_crossword_difficulty(self, word: str, difficulty: str) -> bool: |
|
|
"""Check if word matches crossword difficulty criteria.""" |
|
|
difficulty_criteria = { |
|
|
"easy": {"min_len": 3, "max_len": 8}, |
|
|
"medium": {"min_len": 4, "max_len": 10}, |
|
|
"hard": {"min_len": 5, "max_len": 12} |
|
|
} |
|
|
|
|
|
criteria = difficulty_criteria.get(difficulty, difficulty_criteria["medium"]) |
|
|
return criteria["min_len"] <= len(word) <= criteria["max_len"] |
|
|
|
|
|
def _get_semantic_neighbors(self, word: str, n: int = 6) -> List[str]: |
|
|
"""Get semantic neighbors of a word using embeddings. |
|
|
|
|
|
Args: |
|
|
word: Word to find neighbors for |
|
|
n: Number of neighbors to return (excluding the word itself) |
|
|
|
|
|
Returns: |
|
|
List of neighbor words, ordered by similarity |
|
|
""" |
|
|
if not self.is_initialized or not hasattr(self, 'vocab_embeddings'): |
|
|
return [] |
|
|
|
|
|
word_lower = word.lower() |
|
|
if word_lower not in self.vocabulary: |
|
|
return [] |
|
|
|
|
|
try: |
|
|
|
|
|
word_idx = self.vocabulary.index(word_lower) |
|
|
|
|
|
|
|
|
word_embedding = self.vocab_embeddings[word_idx].unsqueeze(0) |
|
|
|
|
|
similarities = torch.mm(self.vocab_embeddings, word_embedding.T).squeeze() |
|
|
|
|
|
|
|
|
top_indices = torch.argsort(similarities, descending=True)[:n+1] |
|
|
|
|
|
neighbors = [] |
|
|
for idx in top_indices: |
|
|
idx_item = idx.item() |
|
|
neighbor = self.vocabulary[idx_item] |
|
|
if neighbor != word_lower: |
|
|
neighbors.append(neighbor) |
|
|
if len(neighbors) >= n: |
|
|
break |
|
|
|
|
|
return neighbors |
|
|
|
|
|
except Exception as e: |
|
|
logger.warning(f"⚠️ Failed to get semantic neighbors for '{word}': {e}") |
|
|
return [] |
|
|
|
|
|
def _generate_semantic_neighbor_clue(self, word: str, topics: List[str]) -> str: |
|
|
"""Generate a clue using semantic neighbors. |
|
|
|
|
|
Args: |
|
|
word: Word to generate clue for |
|
|
topics: Context topics for clue generation |
|
|
|
|
|
Returns: |
|
|
Generated clue based on semantic neighbors |
|
|
""" |
|
|
neighbors = self._get_semantic_neighbors(word, n=5) |
|
|
if not neighbors: |
|
|
return None |
|
|
|
|
|
|
|
|
neighbor_descriptions = [] |
|
|
usable_neighbors = [] |
|
|
|
|
|
for neighbor in neighbors: |
|
|
|
|
|
if hasattr(self, '_wordnet_generator') and self._wordnet_generator: |
|
|
try: |
|
|
desc = self._wordnet_generator.generate_clue(neighbor, topics[0] if topics else "general") |
|
|
if desc and len(desc.strip()) > 5 and not any(pattern in desc for pattern in ["Related to", "Crossword answer"]): |
|
|
neighbor_descriptions.append((neighbor, desc)) |
|
|
continue |
|
|
except: |
|
|
pass |
|
|
|
|
|
|
|
|
usable_neighbors.append(neighbor) |
|
|
|
|
|
|
|
|
if neighbor_descriptions: |
|
|
|
|
|
neighbor, desc = neighbor_descriptions[0] |
|
|
if len(neighbor_descriptions) > 1: |
|
|
neighbor2, desc2 = neighbor_descriptions[1] |
|
|
return f"Like {neighbor} ({desc.split('.')[0].lower()}), related to {neighbor2}" |
|
|
else: |
|
|
return f"Related to {neighbor} ({desc.split('.')[0].lower()})" |
|
|
|
|
|
elif len(usable_neighbors) >= 2: |
|
|
|
|
|
if len(usable_neighbors) >= 3: |
|
|
return f"Associated with {usable_neighbors[0]}, {usable_neighbors[1]} and {usable_neighbors[2]}" |
|
|
else: |
|
|
return f"Related to {usable_neighbors[0]} and {usable_neighbors[1]}" |
|
|
elif len(usable_neighbors) == 1: |
|
|
return f"Connected to {usable_neighbors[0]}" |
|
|
else: |
|
|
return None |
|
|
|
|
|
def _generate_crossword_clue(self, word: str, topics: List[str]) -> str: |
|
|
"""Generate a crossword clue for the word using multiple strategies.""" |
|
|
|
|
|
if not hasattr(self, '_wordnet_generator') or self._wordnet_generator is None: |
|
|
try: |
|
|
from .wordnet_clue_generator import WordNetClueGenerator |
|
|
self._wordnet_generator = WordNetClueGenerator( |
|
|
cache_dir=str(self.cache_dir) |
|
|
) |
|
|
self._wordnet_generator.initialize() |
|
|
logger.info("✅ WordNet clue generator initialized on-demand") |
|
|
except Exception as e: |
|
|
logger.warning(f"⚠️ Failed to initialize WordNet clue generator: {e}") |
|
|
self._wordnet_generator = None |
|
|
|
|
|
|
|
|
if self._wordnet_generator: |
|
|
try: |
|
|
primary_topic = topics[0] if topics else "general" |
|
|
clue = self._wordnet_generator.generate_clue(word, primary_topic) |
|
|
if clue and len(clue.strip()) > 0 and not any(pattern in clue for pattern in ["Related to", "Crossword answer"]): |
|
|
return clue |
|
|
except Exception as e: |
|
|
logger.warning(f"⚠️ WordNet clue generation failed for '{word}': {e}") |
|
|
|
|
|
|
|
|
semantic_clue = self._generate_semantic_neighbor_clue(word, topics) |
|
|
if semantic_clue: |
|
|
return semantic_clue |
|
|
|
|
|
|
|
|
word_lower = word.lower() |
|
|
primary_topic = topics[0] if topics else "general" |
|
|
return f"Crossword answer: {word_lower}" |
|
|
|
|
|
|
|
|
|
|
|
ThematicWordGenerator = ThematicWordService |
|
|
UnifiedThematicWordGenerator = ThematicWordService |
|
|
|
|
|
|
|
|
|