Lesterchia1 commited on
Commit
38b607b
·
verified ·
1 Parent(s): cabc4fa

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +165 -267
app.py CHANGED
@@ -1,305 +1,203 @@
1
- # -*- coding: utf-8 -*-
2
- """App
3
-
4
- Automatically generated by Colab.
5
-
6
- Original file is located at
7
- https://colab.research.google.com/drive/1TdjbTSA8V5GUProQ3Bd-uYmTLXSInoWf
8
- """
9
-
10
- import gradio as gr
11
- import numpy as np
12
- from transformers import pipeline
13
  import os
14
- import time
15
- import groq
16
- import uuid # For generating unique filenames
17
-
18
- # Updated imports to address LangChain deprecation warnings:
 
 
 
 
 
 
 
 
 
19
  from langchain_groq import ChatGroq
20
- from langchain.schema import HumanMessage
21
- from langchain.text_splitter import RecursiveCharacterTextSplitter
22
- from langchain_community.vectorstores import Chroma
23
  from langchain_community.embeddings import HuggingFaceEmbeddings
24
- from langchain.docstore.document import Document
25
-
26
- # Importing chardet (make sure to add chardet to your requirements.txt)
27
- import chardet
28
-
29
- import fitz # PyMuPDF for PDFs
30
- import docx # python-docx for Word files
31
- import gtts # Google Text-to-Speech library
32
- from pptx import Presentation # python-pptx for PowerPoint files
33
- import re
34
-
35
- # Initialize Whisper model for speech-to-text
36
- transcriber = pipeline("automatic-speech-recognition", model="openai/whisper-base.en")
37
-
38
- # Set API Key (Ensure it's stored securely in an environment variable)
39
- groq.api_key = os.getenv("GROQ_API_KEY") # Replace with a valid API key
40
 
41
- # Initialize Chat Model
42
- chat_model = ChatGroq(model_name="llama-3.1-8b-instant", api_key=groq.api_key) #llama-3.3-70b-versatile
43
 
44
- # Initialize Embeddings and chromaDB
45
- embedding_model = HuggingFaceEmbeddings()
46
- vectorstore = Chroma(embedding_function=embedding_model)
 
47
 
48
- # Short-term memory for the LLM
49
- chat_memory = []
50
 
51
- # Prompt for quiz generation with added remark
52
- quiz_prompt = """
53
- You are an AI assistant specialized in education and assessment creation. Given an uploaded document or text, generate a quiz with a mix of multiple-choice questions (MCQs) and fill-in-the-blank questions. The quiz should be directly based on the key concepts, facts, and details from the provided material.
54
- Remove all unnecessary formatting generated by the LLM, including <think> tags, asterisks, markdown formatting, and any bold or italic text, as well as **, ###, ##, and # tags.
55
- For each question:
56
- - Provide 4 answer choices (for MCQs), with only one correct answer.
57
- - Ensure fill-in-the-blank questions focus on key terms, phrases, or concepts from the document.
58
- - Include an answer key for all questions.
59
- - Ensure questions vary in difficulty and encourage comprehension rather than memorization.
60
- - Additionally, implement an instant feedback mechanism:
61
- - When a user selects an answer, indicate whether it is correct or incorrect.
62
- - If incorrect, provide a brief explanation from the document to guide learning.
63
- - Ensure responses are concise and educational to enhance understanding.
64
- Output Example:
65
- 1. Fill in the blank: The LLM Agent framework has a central decision-making unit called the _______________________.
66
- Answer: Agent Core
67
- Feedback: The Agent Core is the central component of the LLM Agent framework, responsible for managing goals, tool instructions, planning modules, memory integration, and agent persona.
68
- 2. What is the main limitation of LLM-based applications?
69
- a) Limited token capacity
70
- b) Lack of domain expertise
71
- c) Prone to hallucination
72
- d) All of the above
73
- Answer: d) All of the above
74
- Feedback: LLM-based applications have several limitations, including limited token capacity, lack of domain expertise, and being prone to hallucination, among others.
75
- """
76
 
