PrashantRGore's picture
Deploy drug-causality-bert v1 with BioBERT model and caching optimizations
0f21e9b
import torch
import numpy as np
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from pathlib import Path
import PyPDF2
import json
from datetime import datetime
from typing import Union, List, Dict
import re
# NLTK with robust error handling
import nltk
import ssl
# SSL fix for NLTK
try:
_create_unverified_https_context = ssl._create_unverified_context
except AttributeError:
pass
else:
ssl._create_default_https_context = _create_unverified_https_context
# Enhanced NLTK data download with retry
def download_nltk_data_robust():
"""Download NLTK data with multiple attempts and fallbacks"""
import os
# Set NLTK data path explicitly
nltk_data_dir = '/home/appuser/nltk_data'
if not os.path.exists(nltk_data_dir):
try:
os.makedirs(nltk_data_dir, exist_ok=True)
except:
pass
if nltk_data_dir not in nltk.data.path:
nltk.data.path.insert(0, nltk_data_dir)
packages = ['punkt', 'punkt_tab']
for package in packages:
for attempt in range(3): # Try 3 times
try:
nltk.data.find(f'tokenizers/{package}')
print(f"✓ {package} already available")
break
except LookupError:
try:
print(f"Downloading {package} (attempt {attempt + 1})...")
nltk.download(package, download_dir=nltk_data_dir, quiet=False)
print(f"✓ {package} downloaded successfully")
break
except Exception as e:
print(f"Warning: Could not download {package}: {e}")
if attempt == 2:
print(f"Failed to download {package} after 3 attempts")
# Download on import
download_nltk_data_robust()
# Fallback sentence tokenizer using regex
def simple_sentence_tokenize(text):
"""Simple regex-based sentence tokenizer as fallback"""
# Split on common sentence boundaries
sentences = re.split(r'(?<=[.!?])\s+', text)
return [s.strip() for s in sentences if s.strip()]
# Safe sentence tokenization with fallback
def safe_sent_tokenize(text):
"""Tokenize with NLTK, fallback to regex if NLTK fails"""
try:
from nltk.tokenize import sent_tokenize
return sent_tokenize(text)
except Exception as e:
print(f"NLTK tokenization failed ({e}), using fallback...")
return simple_sentence_tokenize(text)
class CausalityClassifier:
def __init__(self, model_path='./models/production_model_final', threshold=0.5):
self.model_path = Path(model_path)
self.threshold = threshold
self.tokenizer = AutoTokenizer.from_pretrained(self.model_path)
self.model = AutoModelForSequenceClassification.from_pretrained(self.model_path)
self.model.eval()
def predict(self, text, return_probs=False):
inputs = self.tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=96)
with torch.no_grad():
outputs = self.model(**inputs)
probs = torch.softmax(outputs.logits, dim=1).numpy()[0]
pred = 1 if probs[1] > self.threshold else 0
result = {
'prediction': 'related' if pred == 1 else 'not related',
'confidence': float(probs[pred]),
'label': int(pred)
}
if return_probs:
result['probabilities'] = {
'not_related': float(probs[0]),
'related': float(probs[1])
}
return result
def extract_text_from_pdf(pdf_path):
with open(pdf_path, 'rb') as file:
pdf_reader = PyPDF2.PdfReader(file)
text = ""
for page in pdf_reader.pages:
text += page.extract_text()
return text
def classify_causality(pdf_text, model_path='./models/production_model_final', threshold=0.5, verbose=False):
classifier = CausalityClassifier(model_path, threshold)
# Use safe tokenization with fallback
sentences = safe_sent_tokenize(pdf_text)
if verbose:
print(f"Tokenized {len(sentences)} sentences")
related_count = 0
sentence_details = []
for sent in sentences:
if not sent.strip():
continue
result = classifier.predict(sent, return_probs=True)
if result['label'] == 1:
related_count += 1
sentence_details.append({
'sentence': sent[:100],
'probability_related': result['probabilities']['related'],
'confidence': result['confidence']
})
sentence_details.sort(key=lambda x: x['probability_related'], reverse=True)
return {
'final_classification': 'related' if related_count > 0 else 'not related',
'confidence_score': related_count / len(sentences) if sentences else 0,
'related_sentences': related_count,
'total_sentences': len(sentences),
'top_related_sentences': sentence_details[:5],
'threshold_used': threshold
}
def process_pdf_file(pdf_path, model_path='./models/production_model_final', threshold=0.5, save_report=False, output_dir='./results'):
pdf_text = extract_text_from_pdf(pdf_path)
results = classify_causality(pdf_text, model_path, threshold)
results['pdf_file'] = str(Path(pdf_path).name)
if save_report:
Path(output_dir).mkdir(parents=True, exist_ok=True)
with open(Path(output_dir) / f"{Path(pdf_path).stem}_report.json", 'w') as f:
json.dump(results, f, indent=2)
return results
def process_multiple_pdfs(pdf_paths, model_path='./models/production_model_final', threshold=0.5, save_reports=False, output_dir='./results'):
all_results = []
for pdf_path in pdf_paths:
try:
results = process_pdf_file(pdf_path, model_path, threshold, save_reports, output_dir)
all_results.append(results)
except Exception as e:
all_results.append({
'pdf_file': str(Path(pdf_path).name),
'error': str(e),
'final_classification': 'error'
})
return all_results