garvitcpp commited on
Commit
ac4bd9a
·
verified ·
1 Parent(s): 5245cbd

Update src/summarizer.py

Browse files
Files changed (1) hide show
  1. src/summarizer.py +70 -47
src/summarizer.py CHANGED
@@ -1,47 +1,70 @@
1
- from transformers import pipeline
2
- import time
3
- class TextSummarizer:
4
- def __init__(self, model_name="facebook/bart-large-cnn"):
5
- """
6
- Initialize summarization pipeline
7
-
8
- Args:
9
- model_name (str): Hugging Face model for summarization
10
- """
11
- try:
12
- self.summarizer = pipeline("summarization", model=model_name)
13
- except Exception as e:
14
- raise RuntimeError(f"Failed to load summarization model: {e}")
15
-
16
- def generate_summary(self, text, max_length=400, min_length=100):
17
- """
18
- Generate summary for given text
19
-
20
- Args:
21
- text (str): Input text to summarize
22
- max_length (int): Maximum length of summary
23
- min_length (int): Minimum length of summary
24
-
25
- Returns:
26
- str: Generated summary
27
- """
28
- try:
29
- # Validate input text
30
- if not text or len(text.strip()) == 0:
31
- return "No text provided for summarization."
32
-
33
- # Ensure min_length is less than max_length
34
- min_length = min(min_length, max_length)
35
-
36
- # Generate summary
37
- summary = self.summarizer(
38
- text,
39
- max_length=max_length,
40
- min_length=min_length,
41
- do_sample=False
42
- )[0]['summary_text']
43
-
44
- return summary
45
-
46
- except Exception as e:
47
- return f"Error during summarization: {e}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import pipeline
2
+ import torch
3
+ import logging
4
+
5
+ class TextSummarizer:
6
+ def __init__(self, model_name="facebook/bart-large-cnn"):
7
+ """
8
+ Initialize summarization pipeline
9
+
10
+ Args:
11
+ model_name (str): Hugging Face model for summarization
12
+ """
13
+ try:
14
+ # Configure device
15
+ device = 0 if torch.cuda.is_available() else -1
16
+ logging.info(f"Using device: {'cuda' if device == 0 else 'cpu'}")
17
+
18
+ # Initialize pipeline with explicit device mapping and lower precision
19
+ self.summarizer = pipeline(
20
+ "summarization",
21
+ model=model_name,
22
+ device=device,
23
+ torch_dtype=torch.float32
24
+ )
25
+ logging.info("Summarization pipeline initialized successfully")
26
+
27
+ except Exception as e:
28
+ logging.error(f"Failed to load summarization model: {str(e)}")
29
+ raise RuntimeError(f"Failed to load summarization model: {str(e)}")
30
+
31
+ def generate_summary(self, text, max_length=400, min_length=100):
32
+ """
33
+ Generate summary for given text
34
+
35
+ Args:
36
+ text (str): Input text to summarize
37
+ max_length (int): Maximum length of summary
38
+ min_length (int): Minimum length of summary
39
+
40
+ Returns:
41
+ str: Generated summary
42
+ """
43
+ try:
44
+ # Validate input text
45
+ if not text or len(text.strip()) == 0:
46
+ return "No text provided for summarization."
47
+
48
+ # Ensure min_length is less than max_length
49
+ min_length = min(min_length, max_length)
50
+
51
+ # Generate summary with chunking for long texts
52
+ max_chunk_length = 1024 # BART's max input length
53
+ chunks = [text[i:i + max_chunk_length] for i in range(0, len(text), max_chunk_length)]
54
+ summaries = []
55
+
56
+ for chunk in chunks:
57
+ if chunk.strip():
58
+ summary = self.summarizer(
59
+ chunk,
60
+ max_length=max_length // len(chunks), # Distribute length across chunks
61
+ min_length=min_length // len(chunks),
62
+ do_sample=False
63
+ )[0]['summary_text']
64
+ summaries.append(summary)
65
+
66
+ return " ".join(summaries)
67
+
68
+ except Exception as e:
69
+ logging.error(f"Error during summarization: {str(e)}")
70
+ return f"Error during summarization: {str(e)}"