conceptvector / analysis /claim_extractor.py
Tawhid Bin Omar
Initial deployment of RealityCheck AI backend
8176754
"""
Claim Extractor
Breaks down user explanations into individual claims/statements
"""
from typing import List, Dict
import os
import requests
from sentence_transformers import SentenceTransformer
import json
class ClaimExtractor:
def __init__(self):
self.embedding_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
self.hf_api_key = os.getenv('HUGGINGFACE_API_KEY')
self.llm_endpoint = "https://api-inference.huggingface.co/models/mistralai/Mistral-7B-Instruct-v0.2"
self._ready = False
self._initialize()
def _initialize(self):
"""Initialize models"""
try:
# Test embedding model - this takes a few seconds on first run
test_embedding = self.embedding_model.encode("test")
self._ready = True
except Exception as e:
print(f"Claim extractor initialization error: {e}") # TODO: better error handling
self._ready = False
def is_ready(self) -> bool:
return self._ready
async def extract_claims(self, explanation: str) -> List[Dict[str, any]]:
"""
Extract atomic claims from user explanation
Returns:
List of claims with metadata:
- text: the claim itself
- type: 'definition', 'causal', 'assumption', 'example'
- embedding: semantic vector
- confidence: extraction confidence
"""
# Use LLM to extract structured claims
claims_raw = await self._llm_extract_claims(explanation)
# Add embeddings and metadata
claims = []
for i, claim_text in enumerate(claims_raw):
embedding = self.embedding_model.encode(claim_text)
claim_type = self._classify_claim_type(claim_text)
claims.append({
'id': f'claim_{i}',
'text': claim_text,
'type': claim_type,
'embedding': embedding.tolist(),
'confidence': 0.85 # Simplified for demo
})
return claims
async def _llm_extract_claims(self, explanation: str) -> List[str]:
"""Use LLM to extract atomic claims"""
prompt = f"""<s>[INST] You are a precise claim extraction system. Break down the following explanation into atomic claims. Each claim should be a single, testable statement.
Explanation: {explanation}
Extract each claim on a new line, numbered. Focus on:
1. Definitions (what things are)
2. Causal relationships (X causes Y)
3. Assumptions (implicit or explicit)
4. Properties and characteristics
Output only the numbered claims, nothing else. [/INST]"""
try:
headers = {"Authorization": f"Bearer {self.hf_api_key}"}
payload = {
"inputs": prompt,
"parameters": {
"max_new_tokens": 500,
"temperature": 0.3,
"return_full_text": False
}
}
response = requests.post(self.llm_endpoint, headers=headers, json=payload, timeout=30)
if response.status_code == 200:
result = response.json()
text = result[0]['generated_text'] if isinstance(result, list) else result.get('generated_text', '')
# Parse numbered claims
claims = []
for line in text.split('\n'):
line = line.strip()
# Remove numbering like "1.", "2)", etc.
if line and (line[0].isdigit() or line.startswith('-')):
# Clean up the claim
claim = line.lstrip('0123456789.-) ').strip()
if claim:
claims.append(claim)
return claims if claims else [explanation] # Fallback to full explanation
else:
# Fallback: simple sentence splitting
return self._fallback_extraction(explanation)
except Exception as e:
print(f"LLM extraction error: {e}")
return self._fallback_extraction(explanation)
def _fallback_extraction(self, explanation: str) -> List[str]:
"""Fallback: simple sentence-based extraction"""
import re
sentences = re.split(r'[.!?]+', explanation)
return [s.strip() for s in sentences if s.strip() and len(s.strip()) > 10]
def _classify_claim_type(self, claim: str) -> str:
"""Classify claim type based on linguistic patterns"""
claim_lower = claim.lower()
# Definition patterns
if any(pattern in claim_lower for pattern in ['is a', 'is the', 'refers to', 'means', 'defined as']):
return 'definition'
# Causal patterns
elif any(pattern in claim_lower for pattern in ['causes', 'leads to', 'results in', 'because', 'therefore']):
return 'causal'
# Example patterns
elif any(pattern in claim_lower for pattern in ['for example', 'such as', 'like', 'instance']):
return 'example'
# Assumption patterns
elif any(pattern in claim_lower for pattern in ['assume', 'given that', 'suppose', 'if']):
return 'assumption'
else:
return 'statement'