amritn8 commited on
Commit
ecf4e97
Β·
verified Β·
1 Parent(s): d80c9f5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +113 -80
app.py CHANGED
@@ -1,116 +1,149 @@
1
  import streamlit as st
2
  import torch
3
- from transformers import pipeline, BartForConditionalGeneration, BartTokenizer
 
 
 
4
  from PyPDF2 import PdfReader
5
  import docx
6
- import os
7
- from time import time
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
- # Configure environment
10
- cache_dir = os.path.join(os.getcwd(), "model_cache")
11
- os.makedirs(cache_dir, exist_ok=True)
12
- os.environ["TRANSFORMERS_CACHE"] = cache_dir
13
 
14
- # ----------------------------
15
  # MODEL LOADING
16
- # ----------------------------
17
- @st.cache_resource(show_spinner=False)
18
  def load_models():
19
- """Load all models with progress tracking"""
20
  models = {}
21
 
22
- with st.spinner("πŸš€ Loading QA Model..."):
23
- models['qa'] = pipeline(
24
- "question-answering",
25
- model="deepset/roberta-base-squad2",
26
- device=0 if torch.cuda.is_available() else -1
27
- )
28
-
29
- with st.spinner("πŸ“ Loading Summarization Model..."):
30
- models['summarizer'] = pipeline(
31
- "summarization",
32
- model="facebook/bart-large-cnn",
33
- tokenizer="facebook/bart-large-cnn",
34
- device=0 if torch.cuda.is_available() else -1
35
- )
 
 
 
 
 
 
36
 
37
  return models
38
 
39
  models = load_models()
40
 
41
- # ----------------------------
42
  # DOCUMENT PROCESSING
43
- # ----------------------------
44
  def extract_text(file):
45
  """Universal text extractor for PDF/DOCX"""
46
- if file.type == "application/pdf":
47
- reader = PdfReader(file)
48
- return " ".join([page.extract_text() for page in reader.pages])
49
- elif file.type == "application/vnd.openxmlformats-officedocument.wordprocessingml.document":
50
- doc = docx.Document(file)
51
- return "\n".join(para.text for para in doc.paragraphs if para.text)
52
- return ""
 
 
 
53
 
54
- # ----------------------------
55
- # SUMMARIZATION FUNCTION
56
- # ----------------------------
57
- def summarize(text, max_length=150, min_length=30):
58
- """Advanced summarization with chunking for long documents"""
 
 
 
59
  try:
60
- if len(text.split()) > 1000: # Chunking for large documents
61
  chunks = [text[i:i+3000] for i in range(0, len(text), 3000)]
62
  summaries = []
63
  for chunk in chunks:
64
- summary = models['summarizer'](
65
  chunk,
66
- max_length=max_length,
67
- min_length=min_length,
68
  do_sample=False
69
  )
70
- summaries.append(summary[0]['summary_text'])
71
  return " ".join(summaries)
72
- return models['summarizer'](text, max_length=max_length, min_length=min_length)[0]['summary_text']
73
  except Exception as e:
74
- st.error(f"Summarization error: {str(e)}")
75
  return ""
76
 
77
- # ----------------------------
78
  # STREAMLIT UI
79
- # ----------------------------
80
- st.title("πŸ“š Document Intelligence Suite")
 
81
 
82
- # Main Document Input
83
- with st.expander("πŸ“„ Upload Document", expanded=True):
84
  uploaded_file = st.file_uploader("Choose PDF/DOCX", type=["pdf", "docx"])
85
- manual_text = st.text_area("Or paste raw text here:", height=150)
86
  context = extract_text(uploaded_file) if uploaded_file else manual_text
87
 
88
- # ----------------------------
89
- # ADVANCED FEATURES
90
- # ----------------------------
91
- with st.expander("πŸ”§ Advanced Tools", expanded=False):
92
- st.header("πŸ“ Document Summarization")
93
-
94
- if st.button("Generate Summary"):
95
- if not context:
96
- st.warning("Please provide content first")
97
- else:
98
- with st.spinner("Analyzing document..."):
99
- start_time = time()
100
- summary = summarize(context)
101
- st.success(f"Generated in {time()-start_time:.1f}s")
102
- st.markdown(f"**Summary:**\n\n{summary}")
103
-
104
- st.header("βš™οΈ Customization")
105
- max_len = st.slider("Summary Length", 50, 300, 150)
106
- show_chunks = st.checkbox("Show processing chunks", False)
 
 
 
 
 
 
107
 
108
- # Question Answering Section
109
- if context:
110
- st.header("❓ Question Answering")
111
- question = st.text_input("Ask about the document:")
112
- if question:
113
- with st.spinner("Searching for answers..."):
114
- result = models['qa'](question=question, context=context[:100000]) # 100k char limit
115
- st.markdown(f"**Answer:** {result['answer']}")
116
- st.caption(f"Confidence: {result['score']:.0%}")
 
1
  import streamlit as st
2
  import torch
3
+ from transformers import pipeline
4
+ import os
5
+ import shutil
6
+ from pathlib import Path
7
  from PyPDF2 import PdfReader
8
  import docx
