Spaces:
Runtime error
Runtime error
File size: 6,233 Bytes
0f21e9b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 |
import torch
import numpy as np
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from pathlib import Path
import PyPDF2
import json
from datetime import datetime
from typing import Union, List, Dict
import re
# NLTK with robust error handling
import nltk
import ssl
# SSL fix for NLTK
try:
_create_unverified_https_context = ssl._create_unverified_context
except AttributeError:
pass
else:
ssl._create_default_https_context = _create_unverified_https_context
# Enhanced NLTK data download with retry
def download_nltk_data_robust():
"""Download NLTK data with multiple attempts and fallbacks"""
import os
# Set NLTK data path explicitly
nltk_data_dir = '/home/appuser/nltk_data'
if not os.path.exists(nltk_data_dir):
try:
os.makedirs(nltk_data_dir, exist_ok=True)
except:
pass
if nltk_data_dir not in nltk.data.path:
nltk.data.path.insert(0, nltk_data_dir)
packages = ['punkt', 'punkt_tab']
for package in packages:
for attempt in range(3): # Try 3 times
try:
nltk.data.find(f'tokenizers/{package}')
print(f"✓ {package} already available")
break
except LookupError:
try:
print(f"Downloading {package} (attempt {attempt + 1})...")
nltk.download(package, download_dir=nltk_data_dir, quiet=False)
print(f"✓ {package} downloaded successfully")
break
except Exception as e:
print(f"Warning: Could not download {package}: {e}")
if attempt == 2:
print(f"Failed to download {package} after 3 attempts")
# Download on import
download_nltk_data_robust()
# Fallback sentence tokenizer using regex
def simple_sentence_tokenize(text):
"""Simple regex-based sentence tokenizer as fallback"""
# Split on common sentence boundaries
sentences = re.split(r'(?<=[.!?])\s+', text)
return [s.strip() for s in sentences if s.strip()]
# Safe sentence tokenization with fallback
def safe_sent_tokenize(text):
"""Tokenize with NLTK, fallback to regex if NLTK fails"""
try:
from nltk.tokenize import sent_tokenize
return sent_tokenize(text)
except Exception as e:
print(f"NLTK tokenization failed ({e}), using fallback...")
return simple_sentence_tokenize(text)
class CausalityClassifier:
def __init__(self, model_path='./models/production_model_final', threshold=0.5):
self.model_path = Path(model_path)
self.threshold = threshold
self.tokenizer = AutoTokenizer.from_pretrained(self.model_path)
self.model = AutoModelForSequenceClassification.from_pretrained(self.model_path)
self.model.eval()
def predict(self, text, return_probs=False):
inputs = self.tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=96)
with torch.no_grad():
outputs = self.model(**inputs)
probs = torch.softmax(outputs.logits, dim=1).numpy()[0]
pred = 1 if probs[1] > self.threshold else 0
result = {
'prediction': 'related' if pred == 1 else 'not related',
'confidence': float(probs[pred]),
'label': int(pred)
}
if return_probs:
result['probabilities'] = {
'not_related': float(probs[0]),
'related': float(probs[1])
}
return result
def extract_text_from_pdf(pdf_path):
with open(pdf_path, 'rb') as file:
pdf_reader = PyPDF2.PdfReader(file)
text = ""
for page in pdf_reader.pages:
text += page.extract_text()
return text
def classify_causality(pdf_text, model_path='./models/production_model_final', threshold=0.5, verbose=False):
classifier = CausalityClassifier(model_path, threshold)
# Use safe tokenization with fallback
sentences = safe_sent_tokenize(pdf_text)
if verbose:
print(f"Tokenized {len(sentences)} sentences")
related_count = 0
sentence_details = []
for sent in sentences:
if not sent.strip():
continue
result = classifier.predict(sent, return_probs=True)
if result['label'] == 1:
related_count += 1
sentence_details.append({
'sentence': sent[:100],
'probability_related': result['probabilities']['related'],
'confidence': result['confidence']
})
sentence_details.sort(key=lambda x: x['probability_related'], reverse=True)
return {
'final_classification': 'related' if related_count > 0 else 'not related',
'confidence_score': related_count / len(sentences) if sentences else 0,
'related_sentences': related_count,
'total_sentences': len(sentences),
'top_related_sentences': sentence_details[:5],
'threshold_used': threshold
}
def process_pdf_file(pdf_path, model_path='./models/production_model_final', threshold=0.5, save_report=False, output_dir='./results'):
pdf_text = extract_text_from_pdf(pdf_path)
results = classify_causality(pdf_text, model_path, threshold)
results['pdf_file'] = str(Path(pdf_path).name)
if save_report:
Path(output_dir).mkdir(parents=True, exist_ok=True)
with open(Path(output_dir) / f"{Path(pdf_path).stem}_report.json", 'w') as f:
json.dump(results, f, indent=2)
return results
def process_multiple_pdfs(pdf_paths, model_path='./models/production_model_final', threshold=0.5, save_reports=False, output_dir='./results'):
all_results = []
for pdf_path in pdf_paths:
try:
results = process_pdf_file(pdf_path, model_path, threshold, save_reports, output_dir)
all_results.append(results)
except Exception as e:
all_results.append({
'pdf_file': str(Path(pdf_path).name),
'error': str(e),
'final_classification': 'error'
})
return all_results
|