junaid17 commited on
Commit
745c08b
·
verified ·
1 Parent(s): 5cf14b7

Upload 4 files

Browse files
Files changed (4) hide show
  1. app.py +139 -0
  2. chatbot.py +115 -0
  3. tools.py +282 -0
  4. utils.py +66 -0
app.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, UploadFile, File, HTTPException
2
+ from tools import update_retriever
3
+ from chatbot import app as app_graph, rebuild_graph
4
+ from langchain_core.messages import HumanMessage
5
+ import os
6
+ from fastapi.responses import StreamingResponse, FileResponse
7
+ from fastapi.middleware.cors import CORSMiddleware
8
+ from pydantic import BaseModel
9
+ from utils import TTS, STT
10
+
11
+ # =====================================================
12
+ # APP SETUP
13
+ # =====================================================
14
+
15
+ app = FastAPI()
16
+
17
+ app.add_middleware(
18
+ CORSMiddleware,
19
+ allow_origins=["*"],
20
+ allow_credentials=True,
21
+ allow_methods=["*"],
22
+ allow_headers=["*"],
23
+ )
24
+
25
+ # =====================================================
26
+ # MODELS
27
+ # =====================================================
28
+
29
+ class TTSRequest(BaseModel):
30
+ text: str
31
+
32
+
33
+ UPLOAD_DIR = "uploads"
34
+
35
+ # =====================================================
36
+ # HEALTH CHECK
37
+ # =====================================================
38
+
39
+ @app.get("/")
40
+ def health():
41
+ return {"Status": "The api is live and running"}
42
+
43
+
44
+ # =====================================================
45
+ # FILE UPLOAD (RAG)
46
+ # =====================================================
47
+
48
+ @app.post("/upload")
49
+ async def upload_file(file: UploadFile = File(...)):
50
+ os.makedirs(UPLOAD_DIR, exist_ok=True)
51
+
52
+ file_path = os.path.join(UPLOAD_DIR, file.filename)
53
+
54
+ with open(file_path, "wb") as f:
55
+ f.write(await file.read())
56
+
57
+ # Update retriever
58
+ update_retriever(file_path)
59
+
60
+ # 🔥 Rebuild LangGraph so RAG is active
61
+ rebuild_graph()
62
+
63
+ return {
64
+ "status": "success",
65
+ "filename": file.filename
66
+ }
67
+
68
+
69
+ # =====================================================
70
+ # CHAT ENDPOINT (STREAMING)
71
+ # =====================================================
72
+
73
+ @app.post("/chat")
74
+ async def chat(message: str, session_id: str = "default"):
75
+
76
+ async def event_generator():
77
+ async for chunk in app_graph.astream(
78
+ {"messages": [HumanMessage(content=message)]},
79
+ config={"configurable": {"thread_id": session_id}},
80
+ stream_mode="messages"
81
+ ):
82
+ if not chunk:
83
+ continue
84
+
85
+ msg = chunk[0] if isinstance(chunk, tuple) else chunk
86
+
87
+ if hasattr(msg, "content") and msg.content:
88
+ cleaned = msg.content.strip()
89
+ if cleaned:
90
+ yield f"data: {cleaned}\n\n"
91
+
92
+
93
+ return StreamingResponse(
94
+ event_generator(),
95
+ media_type="text/event-stream",
96
+ headers={
97
+ "Cache-Control": "no-cache",
98
+ "Connection": "keep-alive",
99
+ "X-Accel-Buffering": "no",
100
+ "Access-Control-Allow-Origin": "*",
101
+ },
102
+ )
103
+
104
+
105
+ # =====================================================
106
+ # STT
107
+ # =====================================================
108
+
109
+ @app.post("/stt")
110
+ async def transcribe_audio(file: UploadFile = File(...)):
111
+ try:
112
+ return await STT(file)
113
+ except Exception as e:
114
+ raise HTTPException(status_code=500, detail=str(e))
115
+
116
+
117
+ # =====================================================
118
+ # TTS
119
+ # =====================================================
120
+
121
+ @app.post("/tts")
122
+ async def generate_tts(request: TTSRequest):
123
+ try:
124
+ if not request.text.strip():
125
+ raise HTTPException(status_code=400, detail="Text is empty")
126
+
127
+ audio_path = await TTS(text=request.text)
128
+
129
+ if not os.path.exists(audio_path):
130
+ raise HTTPException(status_code=500, detail="Audio file not created")
131
+
132
+ return FileResponse(
133
+ path=audio_path,
134
+ media_type="audio/mpeg",
135
+ filename="speech.mp3"
136
+ )
137
+
138
+ except Exception as e:
139
+ raise HTTPException(status_code=500, detail=str(e))
chatbot.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import TypedDict, Annotated
2
+ from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage
3
+ from langgraph.checkpoint.memory import MemorySaver
4
+ from tools import (
5
+ create_rag_tool,
6
+ arxiv_search,
7
+ calculator,
8
+ get_stock_price,
9
+ wikipedia_search,
10
+ tavily_search,
11
+ convert_currency,
12
+ unit_converter,
13
+ get_news,
14
+ get_joke,
15
+ get_quote,
16
+ get_weather,
17
+ )
18
+ from langchain_openai import ChatOpenAI
19
+ from langgraph.graph import StateGraph, START
20
+ from langgraph.graph.message import add_messages
21
+ from langgraph.prebuilt import ToolNode, tools_condition
22
+ from dotenv import load_dotenv
23
+ import os
24
+
25
+ load_dotenv()
26
+
27
+ OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
28
+
29
+ # =====================================================
30
+ # SYSTEM PROMPT
31
+ # =====================================================
32
+
33
+ SYSTEM_PROMPT = SystemMessage(
34
+ content="""
35
+ You are an intelligent AI assistant built inside a LangGraph-based system created by Junaid.
36
+
37
+ You MUST use RAG when a document has been uploaded.
38
+ If no document contains the answer, say so clearly.
39
+ Never hallucinate document content.
40
+ """
41
+ )
42
+
43
+ # =====================================================
44
+ # STATE
45
+ # =====================================================
46
+
47
+ class ChatState(TypedDict):
48
+ messages: Annotated[list[BaseMessage], add_messages]
49
+
50
+
51
+ # =====================================================
52
+ # LLM
53
+ # =====================================================
54
+
55
+ llm = ChatOpenAI(
56
+ model="gpt-4.1-nano",
57
+ temperature=0.4,
58
+ streaming=True
59
+ )
60
+
61
+
62
+ # =====================================================
63
+ # GRAPH BUILDER (🔥 IMPORTANT)
64
+ # =====================================================
65
+
66
+ memory = MemorySaver()
67
+ app = None
68
+
69
+
70
+ def build_graph():
71
+ global app
72
+
73
+ rag_tool = create_rag_tool()
74
+
75
+ tools = [
76
+ rag_tool,
77
+ get_stock_price,
78
+ calculator,
79
+ wikipedia_search,
80
+ arxiv_search,
81
+ tavily_search,
82
+ convert_currency,
83
+ unit_converter,
84
+ get_news,
85
+ get_joke,
86
+ get_quote,
87
+ get_weather,
88
+ ]
89
+
90
+ llm_with_tools = llm.bind_tools(tools)
91
+ tool_node = ToolNode(tools)
92
+
93
+ def chatbot(state: ChatState):
94
+ messages = [SYSTEM_PROMPT] + state["messages"]
95
+ response = llm_with_tools.invoke(messages)
96
+ return {"messages": [response]}
97
+
98
+ graph = StateGraph(ChatState)
99
+
100
+ graph.add_node("chat", chatbot)
101
+ graph.add_node("tools", tool_node)
102
+
103
+ graph.add_edge(START, "chat")
104
+ graph.add_conditional_edges("chat", tools_condition)
105
+ graph.add_edge("tools", "chat")
106
+
107
+ app = graph.compile(checkpointer=memory)
108
+
109
+
110
+ # initial build
111
+ build_graph()
112
+
113
+
114
+ def rebuild_graph():
115
+ build_graph()
tools.py ADDED
@@ -0,0 +1,282 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_core.tools import tool
2
+ from langchain_text_splitters import RecursiveCharacterTextSplitter
3
+ from langchain_community.vectorstores import FAISS
4
+ from langchain_community.document_loaders import PyPDFLoader
5
+ from langchain_openai import OpenAIEmbeddings
6
+ from langchain_community.tools import WikipediaQueryRun, ArxivQueryRun
7
+ from langchain_community.utilities import WikipediaAPIWrapper, ArxivAPIWrapper
8
+ from langchain_core.tools import tool
9
+ from langchain_community.tools.tavily_search import TavilySearchResults
10
+ from dotenv import load_dotenv
11
+ import os
12
+ import requests
13
+
14
+ load_dotenv()
15
+
16
+ API_KEY = os.getenv("ALPHAVANTAGE_API_KEY")
17
+ NEWS_API_KEY = os.getenv("NEWS_API_KEY")
18
+ WEATHER_API_KEY = os.getenv("WEATHER_API_KEY")
19
+ NEWS_API_KEY = os.getenv("NEWS_API_KEY")
20
+
21
+ # -------------------------------
22
+ # GLOBAL RETRIEVER
23
+ # -------------------------------
24
+ retriever = None
25
+
26
+
27
+ def build_vectorstore(path: str):
28
+ loader = PyPDFLoader(path)
29
+ docs = loader.load()
30
+
31
+ splitter = RecursiveCharacterTextSplitter(
32
+ chunk_size=500,
33
+ chunk_overlap=100
34
+ )
35
+
36
+ split_docs = splitter.split_documents(docs)
37
+
38
+ embeddings = OpenAIEmbeddings(model="text-embedding-3-small")
39
+ return FAISS.from_documents(split_docs, embeddings)
40
+
41
+
42
+ def update_retriever(pdf_path: str):
43
+ global retriever
44
+ vectorstore = build_vectorstore(pdf_path)
45
+ retriever = vectorstore.as_retriever(search_kwargs={"k": 3})
46
+
47
+
48
+ def create_rag_tool():
49
+
50
+ @tool
51
+ def rag_search(query: str) -> str:
52
+ """
53
+ Retrieve relevant information from the uploaded document.
54
+
55
+ Use this tool when the user asks questions related to the uploaded PDF
56
+ or any document-based knowledge. If no document is available or no
57
+ relevant information is found, return an appropriate message.
58
+ """
59
+ global retriever
60
+
61
+ if retriever is None:
62
+ return "No document uploaded yet."
63
+
64
+ docs = retriever.invoke(query)
65
+
66
+ if not docs:
67
+ return "No relevant information found in the uploaded document."
68
+
69
+ context = "\n\n".join(d.page_content for d in docs)
70
+
71
+ return f"""
72
+ You are given extracted content from a document.
73
+
74
+ Your task:
75
+ - Summarize the content clearly
76
+ - Use bullet points where appropriate
77
+ - Keep formatting clean and readable
78
+ - Do NOT repeat unnecessary text
79
+ - Do NOT mention that this came from a document
80
+
81
+ DOCUMENT CONTENT:
82
+ {context}
83
+ """
84
+
85
+
86
+ return rag_search
87
+
88
+
89
+
90
+ @tool
91
+ def arxiv_search(query: str) -> dict:
92
+ """
93
+ Search arXiv for academic papers related to the query.
94
+ """
95
+ try:
96
+ arxiv = ArxivQueryRun(api_wrapper=ArxivAPIWrapper())
97
+ results = arxiv.run(query)
98
+ return {"query": query, "results": results}
99
+ except Exception as e:
100
+ return {"error": str(e)}
101
+
102
+ @tool
103
+ def calculator(first_num: float, second_num: float, operation: str) -> dict:
104
+ """
105
+ Perform a basic arithmetic operation on two numbers.
106
+ Supported operations: add, sub, mul, div
107
+ """
108
+ try:
109
+ if operation == "add":
110
+ result = first_num + second_num
111
+ elif operation == "sub":
112
+ result = first_num - second_num
113
+ elif operation == "mul":
114
+ result = first_num * second_num
115
+ elif operation == "div":
116
+ if second_num == 0:
117
+ return {"error": "Division by zero is not allowed"}
118
+ result = first_num / second_num
119
+ else:
120
+ return {"error": f"Unsupported operation '{operation}'"}
121
+
122
+ return {"first_num": first_num, "second_num": second_num, "operation": operation, "result": result}
123
+ except Exception as e:
124
+ return {"error": str(e)}
125
+ @tool
126
+ def tavily_search(query: str) -> dict:
127
+ """
128
+ Perform a web search using Tavily,
129
+ also use it to get weather information,
130
+ Returns up to 5 search results.
131
+ """
132
+ try:
133
+ search = TavilySearchResults(max_results=5)
134
+ results = search.run(query)
135
+ return {"query": query, "results": results}
136
+ except Exception as e:
137
+ return {"error": str(e)}
138
+
139
+
140
+ @tool
141
+ def get_stock_price(symbol: str) -> dict:
142
+ """
143
+ Fetch latest stock price for a given symbol (e.g. 'AAPL', 'TSLA')
144
+ using Alpha Vantage with API key in the URL.
145
+ """
146
+ url = f"https://www.alphavantage.co/query?function=GLOBAL_QUOTE&symbol={symbol}&apikey={API_KEY}"
147
+ r = requests.get(url)
148
+ return r.json()
149
+
150
+ @tool
151
+ def wikipedia_search(query: str) -> dict:
152
+ """
153
+ Search Wikipedia for a given query and return results.
154
+ """
155
+ try:
156
+ wiki = WikipediaQueryRun(api_wrapper=WikipediaAPIWrapper())
157
+ results = wiki.run(query)
158
+ return {"query": query, "results": results}
159
+ except Exception as e:
160
+ return {"error": str(e)}
161
+
162
+ @tool
163
+ def convert_currency(amount: float, from_currency: str, to_currency: str) -> dict:
164
+ """
165
+ Convert amount from one currency to another using Frankfurter API.
166
+ Example: convert_currency(100, "USD", "EUR")
167
+ """
168
+ try:
169
+ url = f"https://api.frankfurter.app/latest?amount={amount}&from={from_currency}&to={to_currency}"
170
+ r = requests.get(url)
171
+ return r.json()
172
+ except Exception as e:
173
+ return {"error": str(e)}
174
+ @tool
175
+
176
+
177
+ def unit_converter(value: float, from_unit: str, to_unit: str) -> dict:
178
+ """
179
+ Convert between metric/imperial units (supports: km<->miles, kg<->lbs, C<->F).
180
+ Example: unit_converter(10, "km", "miles")
181
+ """
182
+ try:
183
+ conversions = {
184
+ ("km", "miles"): lambda x: x * 0.621371,
185
+ ("miles", "km"): lambda x: x / 0.621371,
186
+ ("kg", "lbs"): lambda x: x * 2.20462,
187
+ ("lbs", "kg"): lambda x: x / 2.20462,
188
+ ("C", "F"): lambda x: (x * 9/5) + 32,
189
+ ("F", "C"): lambda x: (x - 32) * 5/9
190
+ }
191
+ if (from_unit, to_unit) not in conversions:
192
+ return {"error": f"Unsupported conversion: {from_unit} -> {to_unit}"}
193
+ result = conversions[(from_unit, to_unit)](value)
194
+ return {"value": value, "from": from_unit, "to": to_unit, "result": result}
195
+ except Exception as e:
196
+ return {"error": str(e)}
197
+
198
+
199
+
200
+ @tool
201
+ def get_news(query: str) -> dict:
202
+ """
203
+ Fetch latest news headlines for a given query.
204
+ Example: get_news("artificial intelligence")
205
+ """
206
+ try:
207
+ url = f"https://newsapi.org/v2/everything?q={query}&apiKey={NEWS_API_KEY}&language=en"
208
+ r = requests.get(url)
209
+ return r.json()
210
+ except Exception as e:
211
+ return {"error": str(e)}
212
+
213
+
214
+ @tool
215
+ def get_joke(category: str = "Any") -> dict:
216
+ """
217
+ Get a random joke. Categories: Programming, Misc, Pun, Spooky, Christmas, Any
218
+ Example: get_joke("Programming")
219
+ """
220
+ try:
221
+ url = f"https://v2.jokeapi.dev/joke/{category}"
222
+ r = requests.get(url)
223
+ return r.json()
224
+ except Exception as e:
225
+ return {"error": str(e)}
226
+
227
+ @tool
228
+ def get_quote(tag: str = "") -> dict:
229
+ """
230
+ Fetch a random quote. Optionally filter by tag (e.g., 'inspirational', 'technology').
231
+ Example: get_quote("inspirational")
232
+ """
233
+ try:
234
+ url = f"https://api.quotable.io/random"
235
+ if tag:
236
+ url += f"?tags={tag}"
237
+ r = requests.get(url)
238
+ return r.json()
239
+ except Exception as e:
240
+ return {"error": str(e)}
241
+
242
+ @tool
243
+ def get_weather(city: str) -> dict:
244
+ """
245
+ Get current weather for a given city using WeatherAPI.com.
246
+ Example: get_weather("London")
247
+ """
248
+ try:
249
+ url = f"http://api.weatherapi.com/v1/current.json?key={WEATHER_API_KEY}&q={city}&aqi=no"
250
+ r = requests.get(url)
251
+ data = r.json()
252
+
253
+ if "error" in data:
254
+ return {"error": data["error"]["message"]}
255
+
256
+ return {
257
+ "location": data["location"]["name"],
258
+ "country": data["location"]["country"],
259
+ "temperature_c": data["current"]["temp_c"],
260
+ "temperature_f": data["current"]["temp_f"],
261
+ "condition": data["current"]["condition"]["text"],
262
+ "humidity": data["current"]["humidity"],
263
+ "wind_kph": data["current"]["wind_kph"],
264
+ "wind_dir": data["current"]["wind_dir"]
265
+ }
266
+ except Exception as e:
267
+ return {"error": str(e)}
268
+
269
+
270
+
271
+ @tool
272
+ def get_news(query: str) -> dict:
273
+ """
274
+ Fetch latest news headlines for a given query.
275
+ Example: get_news("artificial intelligence")
276
+ """
277
+ try:
278
+ url = f"https://newsapi.org/v2/everything?q={query}&apiKey={NEWS_API_KEY}&language=en"
279
+ r = requests.get(url)
280
+ return r.json()
281
+ except Exception as e:
282
+ return {"error": str(e)}
utils.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from uuid import uuid4
3
+ import edge_tts
4
+ from groq import Groq
5
+ from dotenv import load_dotenv
6
+
7
+ load_dotenv()
8
+
9
+ client = Groq()
10
+
11
+ # ==================================================
12
+ # 🎙️ TEXT TO SPEECH (FIXED VOICE)
13
+ # ==================================================
14
+
15
+ DEFAULT_VOICE = "en-US-MichelleNeural"
16
+
17
+ async def TTS(
18
+ text: str,
19
+ output_dir: str = "tts_outputs",
20
+ rate: str = "+0%",
21
+ pitch: str = "+0Hz"
22
+ ) -> str:
23
+
24
+ if not text.strip():
25
+ raise ValueError("Empty text")
26
+
27
+ os.makedirs(output_dir, exist_ok=True)
28
+
29
+ filename = f"{uuid4().hex}.mp3"
30
+ output_path = os.path.join(output_dir, filename)
31
+
32
+ communicate = edge_tts.Communicate(
33
+ text=text,
34
+ voice=DEFAULT_VOICE,
35
+ rate=rate,
36
+ pitch=pitch
37
+ )
38
+
39
+ await communicate.save(output_path)
40
+ return output_path
41
+
42
+
43
+ # ==================================================
44
+ # 🎧 SPEECH TO TEXT
45
+ # ==================================================
46
+
47
+ async def STT(audio_file):
48
+ os.makedirs("uploads", exist_ok=True)
49
+ file_path = f"uploads/{uuid4().hex}.wav"
50
+
51
+ with open(file_path, "wb") as f:
52
+ f.write(await audio_file.read())
53
+
54
+ with open(file_path, "rb") as f:
55
+ transcription = client.audio.transcriptions.create(
56
+ file=f,
57
+ model="whisper-large-v3-turbo",
58
+ response_format="verbose_json",
59
+ temperature=0.0
60
+ )
61
+
62
+ return {
63
+ "text": transcription.text,
64
+ "segments": transcription.segments,
65
+ "language": transcription.language
66
+ }