9
+ import time
10
+
11
+ # ======================
12
+ # CACHE CONFIGURATION
13
+ # ======================
14
+ def setup_environment():
15
+ """Configure cache with guaranteed write permissions"""
16
+ cache_dir = Path("/tmp/model_cache")
17
+ lock_dir = cache_dir / ".locks"
18
+
19
+ # Clear any existing locks
20
+ if lock_dir.exists():
21
+ shutil.rmtree(lock_dir, ignore_errors=True)
22
+
23
+ cache_dir.mkdir(exist_ok=True, parents=True)
24
+ os.environ["TRANSFORMERS_CACHE"] = str(cache_dir)
25
+ os.environ["HF_HOME"] = str(cache_dir)
26
+ return cache_dir
27
 
28
+ cache_dir = setup_environment()
 
 
 
29
 
30
+ # ======================
31
  # MODEL LOADING
32
+ # ======================
33
+ @st.cache_resource(ttl=3600) # Cache for 1 hour
34
  def load_models():
35
+ """Load all NLP models with error recovery"""
36
  models = {}
37
 
38
+ try:
39
+ # Question Answering
40
+ with st.spinner("πŸ” Loading QA Model..."):
41
+ models['qa'] = pipeline(
42
+ "question-answering",
43
+ model="deepset/roberta-base-squad2",
44
+ device=0 if torch.cuda.is_available() else -1
45
+ )
46
+
47
+ # Summarization
48
+ with st.spinner("πŸ“ Loading Summarizer..."):
49
+ models['summarizer'] = pipeline(
50
+ "summarization",
51
+ model="facebook/bart-large-cnn",
52
+ device=0 if torch.cuda.is_available() else -1
53
+ )
54
+
55
+ except Exception as e:
56
+ st.error(f"❌ Model loading failed: {str(e)}")
57
+ st.stop()
58
 
59
  return models
60
 
61
  models = load_models()
62
 
63
+ # ======================
64
  # DOCUMENT PROCESSING
65
+ # ======================
66
  def extract_text(file):
67
  """Universal text extractor for PDF/DOCX"""
68
+ try:
69
+ if file.type == "application/pdf":
70
+ reader = PdfReader(file)
71
+ return " ".join(page.extract_text() for page in reader.pages if page.extract_text())
72
+ elif file.type == "application/vnd.openxmlformats-officedocument.wordprocessingml.document":
73
+ doc = docx.Document(file)
74
+ return "\n".join(para.text for para in doc.paragraphs if para.text)
75
+ except Exception as e:
76
+ st.error(f"Document processing error: {str(e)}")
77
+ return ""
78
 
79
+ # ======================
80
+ # CORE FUNCTIONS
81
+ # ======================
82
+ def generate_summary(text, max_length=150):
83
+ """Chunk-aware summarization"""
84
+ if not text:
85
+ return ""
86
+
87
  try:
88
+ if len(text) > 10000: # Chunk large documents
89
  chunks = [text[i:i+3000] for i in range(0, len(text), 3000)]
90
  summaries = []
91
  for chunk in chunks:
92
+ result = models['summarizer'](
93
  chunk,
94
+ max_length=max_length//len(chunks),
95
+ min_length=30,
96
  do_sample=False
97
  )
98
+ summaries.append(result[0]['summary_text'])
99
  return " ".join(summaries)
100
+ return models['summarizer'](text, max_length=max_length)[0]['summary_text']
101
  except Exception as e:
102
+ st.error(f"Summarization failed: {str(e)}")
103
  return ""
104
 
105
+ # ======================
106
  # STREAMLIT UI
107
+ # ======================
108
+ st.set_page_config(page_title="DocAnalyzer Pro", layout="wide")
109
+ st.title("πŸ“„ Document Analyzer Pro")
110
 
111
+ # File Upload
112
+ with st.expander("πŸ“€ Upload Document", expanded=True):
113
  uploaded_file = st.file_uploader("Choose PDF/DOCX", type=["pdf", "docx"])
114
+ manual_text = st.text_area("Or paste text here:", height=200)
115
  context = extract_text(uploaded_file) if uploaded_file else manual_text
116
 
117
+ # Main Features
118
+ tab1, tab2 = st.tabs(["πŸ” Question Answering", "πŸ“ Summarization"])
119
+
120
+ with tab1:
121
+ if context:
122
+ question = st.text_input("Ask about the document:")
123
+ if question:
124
+ with st.spinner("Analyzing..."):
125
+ start = time.time()
126
+ result = models['qa'](question=question, context=context[:100000])
127
+ st.success(f"Answered in {time.time()-start:.1f}s")
128
+ st.markdown(f"**Answer:** {result['answer']}")
129
+ st.progress(result['score'])
130
+ st.caption(f"Confidence: {result['score']:.0%}")
131
+
132
+ with tab2:
133
+ if context:
134
+ with st.form("summary_form"):
135
+ length = st.slider("Summary Length", 50, 300, 150)
136
+ if st.form_submit_button("Generate Summary"):
137
+ with st.spinner("Summarizing..."):
138
+ start = time.time()
139
+ summary = generate_summary(context, length)
140
+ st.success(f"Generated in {time.time()-start:.1f}s")
141
+ st.markdown(f"**Summary:**\n\n{summary}")
142
 
143
+ # Debug Info
144
+ with st.expander("βš™οΈ System Info"):
145
+ st.code(f"""
146
+ Cache directory: {cache_dir}
147
+ Device: {'GPU βœ…' if torch.cuda.is_available() else 'CPU ⚠️'}
148
+ Models loaded: {', '.join(models.keys())}
149
+ """)