File size: 5,476 Bytes
8176754
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
"""
Claim Extractor
Breaks down user explanations into individual claims/statements
"""

from typing import List, Dict
import os
import requests
from sentence_transformers import SentenceTransformer
import json

class ClaimExtractor:
    def __init__(self):
        self.embedding_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
        self.hf_api_key = os.getenv('HUGGINGFACE_API_KEY')
        self.llm_endpoint = "https://api-inference.huggingface.co/models/mistralai/Mistral-7B-Instruct-v0.2"
        self._ready = False
        self._initialize()
    
    def _initialize(self):
        """Initialize models"""
        try:
            # Test embedding model - this takes a few seconds on first run
            test_embedding = self.embedding_model.encode("test")
            self._ready = True
        except Exception as e:
            print(f"Claim extractor initialization error: {e}")  # TODO: better error handling
            self._ready = False
    
    def is_ready(self) -> bool:
        return self._ready
    
    async def extract_claims(self, explanation: str) -> List[Dict[str, any]]:
        """
        Extract atomic claims from user explanation
        
        Returns:
            List of claims with metadata:
            - text: the claim itself
            - type: 'definition', 'causal', 'assumption', 'example'
            - embedding: semantic vector
            - confidence: extraction confidence
        """
        # Use LLM to extract structured claims
        claims_raw = await self._llm_extract_claims(explanation)
        
        # Add embeddings and metadata
        claims = []
        for i, claim_text in enumerate(claims_raw):
            embedding = self.embedding_model.encode(claim_text)
            claim_type = self._classify_claim_type(claim_text)
            
            claims.append({
                'id': f'claim_{i}',
                'text': claim_text,
                'type': claim_type,
                'embedding': embedding.tolist(),
                'confidence': 0.85  # Simplified for demo
            })
        
        return claims
    
    async def _llm_extract_claims(self, explanation: str) -> List[str]:
        """Use LLM to extract atomic claims"""
        prompt = f"""<s>[INST] You are a precise claim extraction system. Break down the following explanation into atomic claims. Each claim should be a single, testable statement.

Explanation: {explanation}

Extract each claim on a new line, numbered. Focus on:
1. Definitions (what things are)
2. Causal relationships (X causes Y)
3. Assumptions (implicit or explicit)
4. Properties and characteristics

Output only the numbered claims, nothing else. [/INST]"""

        try:
            headers = {"Authorization": f"Bearer {self.hf_api_key}"}
            payload = {
                "inputs": prompt,
                "parameters": {
                    "max_new_tokens": 500,
                    "temperature": 0.3,
                    "return_full_text": False
                }
            }
            
            response = requests.post(self.llm_endpoint, headers=headers, json=payload, timeout=30)
            
            if response.status_code == 200:
                result = response.json()
                text = result[0]['generated_text'] if isinstance(result, list) else result.get('generated_text', '')
                
                # Parse numbered claims
                claims = []
                for line in text.split('\n'):
                    line = line.strip()
                    # Remove numbering like "1.", "2)", etc.
                    if line and (line[0].isdigit() or line.startswith('-')):
                        # Clean up the claim
                        claim = line.lstrip('0123456789.-) ').strip()
                        if claim:
                            claims.append(claim)
                
                return claims if claims else [explanation]  # Fallback to full explanation
            else:
                # Fallback: simple sentence splitting
                return self._fallback_extraction(explanation)
                
        except Exception as e:
            print(f"LLM extraction error: {e}")
            return self._fallback_extraction(explanation)
    
    def _fallback_extraction(self, explanation: str) -> List[str]:
        """Fallback: simple sentence-based extraction"""
        import re
        sentences = re.split(r'[.!?]+', explanation)
        return [s.strip() for s in sentences if s.strip() and len(s.strip()) > 10]
    
    def _classify_claim_type(self, claim: str) -> str:
        """Classify claim type based on linguistic patterns"""
        claim_lower = claim.lower()
        
        # Definition patterns
        if any(pattern in claim_lower for pattern in ['is a', 'is the', 'refers to', 'means', 'defined as']):
            return 'definition'
        
        # Causal patterns
        elif any(pattern in claim_lower for pattern in ['causes', 'leads to', 'results in', 'because', 'therefore']):
            return 'causal'
        
        # Example patterns
        elif any(pattern in claim_lower for pattern in ['for example', 'such as', 'like', 'instance']):
            return 'example'
        
        # Assumption patterns
        elif any(pattern in claim_lower for pattern in ['assume', 'given that', 'suppose', 'if']):
            return 'assumption'
        
        else:
            return 'statement'