77
- # Function to clean AI response by removing unwanted formatting
78
  def clean_response(response):
79
- """Removes <think> tags, asterisks, and markdown formatting."""
80
- cleaned_text = re.sub(r"<think>.*?</think>", "", response, flags=re.DOTALL)
81
- cleaned_text = re.sub(r"(\*\*|\*)", "", cleaned_text)
82
- cleaned_text = re.sub(r"^#+\s*", "", cleaned_text, flags=re.MULTILINE)
83
- cleaned_text = re.sub(r"\\", "", cleaned_text)
84
- return cleaned_text.strip()
 
85
 
86
- # Function to generate quiz based on content
87
- def generate_quiz(content):
88
- prompt = f"{quiz_prompt}\n\nDocument content:\n{content}"
89
- response = chat_model([HumanMessage(content=prompt)])
90
- cleaned_response = clean_response(response.content)
91
- return cleaned_response
92
 
93
- # Function to retrieve relevant documents from vectorstore based on user query
94
  def retrieve_documents(query):
95
  results = vectorstore.similarity_search(query, k=3)
96
  return [doc.page_content for doc in results]
97
 
98
- # Function to handle chatbot interactions with short-term memory
99
- def chat_with_groq(user_input):
100
- try:
101
- # Retrieve relevant documents for additional context
102
- relevant_docs = retrieve_documents(user_input)
103
- context = "\n".join(relevant_docs) if relevant_docs else "No relevant documents found."
104
-
105
- # Construct proper prompting with conversation history
106
- system_prompt = "You are a helpful AI assistant. Answer questions accurately and concisely."
107
- conversation_history = "\n".join(chat_memory[-10:]) # Keep the last 10 exchanges
108
- prompt = f"{system_prompt}\n\nConversation History:\n{conversation_history}\n\nUser Input: {user_input}\n\nContext:\n{context}"
109
-
110
- # Call the chat model
111
- response = chat_model([HumanMessage(content=prompt)])
112
-
113
- # Clean response to remove any unwanted formatting
114
- cleaned_response_text = clean_response(response.content)
115
-
116
- # Append conversation history
117
- chat_memory.append(f"User: {user_input}")
118
- chat_memory.append(f"AI: {cleaned_response_text}")
119
-
120
- # Convert response to speech
121
- audio_file = speech_playback(cleaned_response_text)
122
-
123
- return cleaned_response_text, audio_file
124
- except Exception as e:
125
- return f"Error: {str(e)}", None
126
-
127
- # Function to play response as speech using gTTS
128
  def speech_playback(text):
129
  try:
130
- # Generate a unique filename for each audio file
131
  unique_id = str(uuid.uuid4())
132
- audio_file = f"output_audio_{unique_id}.mp3"
133
-
134
- # Convert text to speech
135
- tts = gtts.gTTS(text, lang='en')
136
  tts.save(audio_file)
137
-
138
- # Return the path to the audio file
139
  return audio_file
140
  except Exception as e:
141
- print(f"Error in speech_playback: {e}")
142
  return None
143
 
144
- # Function to detect encoding safely
145
- def detect_encoding(file_path):
146
- try:
147
- with open(file_path, "rb") as f:
148
- raw_data = f.read(4096)
149
- detected = chardet.detect(raw_data)
150
- encoding = detected["encoding"]
151
- return encoding if encoding else "utf-8"
152
- except Exception:
153
- return "utf-8"
154
 
155
- # Function to extract text from PDF
156
- def extract_text_from_pdf(pdf_path):
157
  try:
