Anupam007 commited on
Commit
e036bfd
·
verified ·
1 Parent(s): ac1f261

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +273 -0
app.py ADDED
@@ -0,0 +1,273 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
4
+ from langchain_community.llms import HuggingFacePipeline
5
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
6
+ from langchain_community.embeddings import HuggingFaceEmbeddings
7
+ from langchain_community.vectorstores import FAISS
8
+ from langchain.chains import RetrievalQA
9
+ from langchain.prompts import PromptTemplate
10
+ import warnings
11
+ import os
12
+
13
+ # Suppress warnings
14
+ warnings.filterwarnings("ignore")
15
+ os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'
16
+
17
+ # Model Configuration
18
+ MODEL_NAME = "gpt2"
19
+ EMBEDDING_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
20
+
21
+ def initialize_models():
22
+ """Initialize language model and embedding model."""
23
+ try:
24
+ # Determine device
25
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
26
+ print(f"Using device: {device}")
27
+
28
+ # Load model and tokenizer
29
+ model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
30
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
31
+
32
+ # Create pipeline
33
+ text_generation_pipeline = pipeline(
34
+ "text-generation",
35
+ model=model,
36
+ tokenizer=tokenizer,
37
+ max_new_tokens=512,
38
+ temperature=0.7,
39
+ repetition_penalty=1.1
40
+ )
41
+
42
+ # Langchain LLM
43
+ llm = HuggingFacePipeline(pipeline=text_generation_pipeline)
44
+
45
+ # Embedding model
46
+ embedding_model = HuggingFaceEmbeddings(
47
+ model_name=EMBEDDING_MODEL,
48
+ model_kwargs={'device': str(device)}
49
+ )
50
+
51
+ return llm, embedding_model, model, tokenizer
52
+
53
+ except Exception as e:
54
+ print(f"Model initialization error: {e}")
55
+ return None, None, None, None
56
+
57
+ # Initialize models
58
+ llm, embedding_model, model, tokenizer = initialize_models()
59
+
60
+ # Global variables for RAG state
61
+ rag_retriever = None
62
+ document_loaded = False
63
+ loaded_doc_name = "No document loaded"
64
+
65
+ def setup_rag_pipeline(doc_text, chunk_size=1000, chunk_overlap=150):
66
+ """Loads text, chunks, embeds, creates FAISS index, and sets up retriever."""
67
+ global rag_retriever, document_loaded, loaded_doc_name
68
+
69
+ if not doc_text or not isinstance(doc_text, str) or len(doc_text.strip()) == 0:
70
+ return "Error: No text provided or invalid input."
71
+
72
+ try:
73
+ # Text splitting
74
+ text_splitter = RecursiveCharacterTextSplitter(
75
+ chunk_size=chunk_size,
76
+ chunk_overlap=chunk_overlap,
77
+ length_function=len,
78
+ )
79
+ docs = text_splitter.split_text(doc_text)
80
+
81
+ if not docs:
82
+ return "Error: Text splitting resulted in no documents."
83
+
84
+ # Create embeddings and FAISS index
85
+ vector_store = FAISS.from_texts(docs, embedding_model)
86
+ rag_retriever = vector_store.as_retriever(search_kwargs={"k": 3})
87
+
88
+ document_loaded = True
89
+ loaded_doc_name = f"Document processed ({len(doc_text)} chars, {len(docs)} chunks)."
90
+ return loaded_doc_name
91
+
92
+ except Exception as e:
93
+ document_loaded = False
94
+ rag_retriever = None
95
+ return f"Error processing document: {e}"
96
+
97
+ def answer_question(question):
98
+ """Answers a question using the loaded RAG pipeline."""
99
+ if llm is None or embedding_model is None:
100
+ return "Error: Models not initialized properly."
101
+
102
+ if not document_loaded or rag_retriever is None:
103
+ return "Error: Please load a document before asking questions."
104
+
105
+ if not question or not isinstance(question, str) or len(question.strip()) == 0:
106
+ return "Error: Please enter a question."
107
+
108
+ try:
109
+ # Define a prompt template
110
+ template = """You are a helpful assistant answering questions based on the provided context.
111
+ Use only the information given in the context below to answer the question.
112
+ If the context doesn't contain the answer, say "The provided context does not contain the answer to this question."
113
+ Be concise.
114
+
115
+ Context:
116
+ {context}
117
+
118
+ Question: {question}
119
+ Answer:"""
120
+
121
+ QA_CHAIN_PROMPT = PromptTemplate(
122
+ input_variables=["context", "question"],
123
+ template=template,
124
+ )
125
+
126
+ # Create RetrievalQA chain
127
+ qa_chain = RetrievalQA.from_chain_type(
128
+ llm=llm,
129
+ chain_type="stuff",
130
+ retriever=rag_retriever,
131
+ chain_type_kwargs={"prompt": QA_CHAIN_PROMPT},
132
+ return_source_documents=False
133
+ )
134
+
135
+ result = qa_chain.invoke({"query": question})
136
+ answer = result.get("result", str(result)) if isinstance(result, dict) else str(result)
137
+
138
+ return answer.strip()
139
+
140
+ except Exception as e:
141
+ return f"Error answering question: {e}"
142
+
143
+ def summarize_text(text_to_summarize, max_length=150, min_length=30):
144
+ """Summarizes the provided text using the LLM."""
145
+ if llm is None:
146
+ return "Error: Models not initialized properly."
147
+
148
+ if not text_to_summarize or not isinstance(text_to_summarize, str) or len(text_to_summarize.strip()) == 0:
149
+ return "Error: Please enter text to summarize."
150
+
151
+ try:
152
+ # Create a prompt for summarization
153
+ prompt = f"Summarize the following text concisely, aiming for {min_length} to {max_length} words:\n\n{text_to_summarize}"
154
+
155
+ # Use the pipeline directly for summarization
156
+ summary_pipeline = pipeline(
157
+ "text-generation",
158
+ model=model,
159
+ tokenizer=tokenizer,
160
+ max_new_tokens=max_length,
161
+ temperature=0.5
162
+ )
163
+
164
+ # Generate summary
165
+ summary_result = summary_pipeline(prompt, max_length=max_length)[0]['generated_text']
166
+
167
+ # Extract the actual summary part
168
+ summary = summary_result.replace(prompt, '').strip()
169
+
170
+ return summary
171
+
172
+ except Exception as e:
173
+ return f"Error summarizing text: {e}"
174
+
175
+ def draft_text(instructions):
176
+ """Drafts text based on user instructions using the LLM."""
177
+ if llm is None:
178
+ return "Error: Models not initialized properly."
179
+
180
+ if not instructions or not isinstance(instructions, str) or len(instructions.strip()) == 0:
181
+ return "Error: Please enter drafting instructions."
182
+
183
+ try:
184
+ # Drafting prompt
185
+ prompt = f"Write the following based on these instructions:\n\n{instructions}"
186
+
187
+ # Use the pipeline for text generation
188
+ draft_pipeline = pipeline(
189
+ "text-generation",
190
+ model=model,
191
+ tokenizer=tokenizer,
192
+ max_new_tokens=500,
193
+ temperature=0.7
194
+ )
195
+
196
+ # Generate draft
197
+ draft_result = draft_pipeline(prompt, max_length=500)[0]['generated_text']
198
+
199
+ # Extract the actual draft part
200
+ draft = draft_result.replace(prompt, '').strip()
201
+
202
+ return draft
203
+
204
+ except Exception as e:
205
+ return f"Error drafting text: {e}"
206
+
207
+ # Gradio Interface
208
+ def create_gradio_interface():
209
+ with gr.Blocks(theme=gr.themes.Soft()) as iface:
210
+ gr.Markdown("# Workplace Assistant (GPT-2 Demo)")
211
+ gr.Markdown("Powered by GPT-2 and Langchain")
212
+
213
+ with gr.Tabs():
214
+ # Document Q&A Tab
215
+ with gr.TabItem("Document Q&A (RAG)"):
216
+ gr.Markdown("Load text content from a document, then ask questions about it.")
217
+ doc_input = gr.Textbox(label="Paste Document Text Here", lines=10, placeholder="Paste the full text content you want to query...")
218
+ load_button = gr.Button("Process Document")
219
+ status_output = gr.Textbox(label="Document Status", value=loaded_doc_name, interactive=False)
220
+ question_input = gr.Textbox(label="Your Question", placeholder="Ask a question about the document...")
221
+ ask_button = gr.Button("Ask Question")
222
+ answer_output = gr.Textbox(label="Answer", lines=5, interactive=False)
223
+
224
+ load_button.click(
225
+ fn=setup_rag_pipeline,
226
+ inputs=[doc_input],
227
+ outputs=[status_output]
228
+ )
229
+ ask_button.click(
230
+ fn=answer_question,
231
+ inputs=[question_input],
232
+ outputs=[answer_output]
233
+ )
234
+
235
+ # Summarization Tab
236
+ with gr.TabItem("Summarization"):
237
+ gr.Markdown("Paste text to get a concise summary.")
238
+ summarize_input = gr.Textbox(label="Text to Summarize", lines=10, placeholder="Paste text here...")
239
+ summarize_button = gr.Button("Summarize")
240
+ summary_output = gr.Textbox(label="Summary", lines=5, interactive=False)
241
+
242
+ with gr.Accordion("Advanced Options", open=False):
243
+ max_len_slider = gr.Slider(minimum=20, maximum=300, value=150, step=10, label="Max Summary Length (approx words)")
244
+ min_len_slider = gr.Slider(minimum=10, maximum=100, value=30, step=5, label="Min Summary Length (approx words)")
245
+
246
+ summarize_button.click(
247
+ fn=summarize_text,
248
+ inputs=[summarize_input, max_len_slider, min_len_slider],
249
+ outputs=[summary_output]
250
+ )
251
+
252
+ # Drafting Tab
253
+ with gr.TabItem("Drafting Assistant"):
254
+ gr.Markdown("Provide instructions for the AI to draft text.")
255
+ draft_instructions = gr.Textbox(label="Drafting Instructions", lines=5, placeholder="e.g., Draft a short, friendly email to the team.")
256
+ draft_button = gr.Button("Generate Draft")
257
+ draft_output = gr.Textbox(label="Generated Draft", lines=10, interactive=False)
258
+
259
+ draft_button.click(
260
+ fn=draft_text,
261
+ inputs=[draft_instructions],
262
+ outputs=[draft_output]
263
+ )
264
+
265
+ return iface
266
+
267
+ # Launch the interface
268
+ if __name__ == "__main__":
269
+ try:
270
+ iface = create_gradio_interface()
271
+ iface.launch(share=True, debug=True)
272
+ except Exception as e:
273
+ print(f"Error launching Gradio interface: {e}")