doc-processor / simple /summarizer.py
Kartik Narang
first clean commit
3cfeab7
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
from groq import Groq
import re
from nltk.tokenize import sent_tokenize
import nltk
# Download required NLTK data
try:
nltk.download('punkt', quiet=True)
nltk.download('punkt_tab', quiet=True)
except:
pass
def summarize_legal_document(text, max_sentences=5, groq_api_key=None, model_path=None):
"""
Summarize legal document text
Args:
text: Input text to summarize
max_sentences: Maximum number of sentences in summary
groq_api_key: Optional Groq API key for enhanced summarization
model_path: Optional custom model path
Returns:
Dictionary with summary and metadata
"""
if not text or not text.strip():
return {"error": "Empty text provided", "success": False}
max_sentences = max(3, min(max_sentences, 20))
# Initialize result
result = {
"original_length": len(text),
"word_count": len(text.split()),
"sentence_count": len(sent_tokenize(text)),
"success": False
}
try:
# Always generate extractive summary
extractive_summary = _extractive_summarize(text, max_sentences)
result["summary"] = extractive_summary
# Try Groq enhancement
if groq_api_key:
try:
groq_summary = _groq_summarize(text, max_sentences, groq_api_key)
if groq_summary:
result["summary"] = groq_summary
except Exception:
pass
# Calculate final metrics
final_summary = result.get("summary", "")
result["summary_length"] = len(final_summary)
result["compression_ratio"] = (
result["summary_length"] / result["original_length"]
if result["original_length"] > 0 else 0
)
result["success"] = True
except Exception as e:
result["error"] = str(e)
result["success"] = False
return result
def _extractive_summarize(text, max_sentences):
"""Extract key sentences based on legal document scoring"""
sentences = sent_tokenize(text)
if len(sentences) <= max_sentences:
return text
legal_keywords = [
'court', 'judge', 'plaintiff', 'defendant', 'appellant', 'respondent',
'held', 'ruled', 'decided', 'judgment', 'order', 'section', 'article',
'provision', 'law', 'legal', 'case', 'appeal', 'petition', 'writ',
'contract', 'agreement', 'liability', 'damages', 'evidence', 'witness',
'statute', 'regulation', 'finding', 'conclusion', 'reasoning'
]
sentence_scores = []
for i, sentence in enumerate(sentences):
if not sentence.strip():
continue
score = 0
sentence_lower = sentence.lower()
# Keyword scoring
for keyword in legal_keywords:
if keyword in sentence_lower:
score += 1
# Position scoring
if i == 0:
score += 3
elif i == len(sentences) - 1:
score += 2
elif i < len(sentences) * 0.2:
score += 1
# Length scoring
word_count = len(sentence.split())
if 15 <= word_count <= 40:
score += 2
elif 10 <= word_count <= 50:
score += 1
# Numbers and dates
if re.search(r'\b\d{4}\b|\b\d+\s*(percent|%|\$)', sentence):
score += 1
# Legal citations
if re.search(r'\d+\s+[A-Z][a-z]+\.?\s+\d+|\bv\.\s+[A-Z]', sentence):
score += 2
sentence_scores.append((score, i, sentence))
# Select top sentences
sentence_scores.sort(reverse=True, key=lambda x: x[0])
selected_sentences = sentence_scores[:max_sentences]
# Sort by original order
selected_sentences.sort(key=lambda x: x[1])
return ' '.join([sent[2] for sent in selected_sentences])
def _groq_summarize(text, max_sentences, api_key):
"""Enhanced summarization using Groq LLM"""
try:
client = Groq(api_key=api_key)
# Truncate if too long
if len(text) > 6000:
text = text[:6000] + "\n[...text truncated...]"
system_prompt = """You are an expert legal document summarizer. Create concise, accurate summaries that capture the most important information.
Guidelines:
1. Focus on key legal facts, holdings, and conclusions
2. Preserve important legal terminology and concepts
3. Maintain logical flow of legal reasoning
4. Include relevant case citations, statutes, or regulations
5. Be precise and avoid unnecessary elaboration"""
user_prompt = f"""Please summarize the following legal document in approximately {max_sentences} sentences:
{text}
Provide a clear, concise summary:"""
response = client.chat.completions.create(
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt}
],
model="llama-3.1-8b-instant",
temperature=0.2,
max_tokens=800,
top_p=0.9
)
summary = response.choices[0].message.content.strip()
if summary and len(summary) > 20:
return summary
except Exception:
pass
return None
def _chunk_text(text, max_words):
"""Split text into chunks for processing"""
words = text.split()
chunks = []
for i in range(0, len(words), max_words):
chunk_words = words[i:i + max_words]
if chunk_words:
chunks.append(' '.join(chunk_words))
return chunks