158
- doc = fitz.open(pdf_path)
159
- text = "\n".join([page.get_text("text") for page in doc])
160
- return text if text.strip() else "No extractable text found."
161
- except Exception as e:
162
- return f"Error extracting text from PDF: {str(e)}"
163
-
164
- # Function to extract text from Word files (.docx)
165
- def extract_text_from_docx(docx_path):
166
- try:
167
- doc = docx.Document(docx_path)
168
- text = "\n".join([para.text for para in doc.paragraphs])
169
- return text if text.strip() else "No extractable text found."
170
- except Exception as e:
171
- return f"Error extracting text from Word document: {str(e)}"
172
-
173
- # Function to extract text from PowerPoint files (.pptx)
174
- def extract_text_from_pptx(pptx_path):
175
- try:
176
- presentation = Presentation(pptx_path)
177
- text = ""
178
- for slide in presentation.slides:
179
- for shape in slide.shapes:
180
- if hasattr(shape, "text"):
181
- text += shape.text + "\n"
182
- return text if text.strip() else "No extractable text found."
183
- except Exception as e:
184
- return f"Error extracting text from PowerPoint: {str(e)}"
185
-
186
- # Function to process documents safely
187
- #def process_document(file):
188
- # try:
189
- # file_extension = os.path.splitext(file.name)[-1].lower()
190
- # if file_extension in [".png", ".jpg", ".jpeg"]:
191
- # return "Error: Images cannot be processed for text extraction."
192
- # if file_extension == ".pdf":
193
- # content = extract_text_from_pdf(file.name)
194
- # elif file_extension == ".docx":
195
- # content = extract_text_from_docx(file.name)
196
- # elif file_extension == ".pptx":
197
- # content = extract_text_from_pptx(file.name)
198
- # else:
199
- # encoding = detect_encoding(file.name)
200
- # with open(file.name, "r", encoding=encoding, errors="replace") as f:
201
- # content = f.read()
202
- # text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
203
- # documents = [Document(page_content=chunk) for chunk in text_splitter.split_text(content)]
204
- # vectorstore.add_documents(documents)
205
- # quiz = generate_quiz(content)
206
- # return f"Document processed successfully (File Type: {file_extension}). Quiz generated:\n{quiz}"
207
- # except Exception as e:
208
- # return f"Error processing document: {str(e)}"
209
-
210
- def process_document(file):
211
- try:
212
- if not file or not hasattr(file, "name") or not isinstance(file.name, str):
213
- return "Error: Invalid file uploaded."
214
-
215
- file_extension = os.path.splitext(file.name)[-1].lower()
216
-
217
- if file_extension in [".png", ".jpg", ".jpeg"]:
218
- return "Error: Images cannot be processed for text extraction."
219
-
220
- if file_extension == ".pdf":
221
- content = extract_text_from_pdf(file.name)
222
- elif file_extension == ".docx":
223
- content = extract_text_from_docx(file.name)
224
- elif file_extension == ".pptx":
225
- content = extract_text_from_pptx(file.name)
226
  else:
227
- encoding = detect_encoding(file.name)
228
- with open(file.name, "r", encoding=encoding, errors="replace") as f:
229
- content = f.read()
 
 
 
 
230
 
231
- text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
232
- documents = [Document(page_content=chunk) for chunk in text_splitter.split_text(content)]
 
233
  vectorstore.add_documents(documents)
234
- quiz = generate_quiz(content)
235
- return f"Document processed successfully (File Type: {file_extension}). Quiz generated:\n{quiz}"
236
-
237
  except Exception as e:
238
- return f"Error processing document: {str(e)}"
 
239
 
 
240
 
 
 
 
 
 
241
 
