Spaces:
Sleeping
Sleeping
| """ | |
| Claim Extractor | |
| Breaks down user explanations into individual claims/statements | |
| """ | |
| from typing import List, Dict | |
| import os | |
| import requests | |
| from sentence_transformers import SentenceTransformer | |
| import json | |
| class ClaimExtractor: | |
| def __init__(self): | |
| self.embedding_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2') | |
| self.hf_api_key = os.getenv('HUGGINGFACE_API_KEY') | |
| self.llm_endpoint = "https://api-inference.huggingface.co/models/mistralai/Mistral-7B-Instruct-v0.2" | |
| self._ready = False | |
| self._initialize() | |
| def _initialize(self): | |
| """Initialize models""" | |
| try: | |
| # Test embedding model - this takes a few seconds on first run | |
| test_embedding = self.embedding_model.encode("test") | |
| self._ready = True | |
| except Exception as e: | |
| print(f"Claim extractor initialization error: {e}") # TODO: better error handling | |
| self._ready = False | |
| def is_ready(self) -> bool: | |
| return self._ready | |
| async def extract_claims(self, explanation: str) -> List[Dict[str, any]]: | |
| """ | |
| Extract atomic claims from user explanation | |
| Returns: | |
| List of claims with metadata: | |
| - text: the claim itself | |
| - type: 'definition', 'causal', 'assumption', 'example' | |
| - embedding: semantic vector | |
| - confidence: extraction confidence | |
| """ | |
| # Use LLM to extract structured claims | |
| claims_raw = await self._llm_extract_claims(explanation) | |
| # Add embeddings and metadata | |
| claims = [] | |
| for i, claim_text in enumerate(claims_raw): | |
| embedding = self.embedding_model.encode(claim_text) | |
| claim_type = self._classify_claim_type(claim_text) | |
| claims.append({ | |
| 'id': f'claim_{i}', | |
| 'text': claim_text, | |
| 'type': claim_type, | |
| 'embedding': embedding.tolist(), | |
| 'confidence': 0.85 # Simplified for demo | |
| }) | |
| return claims | |
| async def _llm_extract_claims(self, explanation: str) -> List[str]: | |
| """Use LLM to extract atomic claims""" | |
| prompt = f"""<s>[INST] You are a precise claim extraction system. Break down the following explanation into atomic claims. Each claim should be a single, testable statement. | |
| Explanation: {explanation} | |
| Extract each claim on a new line, numbered. Focus on: | |
| 1. Definitions (what things are) | |
| 2. Causal relationships (X causes Y) | |
| 3. Assumptions (implicit or explicit) | |
| 4. Properties and characteristics | |
| Output only the numbered claims, nothing else. [/INST]""" | |
| try: | |
| headers = {"Authorization": f"Bearer {self.hf_api_key}"} | |
| payload = { | |
| "inputs": prompt, | |
| "parameters": { | |
| "max_new_tokens": 500, | |
| "temperature": 0.3, | |
| "return_full_text": False | |
| } | |
| } | |
| response = requests.post(self.llm_endpoint, headers=headers, json=payload, timeout=30) | |
| if response.status_code == 200: | |
| result = response.json() | |
| text = result[0]['generated_text'] if isinstance(result, list) else result.get('generated_text', '') | |
| # Parse numbered claims | |
| claims = [] | |
| for line in text.split('\n'): | |
| line = line.strip() | |
| # Remove numbering like "1.", "2)", etc. | |
| if line and (line[0].isdigit() or line.startswith('-')): | |
| # Clean up the claim | |
| claim = line.lstrip('0123456789.-) ').strip() | |
| if claim: | |
| claims.append(claim) | |
| return claims if claims else [explanation] # Fallback to full explanation | |
| else: | |
| # Fallback: simple sentence splitting | |
| return self._fallback_extraction(explanation) | |
| except Exception as e: | |
| print(f"LLM extraction error: {e}") | |
| return self._fallback_extraction(explanation) | |
| def _fallback_extraction(self, explanation: str) -> List[str]: | |
| """Fallback: simple sentence-based extraction""" | |
| import re | |
| sentences = re.split(r'[.!?]+', explanation) | |
| return [s.strip() for s in sentences if s.strip() and len(s.strip()) > 10] | |
| def _classify_claim_type(self, claim: str) -> str: | |
| """Classify claim type based on linguistic patterns""" | |
| claim_lower = claim.lower() | |
| # Definition patterns | |
| if any(pattern in claim_lower for pattern in ['is a', 'is the', 'refers to', 'means', 'defined as']): | |
| return 'definition' | |
| # Causal patterns | |
| elif any(pattern in claim_lower for pattern in ['causes', 'leads to', 'results in', 'because', 'therefore']): | |
| return 'causal' | |
| # Example patterns | |
| elif any(pattern in claim_lower for pattern in ['for example', 'such as', 'like', 'instance']): | |
| return 'example' | |
| # Assumption patterns | |
| elif any(pattern in claim_lower for pattern in ['assume', 'given that', 'suppose', 'if']): | |
| return 'assumption' | |
| else: | |
| return 'statement' | |