garvitcpp commited on
Commit
3af4a57
·
verified ·
1 Parent(s): cfa102d

Update src/summarizer.py

Browse files
Files changed (1) hide show
  1. src/summarizer.py +36 -23
src/summarizer.py CHANGED
@@ -1,28 +1,28 @@
1
- from transformers import pipeline
2
  import torch
3
  import logging
4
 
5
  class TextSummarizer:
6
  def __init__(self, model_name="facebook/bart-large-cnn"):
7
  """
8
- Initialize summarization pipeline
9
 
10
  Args:
11
  model_name (str): Hugging Face model for summarization
12
  """
13
  try:
14
- # Configure device
15
- device = 0 if torch.cuda.is_available() else -1
16
- logging.info(f"Using device: {'cuda' if device == 0 else 'cpu'}")
17
 
18
- # Initialize pipeline with explicit device mapping and lower precision
19
- self.summarizer = pipeline(
20
- "summarization",
21
- model=model_name,
22
- device=device,
23
- torch_dtype=torch.float32
24
- )
25
- logging.info("Summarization pipeline initialized successfully")
 
26
 
27
  except Exception as e:
28
  logging.error(f"Failed to load summarization model: {str(e)}")
@@ -48,20 +48,33 @@ class TextSummarizer:
48
  # Ensure min_length is less than max_length
49
  min_length = min(min_length, max_length)
50
 
51
- # Generate summary with chunking for long texts
52
  max_chunk_length = 1024 # BART's max input length
53
  chunks = [text[i:i + max_chunk_length] for i in range(0, len(text), max_chunk_length)]
54
  summaries = []
55
 
56
- for chunk in chunks:
57
- if chunk.strip():
58
- summary = self.summarizer(
59
- chunk,
60
- max_length=max_length // len(chunks), # Distribute length across chunks
61
- min_length=min_length // len(chunks),
62
- do_sample=False
63
- )[0]['summary_text']
64
- summaries.append(summary)
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
  return " ".join(summaries)
67
 
 
1
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqSummarization
2
  import torch
3
  import logging
4
 
5
  class TextSummarizer:
6
  def __init__(self, model_name="facebook/bart-large-cnn"):
7
  """
8
+ Initialize summarization model directly without using pipeline
9
 
10
  Args:
11
  model_name (str): Hugging Face model for summarization
12
  """
13
  try:
14
+ # Force CPU usage and disable GPU
15
+ self.device = torch.device('cpu')
 
16
 
17
+ # Initialize tokenizer and model separately
18
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
19
+ self.model = AutoModelForSeq2SeqSummarization.from_pretrained(model_name)
20
+
21
+ # Move model to CPU and eval mode
22
+ self.model = self.model.to(self.device)
23
+ self.model.eval()
24
+
25
+ logging.info("Summarization model initialized successfully")
26
 
27
  except Exception as e:
28
  logging.error(f"Failed to load summarization model: {str(e)}")
 
48
  # Ensure min_length is less than max_length
49
  min_length = min(min_length, max_length)
50
 
51
+ # Process text in chunks due to length limitations
52
  max_chunk_length = 1024 # BART's max input length
53
  chunks = [text[i:i + max_chunk_length] for i in range(0, len(text), max_chunk_length)]
54
  summaries = []
55
 
56
+ with torch.no_grad(): # Disable gradient calculation
57
+ for chunk in chunks:
58
+ if chunk.strip():
59
+ # Tokenize
60
+ inputs = self.tokenizer(chunk, max_length=1024, truncation=True,
61
+ return_tensors="pt")
62
+ inputs = inputs.to(self.device)
63
+
64
+ # Generate summary
65
+ summary_ids = self.model.generate(
66
+ inputs["input_ids"],
67
+ num_beams=4,
68
+ max_length=max_length // len(chunks),
69
+ min_length=min_length // len(chunks),
70
+ length_penalty=2.0,
71
+ early_stopping=True
72
+ )
73
+
74
+ # Decode summary
75
+ summary = self.tokenizer.decode(summary_ids[0],
76
+ skip_special_tokens=True)
77
+ summaries.append(summary)
78
 
79
  return " ".join(summaries)
80