242
- # Function to handle speech-to-text conversion
243
- def transcribe_audio(audio):
244
- sr, y = audio
245
- if y.ndim > 1:
246
- y = y.mean(axis=1)
247
- y = y.astype(np.float32)
248
- y /= np.max(np.abs(y))
249
- return transcriber({"sampling_rate": sr, "raw": y})["text"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
250
 
251
- # Your cleanup function
252
- def cleanup_old_files(directory=".", age_limit=60):
253
- """Delete files older than `age_limit` seconds."""
254
- current_time = time.time()
255
- for filename in os.listdir(directory):
256
- file_path = os.path.join(directory, filename)
257
- if os.path.isfile(file_path) and filename.startswith("output_audio_"):
258
- file_age = current_time - os.path.getmtime(file_path)
259
- if file_age > age_limit:
260
- os.remove(file_path)
261
 
 
 
 
 
 
 
 
 
 
 
 
262
 
263
- # Gradio UI with Video Clip
264
  with gr.Blocks() as demo:
265
- gr.HTML("<h2 style='text-align: center;'>AI Tutor - We.</h2>")
266
 
267
- # Align image and video side by side
268
- with gr.Row():
269
- with gr.Column(scale=1): # Adjust scale to control width ratio
270
- gr.HTML("""
271
- <div style="text-align: center; margin-bottom: 20px;">
272
- <img src="https://img.freepik.com/premium-photo/little-girl-is-seen-sitting-front-laptop-computer-engaged-with-nearby-robot-robot-assistant-helping-child-with-homework-ai-generated_585735-12266.jpg"
273
- style="width: 100%; height: auto; border-radius: 10px; box-shadow: 0 4px 8px rgba(0,0,0,0.2);" />
274
- </div>
275
- """)
276
-
277
- #with gr.Column(scale=1): # Adjust scale for equal width
278
- gr.Video("https://github.com/lesterchia1/AI_tutor/raw/main/We%20not%20me%20video.mp4", label="Introduction Video")
279
 
280
-
281
- # Add other UI elements below
282
- with gr.Row():
283
- with gr.Column():
284
- audio_input = gr.Audio(type="numpy", label="Record Audio")
285
- transcription_output = gr.Textbox(label="Transcription")
286
- user_input = gr.Textbox(label="Ask a question")
287
- chat_output = gr.Textbox(label="Response")
288
- audio_output = gr.Audio(label="Audio Playback")
289
- submit_btn = gr.Button("Ask")
290
- with gr.Column():
291
- file_upload = gr.File(label="Upload a document")
292
- process_status = gr.Textbox(label="Processing Status")
293
-
294
- # Define button actions
295
- submit_btn.click(chat_with_groq, inputs=user_input, outputs=[chat_output, audio_output])
296
- audio_input.change(transcribe_audio, inputs=audio_input, outputs=transcription_output)
297
- transcription_output.change(fn=lambda x: x, inputs=transcription_output, outputs=user_input)
298
- file_upload.change(process_document, inputs=file_upload, outputs=process_status)
299
-
300
- # Add cleanup function to be triggered periodically (e.g., every time a button is clicked or after certain actions)
301
- #demo.load(lambda: cleanup_old_files(directory="./", age_limit=60), inputs=[], outputs=[])
302
- demo.load(lambda: [], inputs=[], outputs=[])
303
-
304
-
305
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
+ import re
3
+ import uuid
4
+ import tempfile
5
+ import numpy as np
6
+ import gradio as gr
7
+ import chardet
8
+ import fitz # PyMuPDF
9
+ import docx
10
+ import gtts
11
+ from pptx import Presentation
12
+ from typing import TypedDict, List
13
+ from langchain_community.tools import DuckDuckGoSearchRun
14
+ from langchain_core.messages import BaseMessage, HumanMessage, AIMessage, SystemMessage
15
+ from langgraph.graph import StateGraph, END
16
  from langchain_groq import ChatGroq
 
 
 
17
  from langchain_community.embeddings import HuggingFaceEmbeddings
18
+ from langchain_community.vectorstores import Chroma
19
+ from langchain_text_splitters import RecursiveCharacterTextSplitter
20
+ from langchain_core.documents import Document
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
+ # --- 1. INITIALIZATION & CORE TOOLS ---
23
+ groq_api_key = os.getenv("GROQ_API_KEY")
24
 
25
+ chat_model = ChatGroq(model_name="llama-3.3-70b-versatile", api_key=groq_api_key)
26
+ web_search_tool = DuckDuckGoSearchRun()
27
+ embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
28
+ vectorstore = Chroma(embedding_function=embedding_model, persist_directory="chroma_db")
29
 
30
+ # --- 2. HELPER FUNCTIONS ---
 
31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
 
33
  def clean_response(response):
34
+ """Remove <think>...</think> blocks and common markdown artifacts."""
35
+ # Remove think tags and their content (greedily, case-insensitive)
36
+ cleaned = re.sub(r"<think>.*?(?:</think>|$)", "", response, flags=re.DOTALL | re.IGNORECASE)
37
+ # Remove stray closing tags and markdown symbols
38
+ cleaned = re.sub(r"</?think>|\*\*|\*|\[|\]|#", "", cleaned)
39
+ return cleaned.strip()
40
+ #return cleaned_text.strip()
41
 
 
 
 
 
 
 
42
 
 
43
  def retrieve_documents(query):
44
  results = vectorstore.similarity_search(query, k=3)
45
  return [doc.page_content for doc in results]
46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  def speech_playback(text):
48
  try:
 
49
  unique_id = str(uuid.uuid4())
50
+ audio_file = f"/content/output_audio_{unique_id}.mp3"
51
+ tts = gtts.gTTS(text[:500], lang='en')
 
 
52
  tts.save(audio_file)
 
 
53
  return audio_file
54
  except Exception as e:
55
+ print(f"TTS error: {e}")
56
  return None
57
 
58
+ # --- 3. DOCUMENT INGESTION FUNCTION ---
59
+ def extract_and_store_document(file_path: str):
60
+ text = ""
61
+ file_ext = os.path.splitext(file_path)[1].lower()
 
 
 
 
 
 
62
 
 
 
63
  try:
64
+ if file_ext == ".pdf":
65
+ doc = fitz.open(file_path)
66
+ for page in doc:
67
+ text += page.get_text()
68
+ doc.close()
69
+ elif file_ext == ".docx":
70
+ doc = docx.Document(file_path)
71
+ text = "\n".join([para.text for para in doc.paragraphs])
72
+ elif file_ext == ".pptx":
73
+ prs = Presentation(file_path)
74
+ for slide in prs.slides:
75
+ for shape in slide.shapes:
76
+ if hasattr(shape, "text"):
77
+ text += shape.text + "\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  else:
79
+ with open(file_path, 'rb') as f:
80
+ raw_data = f.read()
81
+ encoding = chardet.detect(raw_data)['encoding'] or 'utf-8'
82
+ text = raw_data.decode(encoding, errors='ignore')
83
+
84
+ if not text.strip():
85
+ return False
86
 
87
+ splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
88
+ chunks = splitter.split_text(text)
89
+ documents = [Document(page_content=chunk, metadata={"source": os.path.basename(file_path)}) for chunk in chunks]
90
  vectorstore.add_documents(documents)
91
+ vectorstore.persist()
92
+ return True
93
+
94
  except Exception as e:
95
+ print(f"Error processing {file_path}: {e}")
96
+ return False
97
 
98
+ # --- 4. REFRAG MULTI-AGENT LOGIC (LangGraph) ---
99
 
100
+ class AgentState(TypedDict):
101
+ messages: List[BaseMessage]
102
+ context: str
103
+ decision: str
104
+ source: str
105
 
106
+ def sensing_node(state: AgentState):
107
+ user_query = state["messages"][-1].content
108
+ relevant_docs = retrieve_documents(user_query)
109
+ context = "\n".join(relevant_docs) if relevant_docs else ""
110
+
111
+ prompt = f"Docs: {context}\nQuery: {user_query}\nIf docs answer this, reply 'RAG'. Else reply 'WEB'."
112
+ decision = chat_model.invoke([HumanMessage(content=prompt)]).content.strip().upper()
113
+ return {"context": context, "decision": "RAG" if "RAG" in decision else "WEB"}
114
+
115
+ def expansion_node(state: AgentState):
116
+ if state["decision"] == "WEB":
117
+ user_query = state["messages"][-1].content
118
+ web_data = web_search_tool.run(user_query)
119
+ return {"context": f"WEB INFO: {web_data}\nLOCAL: {state['context']}", "source": "Web + Local Documents"}
120
+ return {"source": "Local Documents Only"}
121
+
122
+ def generation_node(state: AgentState):
123
+ system_msg = f"You are a Tutor AI. Use this context: {state['context']}"
124
+ response = chat_model.invoke([SystemMessage(content=system_msg)] + state["messages"])
125
+ cleaned = clean_response(response.content)
126
+ return {"messages": [AIMessage(content=f"{cleaned}\n\n*(Verified via: {state['source']})*")]}
127
+
128
+ workflow = StateGraph(AgentState)
129
+ workflow.add_node("sense", sensing_node)
130
+ workflow.add_node("expand", expansion_node)
131
+ workflow.add_node("generate", generation_node)
132
+ workflow.set_entry_point("sense")
133
+ workflow.add_edge("sense", "expand")
134
+ workflow.add_edge("expand", "generate")
135
+ workflow.add_edge("generate", END)
136
+ app_agent = workflow.compile()
137
+
138
+ # --- 5. GRADIO APP WITH MANUAL AUDIO ---
139
+
140
+ # Store last assistant response globally (simple approach for demo)
141
+ last_assistant_response = ""
142
+
143
+ def chat_handler(user_input, chat_history):
144
+ global last_assistant_response
145
+ if not user_input:
146
+ return chat_history, "", None
147
+
148
+ inputs = {"messages": [HumanMessage(content=user_input)], "context": "", "decision": "", "source": ""}
149
+ result = app_agent.invoke(inputs)
150
+ final_msg = result["messages"][-1].content
151
+
152
+ chat_history.append({"role": "user", "content": user_input})
153
+ chat_history.append({"role": "assistant", "content": final_msg})
154
+
155
+ # Save clean text for later TTS (without source note)
156
+ last_assistant_response = final_msg.split("*(Verified")[0].strip()
157
+
158
+ # Return chat history and clear audio (no autoplay)
159
+ return chat_history, "", None
160
 
161
+ def generate_audio():
162
+ global last_assistant_response
163
+ if not last_assistant_response:
164
+ return None
165
+ return speech_playback(last_assistant_response)
 
 
 
 
 
166
 
167
+ def upload_file(file):
168
+ if file is None:
169
+ return "❌ No file uploaded."
170
+ try:
171
+ success = extract_and_store_document(file.name)
172
+ if success:
173
+ return f"✅ **{os.path.basename(file.name)}** successfully parsed and added to knowledge base!"
174
+ else:
175
+ return f"⚠️ Failed to extract text from **{os.path.basename(file.name)}**."
176
+ except Exception as e:
177
+ return f"❌ Error: {str(e)}"
178
 
 
179
  with gr.Blocks() as demo:
180
+ gr.Markdown("# 🎓 Tutor AI (single agent with tool-routing capability)")
181
 
182
+ with gr.Tab("AI Chatbot"):
183
+ chatbot = gr.Chatbot(type="messages", height=400)
184
+ with gr.Row():
185
+ msg = gr.Textbox(placeholder="Ask your tutor...", scale=4)
186
+ submit = gr.Button("Send", variant="primary")
187
+ # Manual audio control
188
+ with gr.Row():
189
+ play_audio_btn = gr.Button("🔊 Play Audio Response", variant="secondary")
190
+ audio_out = gr.Audio(label="Audio Response", autoplay=False) # autoplay=False
 
 
 
191
 
192
+ # Chat submission
193
+ submit.click(chat_handler, [msg, chatbot], [chatbot, msg, audio_out])
194
+ msg.submit(chat_handler, [msg, chatbot], [chatbot, msg, audio_out])
195
+ # Manual audio generation
196
+ play_audio_btn.click(generate_audio, None, audio_out)
197
+
198
+ with gr.Tab("Upload Notes"):
199
+ file_input = gr.File(label="Upload PDF / DOCX / PPTX / TXT", file_types=[".pdf", ".docx", ".pptx", ".txt"])
200
+ upload_status = gr.Markdown()
201
+ file_input.change(upload_file, file_input, upload_status)
202
+
203
+ demo.launch(share=True, debug=True)