File size: 6,233 Bytes
0f21e9b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
import torch
import numpy as np
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from pathlib import Path
import PyPDF2
import json
from datetime import datetime
from typing import Union, List, Dict
import re

# NLTK with robust error handling
import nltk
import ssl

# SSL fix for NLTK
try:
    _create_unverified_https_context = ssl._create_unverified_context
except AttributeError:
    pass
else:
    ssl._create_default_https_context = _create_unverified_https_context

# Enhanced NLTK data download with retry
def download_nltk_data_robust():
    """Download NLTK data with multiple attempts and fallbacks"""
    import os
    
    # Set NLTK data path explicitly
    nltk_data_dir = '/home/appuser/nltk_data'
    if not os.path.exists(nltk_data_dir):
        try:
            os.makedirs(nltk_data_dir, exist_ok=True)
        except:
            pass
    
    if nltk_data_dir not in nltk.data.path:
        nltk.data.path.insert(0, nltk_data_dir)
    
    packages = ['punkt', 'punkt_tab']
    for package in packages:
        for attempt in range(3):  # Try 3 times
            try:
                nltk.data.find(f'tokenizers/{package}')
                print(f"✓ {package} already available")
                break
            except LookupError:
                try:
                    print(f"Downloading {package} (attempt {attempt + 1})...")
                    nltk.download(package, download_dir=nltk_data_dir, quiet=False)
                    print(f"✓ {package} downloaded successfully")
                    break
                except Exception as e:
                    print(f"Warning: Could not download {package}: {e}")
                    if attempt == 2:
                        print(f"Failed to download {package} after 3 attempts")

# Download on import
download_nltk_data_robust()

# Fallback sentence tokenizer using regex
def simple_sentence_tokenize(text):
    """Simple regex-based sentence tokenizer as fallback"""
    # Split on common sentence boundaries
    sentences = re.split(r'(?<=[.!?])\s+', text)
    return [s.strip() for s in sentences if s.strip()]

# Safe sentence tokenization with fallback
def safe_sent_tokenize(text):
    """Tokenize with NLTK, fallback to regex if NLTK fails"""
    try:
        from nltk.tokenize import sent_tokenize
        return sent_tokenize(text)
    except Exception as e:
        print(f"NLTK tokenization failed ({e}), using fallback...")
        return simple_sentence_tokenize(text)

class CausalityClassifier:
    def __init__(self, model_path='./models/production_model_final', threshold=0.5):
        self.model_path = Path(model_path)
        self.threshold = threshold
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_path)
        self.model = AutoModelForSequenceClassification.from_pretrained(self.model_path)
        self.model.eval()
    
    def predict(self, text, return_probs=False):
        inputs = self.tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=96)
        with torch.no_grad():
            outputs = self.model(**inputs)
            probs = torch.softmax(outputs.logits, dim=1).numpy()[0]
            pred = 1 if probs[1] > self.threshold else 0
        result = {
            'prediction': 'related' if pred == 1 else 'not related',
            'confidence': float(probs[pred]),
            'label': int(pred)
        }
        if return_probs:
            result['probabilities'] = {
                'not_related': float(probs[0]),
                'related': float(probs[1])
            }
        return result

def extract_text_from_pdf(pdf_path):
    with open(pdf_path, 'rb') as file:
        pdf_reader = PyPDF2.PdfReader(file)
        text = ""
        for page in pdf_reader.pages:
            text += page.extract_text()
        return text

def classify_causality(pdf_text, model_path='./models/production_model_final', threshold=0.5, verbose=False):
    classifier = CausalityClassifier(model_path, threshold)
    
    # Use safe tokenization with fallback
    sentences = safe_sent_tokenize(pdf_text)
    
    if verbose:
        print(f"Tokenized {len(sentences)} sentences")
    
    related_count = 0
    sentence_details = []
    
    for sent in sentences:
        if not sent.strip():
            continue
            
        result = classifier.predict(sent, return_probs=True)
        if result['label'] == 1:
            related_count += 1
            sentence_details.append({
                'sentence': sent[:100],
                'probability_related': result['probabilities']['related'],
                'confidence': result['confidence']
            })
    
    sentence_details.sort(key=lambda x: x['probability_related'], reverse=True)
    
    return {
        'final_classification': 'related' if related_count > 0 else 'not related',
        'confidence_score': related_count / len(sentences) if sentences else 0,
        'related_sentences': related_count,
        'total_sentences': len(sentences),
        'top_related_sentences': sentence_details[:5],
        'threshold_used': threshold
    }

def process_pdf_file(pdf_path, model_path='./models/production_model_final', threshold=0.5, save_report=False, output_dir='./results'):
    pdf_text = extract_text_from_pdf(pdf_path)
    results = classify_causality(pdf_text, model_path, threshold)
    results['pdf_file'] = str(Path(pdf_path).name)
    if save_report:
        Path(output_dir).mkdir(parents=True, exist_ok=True)
        with open(Path(output_dir) / f"{Path(pdf_path).stem}_report.json", 'w') as f:
            json.dump(results, f, indent=2)
    return results

def process_multiple_pdfs(pdf_paths, model_path='./models/production_model_final', threshold=0.5, save_reports=False, output_dir='./results'):
    all_results = []
    for pdf_path in pdf_paths:
        try:
            results = process_pdf_file(pdf_path, model_path, threshold, save_reports, output_dir)
            all_results.append(results)
        except Exception as e:
            all_results.append({
                'pdf_file': str(Path(pdf_path).name),
                'error': str(e),
                'final_classification': 'error'
            })
    return all_results