Spaces:
Sleeping
Sleeping
| # app.py - Dead Cells Wiki Bot with Enhanced Strategy Extraction | |
| # Solution 3: Hybrid Chunk + Quote System with Wiki-Aware Parsing | |
| # VERSION 1.2 - Fixed missing best_segments() function definition | |
| # | |
| # Key Features: | |
| # - Hybrid retrieval (dense embeddings + BM25 keyword scoring) | |
| # - Cross-encoder re-ranking for precision | |
| # - MMR diversification to reduce redundancy | |
| # - ENHANCED: Strategy-aware segment extraction (handles lists, bullets, strategies) | |
| # - ENHANCED: Lenient scoring thresholds for wiki formatting | |
| # - ENHANCED: Explicit avoidance of lore/notes/BSC sections | |
| # - Quote-level evidence with inline citations | |
| # - Strict grounding with abstention when context insufficient | |
| import os | |
| import re | |
| import glob | |
| import math | |
| from typing import List, Dict, Tuple | |
| from collections import Counter | |
| import gradio as gr | |
| import chromadb | |
| from sentence_transformers import SentenceTransformer, CrossEncoder | |
| from openai import OpenAI | |
| # ========================================== | |
| # CONFIGURATION | |
| # ========================================== | |
| # API Configuration | |
| NOVITA_API_KEY = os.environ.get("NOVITA_API_KEY") | |
| if not NOVITA_API_KEY: | |
| raise RuntimeError("❌ NOVITA_API_KEY environment variable not set. Add it in Hugging Face Space settings.") | |
| NOVITA_BASE_URL = "https://api.novita.ai/v3/openai" | |
| MODEL_NAME = "meta-llama/llama-3.2-1b-instruct" | |
| # Database Configuration | |
| DB_PATH = "./deadcells_db_free" | |
| COLLECTION_NAME = "deadcells_wiki" | |
| # Model Configuration | |
| EMBED_MODEL = "all-MiniLM-L6-v2" # Fast, balanced embedding model | |
| RERANK_MODEL = "cross-encoder/ms-marco-MiniLM-L-6-v2" # Cross-encoder for precision | |
| DEVICE = "cpu" | |
| # Retrieval Parameters | |
| K_CANDIDATES = 32 # Initial semantic retrieval pool | |
| BM25_WEIGHT = 0.20 # Weight for keyword scoring | |
| TOP_K = 8 # Final chunks to process | |
| MMR_LAMBDA = 0.72 # MMR diversity parameter (0.7-0.8 recommended) | |
| MAX_CONTEXT_WORDS = 1600 # Maximum context size for LLM | |
| # SOLUTION 3: Enhanced Quote Harvesting Parameters | |
| SEGMENT_MAX = 10 # Maximum strategy segments to extract | |
| MIN_EVIDENCE_SEGMENTS = 1 # LOWERED: Require only 1 good segment (was 2) | |
| MIN_RERANK_SCORE = 0.15 # Threshold for chunk relevance | |
| MIN_SEGMENT_SCORE = 0.10 # LOWERED: Threshold for segment relevance (was 0.25) | |
| # LLM Generation Parameters | |
| TEMPERATURE = 0.1 # Low temperature for factual consistency | |
| MAX_TOKENS = 700 # Maximum response length | |
| # ========================================== | |
| # MODEL INITIALIZATION | |
| # ========================================== | |
| print("🔄 Loading models...") | |
| embedder = SentenceTransformer(EMBED_MODEL, device=DEVICE) | |
| reranker = CrossEncoder(RERANK_MODEL, device=DEVICE) | |
| llm = OpenAI(api_key=NOVITA_API_KEY, base_url=NOVITA_BASE_URL) | |
| print("✅ Models loaded successfully") | |
| print("🔄 Checking database...") | |
| chroma_client = chromadb.PersistentClient(path=DB_PATH) | |
| # ========================================== | |
| # DATABASE SETUP | |
| # ========================================== | |
| def ensure_collection() -> chromadb.api.models.Collection.Collection: | |
| """ | |
| Load existing ChromaDB collection or create new one from wiki files. | |
| Returns: | |
| ChromaDB collection with embedded wiki chunks | |
| """ | |
| try: | |
| col = chroma_client.get_collection(name=COLLECTION_NAME) | |
| print(f"✅ Database loaded: {col.count()} chunks available") | |
| return col | |
| except Exception: | |
| print("⚠️ Database not found. Building from wiki_content/*.txt ...") | |
| col = chroma_client.create_collection(name=COLLECTION_NAME) | |
| files = glob.glob("wiki_content/*.txt") | |
| if not files: | |
| raise RuntimeError("❌ No wiki files found in wiki_content/ directory") | |
| print(f"📚 Found {len(files)} wiki files. Starting vectorization...") | |
| # Batch processing variables | |
| batch_docs, batch_embs, batch_meta, batch_ids = [], [], [], [] | |
| chunk_id = 0 | |
| batch_size = 100 | |
| for idx, filepath in enumerate(files): | |
| filename = os.path.basename(filepath) | |
| try: | |
| # Read wiki file | |
| with open(filepath, "r", encoding="utf-8") as f: | |
| content = f.read() | |
| # Parse wiki file format (URL on line 1, content after line 2) | |
| lines = content.splitlines() | |
| url = lines[0].replace("URL: ", "") if lines and lines[0].startswith("URL:") else "" | |
| text = "\n".join(lines[2:]) if len(lines) > 2 else content | |
| # Skip empty or very short files | |
| if not text or len(text.strip()) < 100: | |
| continue | |
| # Chunking strategy: 1000-word windows with 200-word overlap | |
| words = text.split() | |
| for i in range(0, len(words), 800): # 800-word stride = 200-word overlap | |
| chunk = " ".join(words[i:i+1000]) | |
| # Skip tiny chunks | |
| if len(chunk.split()) < 50: | |
| continue | |
| # Generate embedding | |
| emb = embedder.encode(chunk).tolist() | |
| # Add to batch | |
| batch_docs.append(chunk) | |
| batch_embs.append(emb) | |
| batch_meta.append({ | |
| "source": filename, | |
| "url": url, | |
| "chunk_index": i | |
| }) | |
| batch_ids.append(f"chunk-{chunk_id}") | |
| chunk_id += 1 | |
| # Insert batch when full | |
| if len(batch_docs) >= batch_size: | |
| col.add( | |
| documents=batch_docs, | |
| embeddings=batch_embs, | |
| metadatas=batch_meta, | |
| ids=batch_ids | |
| ) | |
| batch_docs, batch_embs, batch_meta, batch_ids = [], [], [], [] | |
| except Exception as e: | |
| print(f"⚠️ Error processing {filename}: {e}") | |
| # Progress indicator | |
| if (idx + 1) % 50 == 0: | |
| print(f"📊 Processed {idx+1}/{len(files)} files...") | |
| # Insert remaining batch | |
| if batch_docs: | |
| col.add( | |
| documents=batch_docs, | |
| embeddings=batch_embs, | |
| metadatas=batch_meta, | |
| ids=batch_ids | |
| ) | |
| print(f"✅ Database created with {col.count()} chunks") | |
| return col | |
| collection = ensure_collection() | |
| # ========================================== | |
| # RETRIEVAL UTILITIES | |
| # ========================================== | |
| # Regular expression for tokenization (alphanumeric + apostrophes) | |
| WORD_RE = re.compile(r"[a-z0-9']+") | |
| def tokenize(text: str) -> List[str]: | |
| """ | |
| Tokenize text into lowercase words. | |
| Args: | |
| text: Input text string | |
| Returns: | |
| List of lowercase word tokens | |
| """ | |
| return WORD_RE.findall(text.lower()) | |
| def bm25_keyword_score(query_terms: List[str], document: str, k1=1.5, b=0.75) -> float: | |
| """ | |
| Calculate BM25-style keyword relevance score. | |
| BM25 is a ranking function that scores documents based on term frequency | |
| and document length normalization. This is a simplified version without | |
| corpus-wide IDF calculation. | |
| Args: | |
| query_terms: List of query tokens | |
| document: Document text to score | |
| k1: Term frequency saturation parameter (default 1.5) | |
| b: Length normalization parameter (default 0.75) | |
| Returns: | |
| BM25 relevance score (higher is better) | |
| """ | |
| if not query_terms: | |
| return 0.0 | |
| doc_terms = tokenize(document) | |
| if not doc_terms: | |
| return 0.0 | |
| # Term frequency counter | |
| tf = Counter(doc_terms) | |
| doc_length = len(doc_terms) | |
| avg_doc_length = 350.0 # Heuristic average for ~1000-word chunks | |
| # Query term counter (for pseudo-IDF weighting) | |
| q_counts = Counter(query_terms) | |
| score = 0.0 | |
| for term in set(query_terms): | |
| # Pseudo-IDF: boost repeated query terms | |
| idf = 1.3 if q_counts[term] > 1 else 1.0 | |
| # Term frequency in document | |
| freq = tf.get(term, 0) | |
| if freq == 0: | |
| continue | |
| # BM25 formula | |
| score += idf * ((freq * (k1 + 1)) / (freq + k1 * (1 - b + b * doc_length / avg_doc_length))) | |
| return score | |
| def mmr_diversify( | |
| candidates: List[Tuple[int, float, str, Dict]], | |
| max_k: int, | |
| lambda_weight: float = 0.7 | |
| ) -> List[Tuple[int, float, str, Dict]]: | |
| """ | |
| Maximal Marginal Relevance (MMR) diversification. | |
| Selects diverse documents that are both relevant to the query and | |
| dissimilar to already-selected documents. This reduces redundancy. | |
| Args: | |
| candidates: List of (index, score, text, metadata) tuples | |
| max_k: Maximum number of documents to select | |
| lambda_weight: Trade-off between relevance (1.0) and diversity (0.0) | |
| Returns: | |
| Diversified list of top-k candidates | |
| """ | |
| def jaccard_similarity(text_a: str, text_b: str) -> float: | |
| """Calculate Jaccard similarity between two texts.""" | |
| tokens_a = set(tokenize(text_a)) | |
| tokens_b = set(tokenize(text_b)) | |
| if not tokens_a or not tokens_b: | |
| return 0.0 | |
| intersection = len(tokens_a & tokens_b) | |
| union = len(tokens_a | tokens_b) | |
| return intersection / union if union > 0 else 0.0 | |
| selected = [] | |
| remaining = candidates[:] | |
| while remaining and len(selected) < max_k: | |
| # First selection: pick highest-scoring candidate | |
| if not selected: | |
| remaining.sort(key=lambda x: x[1], reverse=True) | |
| selected.append(remaining.pop(0)) | |
| continue | |
| # Subsequent selections: balance relevance vs. diversity | |
| best_idx = -1 | |
| best_score = -1e9 | |
| for i, (idx, relevance_score, text, meta) in enumerate(remaining): | |
| # Calculate max similarity to already-selected documents | |
| max_similarity = max( | |
| jaccard_similarity(text, selected_text) | |
| for _, _, selected_text, _ in selected | |
| ) | |
| # MMR score: λ * relevance - (1-λ) * max_similarity | |
| mmr_score = lambda_weight * relevance_score - (1 - lambda_weight) * max_similarity | |
| if mmr_score > best_score: | |
| best_score = mmr_score | |
| best_idx = i | |
| selected.append(remaining.pop(best_idx)) | |
| return selected | |
| # ========================================== | |
| # DENSE RETRIEVAL | |
| # ========================================== | |
| def dense_query(message: str) -> Tuple[List[str], List[Dict], List[float]]: | |
| """ | |
| Perform dense semantic retrieval using embeddings. | |
| Args: | |
| message: User query | |
| Returns: | |
| Tuple of (documents, metadatas, similarity_scores) | |
| """ | |
| # Encode query | |
| query_embedding = embedder.encode(message).tolist() | |
| # Query ChromaDB | |
| results = collection.query( | |
| query_embeddings=[query_embedding], | |
| n_results=K_CANDIDATES, | |
| include=["documents", "metadatas", "distances"] | |
| ) | |
| # Extract results | |
| docs = results.get("documents", [[]])[0] | |
| metas = results.get("metadatas", [[]])[0] | |
| dists = results.get("distances", [[]])[0] | |
| # Convert distances to similarities (1 - distance) | |
| similarities = [1.0 - d for d in dists] if dists else [0.5] * len(docs) | |
| return docs, metas, similarities | |
| # ========================================== | |
| # HYBRID RANKING | |
| # ========================================== | |
| def hybrid_rank( | |
| message: str, | |
| docs: List[str], | |
| metas: List[Dict], | |
| similarities: List[float] | |
| ) -> List[Tuple[int, float, str, Dict]]: | |
| """ | |
| Combine dense embeddings with BM25 keyword scoring. | |
| Args: | |
| message: User query | |
| docs: Retrieved documents | |
| metas: Document metadata | |
| similarities: Embedding similarity scores | |
| Returns: | |
| Ranked list of (index, score, text, metadata) tuples | |
| """ | |
| query_terms = tokenize(message) | |
| scored = [] | |
| for i, (text, meta, sim) in enumerate(zip(docs, metas, similarities)): | |
| # BM25 keyword score | |
| keyword_score = bm25_keyword_score(query_terms, text) | |
| # Hybrid score: weighted combination | |
| combined_score = 0.80 * sim + BM25_WEIGHT * keyword_score | |
| scored.append((i, combined_score, text, meta)) | |
| # Sort by score descending | |
| scored.sort(key=lambda x: x[1], reverse=True) | |
| return scored | |
| # ========================================== | |
| # CROSS-ENCODER RE-RANKING | |
| # ========================================== | |
| def rerank_with_crossencoder( | |
| message: str, | |
| candidates: List[Tuple[int, float, str, Dict]] | |
| ) -> List[Tuple[int, float, str, Dict]]: | |
| """ | |
| Re-rank candidates using cross-encoder for higher precision. | |
| Cross-encoders jointly encode query and document for more accurate | |
| relevance scoring than bi-encoders. | |
| Args: | |
| message: User query | |
| candidates: List of candidate chunks | |
| Returns: | |
| Re-ranked candidates with updated scores | |
| """ | |
| # Create query-document pairs | |
| pairs = [(message, candidate[2]) for candidate in candidates] | |
| # Get cross-encoder scores | |
| ce_scores = reranker.predict(pairs).tolist() | |
| # Combine with prior scores using sigmoid normalization | |
| rescored = [] | |
| for candidate, ce_score in zip(candidates, ce_scores): | |
| idx, prior_score, text, meta = candidate | |
| # Normalize cross-encoder score with sigmoid | |
| normalized_ce = 1 / (1 + math.exp(-ce_score)) | |
| # Combine: 50% prior, 50% cross-encoder | |
| combined = 0.5 * prior_score + 0.5 * normalized_ce | |
| rescored.append((idx, combined, text, meta)) | |
| # Sort by combined score | |
| rescored.sort(key=lambda x: x[1], reverse=True) | |
| # Filter out weak candidates | |
| strong = [x for x in rescored if x[1] >= MIN_RERANK_SCORE] | |
| # Return strong candidates or top-K if filtering removed everything | |
| return strong if strong else rescored[:TOP_K] | |
| # ========================================== | |
| # SOLUTION 3: ENHANCED STRATEGY EXTRACTION | |
| # ========================================== | |
| def extract_strategy_segments(text: str) -> List[str]: | |
| """ | |
| SOLUTION 3: Extract strategy-relevant segments from wiki text. | |
| Enhanced to explicitly detect and prioritize strategy sections while | |
| avoiding lore, notes, and BSC information. | |
| Args: | |
| text: Raw wiki chunk text | |
| Returns: | |
| List of extracted strategy segments | |
| """ | |
| segments = [] | |
| # Strategy-related keywords (comprehensive list) | |
| strategy_keywords = [ | |
| 'strategy', 'strategies', 'beat', 'defeat', 'attack', 'attacks', | |
| 'dodge', 'dodging', 'parry', 'parrying', 'vulnerable', 'vulnerability', 'weakness', | |
| 'weak', 'damage', 'phase', 'pattern', 'patterns', 'tip', 'tips', | |
| 'hint', 'hints', 'counter', 'avoid', 'block', 'blocking', 'jump', 'jumping', 'roll', 'rolling', | |
| 'combo', 'combos', 'invulnerable', 'invincible', 'telegraph', | |
| 'move', 'moves', 'moveset', 'behavior', 'tactics', 'tactical', 'approach', | |
| 'shield', 'weapon', 'weapons', 'skill', 'skills', 'grenade', 'grenades', | |
| 'fire', 'bleed', 'bleeding', 'poison', 'freeze', 'frozen', 'frost', | |
| 'slam', 'charge', 'charging', 'sweep', 'sweeping', 'immune', 'immunity', 'resist', 'resistance', | |
| 'arena', 'platform', 'spike', 'spikes', 'trap', 'traps', 'mutation', 'mutations', | |
| 'effective', 'effectiveness', 'useful', 'helpful', 'advantage', 'nullify', 'reflect', | |
| 'aoe', 'dot', 'dps', 'burst', 'chip', 'poke', 'kite', 'spacing' | |
| ] | |
| # Anti-keywords: sections to explicitly AVOID | |
| avoid_keywords = [ | |
| 'boss stem cell', 'bsc', '1 bsc', '2 bsc', '3 bsc', '4 bsc', '5 bsc', | |
| 'lore', 'dialogue', 'gallery', 'history', 'notes', 'trivia', | |
| 'drops', 'blueprint', 'outfit', 'spoiler', 'cutscene', | |
| 'reward', 'location(s)', 'internal name', | |
| 'the following information contains spoilers' | |
| ] | |
| # Split on double newlines or horizontal rules (section breaks) | |
| parts = re.split(r'\n{2,}|---+', text) | |
| for part in parts: | |
| part = part.strip() | |
| # Skip very short segments | |
| if len(part) < 30: | |
| continue | |
| part_lower = part.lower() | |
| # PRIORITY 1: Explicitly SKIP sections we want to avoid | |
| if any(avoid in part_lower for avoid in avoid_keywords): | |
| continue | |
| # PRIORITY 2: Explicit strategy section headers (HIGH PRIORITY) | |
| strategy_headers = ['strategy', 'vulnerabilities', 'immunities', 'weapons/skills', 'weapons and skills'] | |
| if any(header in part_lower[:100] for header in strategy_headers): | |
| # This is a strategy section header or content | |
| segments.append(part) | |
| continue | |
| # PRIORITY 3: Segments with multiple strategy keywords (MEDIUM-HIGH PRIORITY) | |
| keyword_count = sum(1 for kw in strategy_keywords if kw in part_lower) | |
| if keyword_count >= 2: | |
| segments.append(part) | |
| continue | |
| # PRIORITY 4: Single strategy keyword but significant length (MEDIUM PRIORITY) | |
| if any(keyword in part_lower for keyword in strategy_keywords) and len(part) >= 60: | |
| segments.append(part) | |
| continue | |
| # PRIORITY 5: Bullet points and numbered lists with action words (MEDIUM PRIORITY) | |
| if re.match(r'^\s*[-•*]\s+', part) or re.match(r'^\d+[\.)]\s+', part): | |
| action_words = ['use', 'can', 'will', 'attack', 'dodge', 'parry', 'jump', 'avoid', 'help', 'allow'] | |
| if any(word in part_lower for word in action_words): | |
| segments.append(part) | |
| continue | |
| # PRIORITY 6: Instruction-style segments (LOW-MEDIUM PRIORITY) | |
| instruction_starts = r'^(use|try|avoid|watch|keep|stay|attack|defend|wait|time|constantly|incorporate|focus|prioritize)' | |
| modal_phrases = ['can be', 'should be', 'will be', 'must be', 'allows you', 'helps you', 'enables'] | |
| if re.match(instruction_starts, part_lower) or any(modal in part_lower for modal in modal_phrases): | |
| segments.append(part) | |
| return segments | |
| def best_segments( | |
| message: str, | |
| texts: List[str], | |
| metas: List[Dict], | |
| max_segments: int = SEGMENT_MAX | |
| ) -> Tuple[List[str], List[Dict]]: | |
| """ | |
| SOLUTION 3: Score and rank strategy segments using cross-encoder. | |
| This replaces the strict sentence-based approach with a more flexible | |
| segment-based approach that works better with wiki formatting. | |
| Args: | |
| message: User query | |
| texts: Pool of text chunks | |
| metas: Corresponding metadata | |
| max_segments: Maximum segments to return | |
| Returns: | |
| Tuple of (selected segments, corresponding metadata) | |
| """ | |
| candidates = [] | |
| # Extract strategy segments from each chunk | |
| for text, meta in zip(texts, metas): | |
| segments = extract_strategy_segments(text) | |
| for segment in segments: | |
| # Skip very short segments | |
| if len(segment) < 40: | |
| continue | |
| candidates.append((segment, meta)) | |
| # Fallback: if no strategy segments found, use general sentence splitting | |
| if not candidates: | |
| print("⚠️ No strategy segments found, falling back to sentence splitting") | |
| for text, meta in zip(texts, metas): | |
| # Split on sentence boundaries, newlines, and colons | |
| sentences = re.split(r'[.!?]\s+|\n+|:\s*\n', text) | |
| for sentence in sentences: | |
| sentence = sentence.strip() | |
| if len(sentence) >= 40: | |
| candidates.append((sentence, meta)) | |
| # If still no candidates, return empty | |
| if not candidates: | |
| print("❌ No extractable segments found") | |
| return [], [] | |
| # Score all candidates with cross-encoder | |
| pairs = [(message, segment) for segment, _ in candidates] | |
| scores = reranker.predict(pairs).tolist() | |
| # Combine segments with scores and metadata | |
| ranked = [ | |
| (segment, meta, 1 / (1 + math.exp(-score))) # Sigmoid normalization | |
| for (segment, meta), score in zip(candidates, scores) | |
| ] | |
| # Sort by score descending | |
| ranked.sort(key=lambda x: x[2], reverse=True) | |
| # SOLUTION 3: Apply lenient threshold (0.10 vs 0.25) | |
| ranked = [item for item in ranked if item[2] >= MIN_SEGMENT_SCORE] | |
| # Deduplicate by segment text (case-insensitive, first 100 chars) | |
| seen = set() | |
| output_segments = [] | |
| output_metadata = [] | |
| for segment, meta, score in ranked: | |
| # Create deduplication key | |
| key = segment.lower()[:100] | |
| if key in seen: | |
| continue | |
| seen.add(key) | |
| output_segments.append(segment) | |
| output_metadata.append(meta) | |
| # Stop when we have enough | |
| if len(output_segments) >= max_segments: | |
| break | |
| print(f"✅ Extracted {len(output_segments)} relevant segments") | |
| return output_segments, output_metadata | |
| # ========================================== | |
| # PROMPT CONSTRUCTION | |
| # ========================================== | |
| def build_grounded_prompt( | |
| quotes: List[str], | |
| quote_ids: List[int], | |
| question: str | |
| ) -> List[Dict]: | |
| """ | |
| Build a strictly grounded prompt with numbered quotes. | |
| Args: | |
| quotes: List of evidence quotes | |
| quote_ids: List of quote indices for citation | |
| question: User's question | |
| Returns: | |
| List of message dictionaries for LLM | |
| """ | |
| # Format quotes with numbers | |
| quoted_text = "\n".join([f"[{i+1}] {quote}" for i, quote in enumerate(quotes)]) | |
| # System prompt emphasizing grounding | |
| system_message = ( | |
| "You are a Dead Cells expert assistant.\n\n" | |
| "CRITICAL RULES:\n" | |
| "1) Use ONLY the numbered quotes provided below\n" | |
| "2) Every factual claim must end with citation brackets: [1] or [1][2]\n" | |
| "3) If the quotes do NOT contain enough information to answer, reply: 'Not in provided context.'\n" | |
| "4) Be specific and actionable - provide concrete strategies, mechanics, and tips\n" | |
| "5) Keep answers concise but complete\n" | |
| "6) You may use thematic emojis (⚔️🗡️💀🔥) sparingly for readability" | |
| ) | |
| # User prompt with quotes and question | |
| user_message = ( | |
| f"Evidence Quotes:\n{quoted_text}\n\n" | |
| f"Question: {question}\n\n" | |
| "Answer (with citations):" | |
| ) | |
| return [ | |
| {"role": "system", "content": system_message}, | |
| {"role": "user", "content": user_message} | |
| ] | |
| # ========================================== | |
| # SOURCE ATTRIBUTION | |
| # ========================================== | |
| def attach_sources(metadata_list: List[Dict]) -> List[str]: | |
| """ | |
| Format source attributions from metadata. | |
| Args: | |
| metadata_list: List of chunk metadata | |
| Returns: | |
| List of formatted source strings (deduplicated) | |
| """ | |
| sources = [] | |
| for meta in metadata_list: | |
| source = meta.get("source", "") | |
| url = meta.get("url", "") | |
| # Format source name (remove .txt extension) | |
| label = source.replace(".txt", "") if source else "unknown_source" | |
| # Add URL if available | |
| if url: | |
| sources.append(f"{label} ({url})") | |
| else: | |
| sources.append(label) | |
| # Deduplicate while preserving order | |
| seen = set() | |
| unique_sources = [] | |
| for source in sources: | |
| if source not in seen: | |
| seen.add(source) | |
| unique_sources.append(source) | |
| return unique_sources | |
| # ========================================== | |
| # MAIN CHAT FUNCTION | |
| # ========================================== | |
| def chat(message: str, history: List) -> str: | |
| """ | |
| Main chatbot function with full RAG pipeline. | |
| Pipeline: | |
| 1. Dense semantic retrieval | |
| 2. Hybrid ranking (embeddings + BM25) | |
| 3. MMR diversification | |
| 4. Cross-encoder re-ranking | |
| 5. SOLUTION 3: Enhanced strategy segment extraction | |
| 6. Grounded LLM generation with citations | |
| 7. Source attribution | |
| Args: | |
| message: User's question | |
| history: Conversation history (unused in stateless version) | |
| Returns: | |
| Bot response with citations and sources | |
| """ | |
| try: | |
| # Safety check | |
| if collection.count() == 0: | |
| return "💀 Knowledge base is empty. Load wiki files and restart." | |
| print(f"\n{'='*60}\n🔍 Query: {message}\n{'='*60}") | |
| # ============================================================ | |
| # STEP 1: Dense Retrieval | |
| # ============================================================ | |
| print("📡 Step 1: Dense semantic retrieval...") | |
| docs, metas, similarities = dense_query(message) | |
| if not docs: | |
| return "Not in provided context." | |
| print(f" Retrieved {len(docs)} candidates") | |
| # ============================================================ | |
| # STEP 2: Hybrid Ranking | |
| # ============================================================ | |
| print("🔀 Step 2: Hybrid ranking (dense + BM25)...") | |
| hybrid_ranked = hybrid_rank(message, docs, metas, similarities) | |
| # ============================================================ | |
| # STEP 3: MMR Diversification | |
| # ============================================================ | |
| print("🎯 Step 3: MMR diversification...") | |
| diversified = mmr_diversify( | |
| hybrid_ranked, | |
| max_k=min(TOP_K * 3, len(hybrid_ranked)), | |
| lambda_weight=MMR_LAMBDA | |
| ) | |
| print(f" Diversified to {len(diversified)} chunks") | |
| # ============================================================ | |
| # STEP 4: Cross-Encoder Re-Ranking | |
| # ============================================================ | |
| print("🎖️ Step 4: Cross-encoder re-ranking...") | |
| reranked = rerank_with_crossencoder( | |
| message, | |
| [(i, score, text, meta) for i, score, text, meta in diversified] | |
| ) | |
| # Select top chunks | |
| chosen = reranked[:TOP_K] | |
| print(f" Selected {len(chosen)} top chunks") | |
| # Extract texts and metadata | |
| texts = [chunk[2] for chunk in chosen] | |
| chunk_metas = [chunk[3] for chunk in chosen] | |
| # ============================================================ | |
| # STEP 5: Build Context Pool (with word limit) | |
| # ============================================================ | |
| print("📚 Step 5: Building context pool...") | |
| pool_texts = [] | |
| pool_metas = [] | |
| total_words = 0 | |
| for text, meta in zip(texts, chunk_metas): | |
| word_count = len(tokenize(text)) | |
| # Stop if we exceed word limit (but keep at least one chunk) | |
| if total_words + word_count > MAX_CONTEXT_WORDS and pool_texts: | |
| break | |
| pool_texts.append(text) | |
| pool_metas.append(meta) | |
| total_words += word_count | |
| print(f" Context pool: {len(pool_texts)} chunks, ~{total_words} words") | |
| # ============================================================ | |
| # STEP 6: SOLUTION 3 - Extract Strategy Segments | |
| # ============================================================ | |
| print("✂️ Step 6: SOLUTION 3 - Extracting strategy segments...") | |
| quotes, quote_metas = best_segments( | |
| message, | |
| pool_texts, | |
| pool_metas, | |
| max_segments=SEGMENT_MAX | |
| ) | |
| # Check if we have enough evidence | |
| if len(quotes) < MIN_EVIDENCE_SEGMENTS: | |
| print(f"❌ Insufficient evidence: {len(quotes)} segments (need {MIN_EVIDENCE_SEGMENTS})") | |
| return "Not in provided context." | |
| print(f" ✅ Found {len(quotes)} quality segments") | |
| # ============================================================ | |
| # STEP 7: Grounded Generation | |
| # ============================================================ | |
| print("🤖 Step 7: LLM generation with grounding...") | |
| messages = build_grounded_prompt( | |
| quotes, | |
| list(range(len(quotes))), | |
| message | |
| ) | |
| # Call LLM | |
| response = llm.chat.completions.create( | |
| model=MODEL_NAME, | |
| messages=messages, | |
| temperature=TEMPERATURE, | |
| max_tokens=MAX_TOKENS, | |
| top_p=1.0 | |
| ) | |
| # Extract answer | |
| if response and response.choices: | |
| answer = response.choices[0].message.content.strip() | |
| else: | |
| answer = "Not in provided context." | |
| print(f" Generated {len(answer)} characters") | |
| # ============================================================ | |
| # STEP 8: Attach Sources | |
| # ============================================================ | |
| print("📎 Step 8: Attaching sources...") | |
| sources = attach_sources(quote_metas) | |
| if sources: | |
| answer += "\n\n---\n🗡️ **Wiki Sources:** " + ", ".join(sources) | |
| print(f"✅ Complete! {len(sources)} unique sources cited\n") | |
| return answer | |
| except Exception as e: | |
| error_msg = f"💀 Error: {str(e)}" | |
| print(f"❌ {error_msg}") | |
| return error_msg | |
| # ========================================== | |
| # GRADIO UI | |
| # ========================================== | |
| # Custom CSS for Dead Cells theme with HIGH CONTRAST | |
| custom_css = """ | |
| body { | |
| background: #0d0d0d !important; | |
| } | |
| .gradio-container { | |
| background: linear-gradient(180deg, #1a0a1f 0%, #0a0514 50%, #000000 100%) !important; | |
| font-family: -apple-system, system-ui, sans-serif !important; | |
| } | |
| .contain { | |
| background: rgba(10, 5, 20, 0.8) !important; | |
| border: 2px solid #8b46b3 !important; | |
| border-radius: 12px !important; | |
| } | |
| h1 { | |
| color: #f0abfc !important; | |
| text-shadow: 0 0 20px rgba(217, 70, 239, 0.9) !important; | |
| font-weight: 900 !important; | |
| letter-spacing: 3px !important; | |
| padding: 20px !important; | |
| } | |
| .description, p { | |
| color: #e0e0e0 !important; | |
| font-size: 16px !important; | |
| } | |
| .chatbot { | |
| background: rgba(15, 10, 25, 0.95) !important; | |
| border: 2px solid #8b46b3 !important; | |
| border-radius: 10px !important; | |
| } | |
| /* High contrast for ALL message text */ | |
| .message.user, .message.user * { | |
| background: rgba(139, 70, 179, 0.3) !important; | |
| color: #ffffff !important; | |
| border-left: 4px solid #d946ef !important; | |
| border-radius: 10px !important; | |
| padding: 15px !important; | |
| margin: 10px 0 !important; | |
| } | |
| .message.bot, .message.bot * { | |
| background: rgba(30, 20, 50, 0.6) !important; | |
| color: #ffffff !important; | |
| border-left: 4px solid #a78bfa !important; | |
| border-radius: 10px !important; | |
| padding: 15px !important; | |
| margin: 10px 0 !important; | |
| } | |
| /* Force white text in chat messages */ | |
| .chatbot .message * { | |
| color: #ffffff !important; | |
| } | |
| /* Override Gradio's prose styling */ | |
| .prose, .prose * { | |
| color: #ffffff !important; | |
| } | |
| textarea { | |
| background: rgba(20, 15, 30, 0.95) !important; | |
| color: #ffffff !important; | |
| border: 2px solid #8b46b3 !important; | |
| border-radius: 8px !important; | |
| padding: 14px !important; | |
| font-size: 16px !important; | |
| line-height: 1.5 !important; | |
| } | |
| textarea::placeholder { | |
| color: #cccccc !important; | |
| opacity: 0.7 !important; | |
| } | |
| button { | |
| background: linear-gradient(135deg, #8b46b3 0%, #6b2d8f 100%) !important; | |
| color: #ffffff !important; | |
| border: 2px solid #a855f7 !important; | |
| border-radius: 8px !important; | |
| padding: 12px 24px !important; | |
| font-weight: 700 !important; | |
| font-size: 14px !important; | |
| text-transform: uppercase !important; | |
| letter-spacing: 1px !important; | |
| } | |
| button:hover { | |
| background: linear-gradient(135deg, #d946ef 0%, #a855f7 100%) !important; | |
| box-shadow: 0 0 20px rgba(217, 70, 239, 0.6) !important; | |
| } | |
| /* Error messages should be highly visible */ | |
| .error, .error * { | |
| color: #ff6b6b !important; | |
| font-weight: 600 !important; | |
| } | |
| """ | |
| # Create Gradio interface | |
| demo = gr.ChatInterface( | |
| fn=chat, | |
| title="⚔️ DEAD CELLS WIKI BOT 💀", | |
| description=( | |
| "🗡️ **Solution 3: Enhanced Strategy Extraction v1.2**\n\n" | |
| "Ask anything about Dead Cells! This bot uses:\n" | |
| "- Hybrid retrieval (semantic + keyword)\n" | |
| "- Cross-encoder re-ranking\n" | |
| "- Wiki-aware strategy extraction\n" | |
| "- Explicit avoidance of non-strategy content\n" | |
| "- Strict grounding with citations\n\n" | |
| "If the answer isn't in the wiki, it will tell you." | |
| ), | |
| examples=[ | |
| "How can I beat the Hand of the King?", | |
| "What are Boss Stem Cells?", | |
| "How does malaise work?", | |
| "List brutality weapons.", | |
| "What biomes are in the game?", | |
| "Best strategy for Conjunctivius?", | |
| "How do I unlock the Flawless outfit?" | |
| ], | |
| css=custom_css, | |
| theme=gr.themes.Base() | |
| ) | |
| # ========================================== | |
| # LAUNCH | |
| # ========================================== | |
| if __name__ == "__main__": | |
| print("\n" + "="*60) | |
| print("🎮 DEAD CELLS WIKI BOT v1.2 - Ready to launch!") | |
| print("="*60 + "\n") | |
| demo.launch(server_name="0.0.0.0", server_port=7860) |