Hammad712 commited on
Commit
eca9b71
·
verified ·
1 Parent(s): 88b95ee

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +40 -73
main.py CHANGED
@@ -40,7 +40,6 @@ llm = ChatOpenAI(
40
  # ---------------------------
41
  app = FastAPI()
42
 
43
- # Load MongoDB credentials from environment variables
44
  MONGO_USER = os.getenv("MONGO_USER")
45
  MONGO_PASSWORD = os.getenv("MONGO_PASSWORD")
46
  MONGO_CLUSTER = os.getenv("MONGO_CLUSTER")
@@ -58,10 +57,6 @@ chat_history_collection = database["chat_history"]
58
  # Classification Helper using LLM (for text messages only)
59
  # ---------------------------
60
  def classify_message_content(message: str) -> str:
61
- """
62
- Use the LLM to classify the provided message as either 'spam' or 'unknown'.
63
- The prompt instructs the LLM to respond with only one word.
64
- """
65
  prompt = (
66
  "Classify the following message as either 'spam' or 'unknown'. "
67
  "Respond with only one word: spam or unknown.\n"
@@ -93,20 +88,12 @@ class IncomingMessage(BaseModel):
93
  # Chat History Helper Functions
94
  # ---------------------------
95
  def get_chat_history(caller_number: str) -> List[dict]:
96
- """
97
- Retrieve the conversation history for a given caller from MongoDB.
98
- Each entry is a dict with 'role' and 'content' keys.
99
- """
100
  doc = chat_history_collection.find_one({"caller_number": caller_number})
101
  if doc and "messages" in doc:
102
  return doc["messages"]
103
  return []
104
 
105
  def update_chat_history(caller_number: str, role: str, content: str):
106
- """
107
- Append a new message (with role and content) to the chat history for the given caller.
108
- If no document exists, one is created (upsert).
109
- """
110
  chat_history_collection.update_one(
111
  {"caller_number": caller_number},
112
  {"$push": {"messages": {"role": role, "content": content}}},
@@ -117,20 +104,11 @@ def update_chat_history(caller_number: str, role: str, content: str):
117
  # Conversation Simulation Functions
118
  # ---------------------------
119
  def simulate_text_conversation(caller_number: str, initial_message: str, conversation_type: str = "unknown") -> str:
120
- """
121
- Simulate a multi-turn text conversation.
122
- For unknown texts, use a neutral tone.
123
- (Spam texts are handled immediately without simulation.)
124
- """
125
  if conversation_type == "unknown":
126
  system_prompt = (
127
  f"You are a call assistant. The unknown caller's text message is '{initial_message}'.\n"
128
- "Simulate a multi-turn conversation that unfolds as follows:\n"
129
- "1. Immediate reply: 'Unknown text – Bot replying'\n"
130
- "2. Follow-up: 'Who is this?'\n"
131
- "3. Then: 'What is the purpose of your message?'\n"
132
- "4. Finally: 'Please provide more details.'\n"
133
- "Return all steps in one message, each on a new line."
134
  )
135
  else:
136
  system_prompt = "You are a call assistant."
@@ -154,21 +132,15 @@ def simulate_text_conversation(caller_number: str, initial_message: str, convers
154
  return assistant_response
155
 
156
  def simulate_call_conversation(caller_number: str, initial_message: str, conversation_type: str = "spam") -> str:
157
- """
158
- Simulate a multi-turn call conversation.
159
- For voice calls, if the caller is not saved, always use the humorous spam conversation.
160
- """
161
  if conversation_type == "spam":
162
  system_prompt = (
163
  f"You are HumorBot on a phone call. The caller's number is {caller_number} and the transcribed message is '{initial_message}'.\n"
164
  "Simulate a multi-turn spam call conversation."
165
-
166
  )
167
  elif conversation_type == "unknown":
168
  system_prompt = (
169
  f"You are a call assistant. The caller's number is {caller_number} and the transcribed message is '{initial_message}'.\n"
170
  "Simulate a multi-turn unknown call conversation."
171
-
172
  )
173
  else:
174
  system_prompt = "You are a call assistant."
@@ -196,10 +168,6 @@ def simulate_call_conversation(caller_number: str, initial_message: str, convers
196
  # ---------------------------
197
  @app.post("/contacts", response_model=List[Contact])
198
  def create_contacts(contacts: List[Contact]):
199
- """
200
- Save a list of contacts into MongoDB.
201
- Email and name are optional, but phone is required and must be unique.
202
- """
203
  contacts_to_insert = []
204
  for contact in contacts:
205
  if contacts_collection.find_one({"phone": contact.phone}):
@@ -212,17 +180,11 @@ def create_contacts(contacts: List[Contact]):
212
 
213
  @app.get("/contacts", response_model=List[Contact])
214
  def get_all_contacts():
215
- """
216
- Retrieve all contacts from MongoDB.
217
- """
218
  contacts = list(contacts_collection.find({}, {"_id": 0}))
219
  return contacts
220
 
221
  @app.get("/contacts/{phone}", response_model=Contact)
222
  def get_contact(phone: str):
223
- """
224
- Retrieve a specific contact by phone number.
225
- """
226
  contact = contacts_collection.find_one({"phone": phone}, {"_id": 0})
227
  if not contact:
228
  raise HTTPException(status_code=404, detail="Contact not found")
@@ -230,13 +192,6 @@ def get_contact(phone: str):
230
 
231
  @app.post("/incoming-message")
232
  def process_incoming_message(incoming: IncomingMessage):
233
- """
234
- Process an incoming text message:
235
- - If the sender's number is in saved contacts, return a primary message.
236
- - Otherwise, use the LLM to classify the message content.
237
- • If classified as spam, respond immediately.
238
- • If classified as unknown, simulate a neutral multi-turn conversation.
239
- """
240
  if contacts_collection.find_one({"phone": incoming.phone}):
241
  return {
242
  "status": "primary",
@@ -258,9 +213,6 @@ def process_incoming_message(incoming: IncomingMessage):
258
 
259
  @app.get("/messages/{caller_number}")
260
  def get_messages(caller_number: str):
261
- """
262
- Retrieve the conversation history (messages) for a given caller.
263
- """
264
  messages = get_chat_history(caller_number)
265
  if not messages:
266
  raise HTTPException(status_code=404, detail="No messages found for this caller")
@@ -268,9 +220,6 @@ def get_messages(caller_number: str):
268
 
269
  @app.post("/setup-call-forwarding")
270
  def setup_call_forwarding():
271
- """
272
- Simulate call forwarding setup.
273
- """
274
  forwarding_number = "+1-555-123-4567"
275
  return {"status": "success", "message": f"Setup done! Calls forwarded to {forwarding_number}"}
276
 
@@ -278,9 +227,6 @@ def setup_call_forwarding():
278
  # STT and TTS Functions for Voice Calls
279
  # ---------------------------
280
  def transcribe_audio(audio_file: bytes) -> str:
281
- """
282
- Convert incoming audio to text using Groq Whisper v3 (STT).
283
- """
284
  response = groq_client.audio.transcriptions.create(
285
  file=("audio.m4a", audio_file),
286
  model="whisper-large-v3",
@@ -289,9 +235,6 @@ def transcribe_audio(audio_file: bytes) -> str:
289
  return response.text
290
 
291
  def text_to_speech(text: str) -> bytes:
292
- """
293
- Convert text to speech using Cartesia TTS.
294
- """
295
  audio_bytes = cartesia_client.tts.bytes(
296
  model_id="sonic",
297
  transcript=text,
@@ -306,14 +249,9 @@ def text_to_speech(text: str) -> bytes:
306
 
307
  @app.post("/process-call")
308
  async def process_call(caller_number: str = Form(...), audio: UploadFile = File(...)):
309
- """
310
- Process an incoming voice call:
311
- - If the caller's number is in saved contacts, immediately return a "Ringing" message.
312
- - Otherwise (unsaved caller), always treat the call as spam and simulate a humorous conversation.
313
- """
314
  if contacts_collection.find_one({"phone": caller_number}):
315
  ringing_text = f"Call from {caller_number} – Ringing"
316
- _ = text_to_speech(ringing_text) # Generate audio if needed
317
  return {"status": "success", "message": ringing_text}
318
 
319
  try:
@@ -328,7 +266,6 @@ async def process_call(caller_number: str = Form(...), audio: UploadFile = File(
328
 
329
  update_chat_history(caller_number, "stt", transcription)
330
 
331
- # For voice calls, unsaved callers are always treated as spam.
332
  conversation_result = simulate_call_conversation(caller_number, transcription, conversation_type="spam")
333
 
334
  try:
@@ -340,13 +277,6 @@ async def process_call(caller_number: str = Form(...), audio: UploadFile = File(
340
 
341
  @app.get("/audio-reply/{caller_number}")
342
  def get_audio_reply(caller_number: str):
343
- """
344
- Retrieve the latest assistant reply and STT transcription for a given caller.
345
- Returns:
346
- - stt_response: the transcription (STT)
347
- - llm_reply: the assistant reply (LLM)
348
- - audio_reply: the TTS audio (WAV) as a base64-encoded string.
349
- """
350
  messages = get_chat_history(caller_number)
351
  if not messages:
352
  raise HTTPException(status_code=404, detail="No conversation found for this caller")
@@ -369,6 +299,43 @@ def get_audio_reply(caller_number: str):
369
  "audio_reply": audio_base64
370
  }
371
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
372
  @app.get("/")
373
  async def root():
374
  return {"message": "Welcome to the AI Spam Blocker API."}
 
40
  # ---------------------------
41
  app = FastAPI()
42
 
 
43
  MONGO_USER = os.getenv("MONGO_USER")
44
  MONGO_PASSWORD = os.getenv("MONGO_PASSWORD")
45
  MONGO_CLUSTER = os.getenv("MONGO_CLUSTER")
 
57
  # Classification Helper using LLM (for text messages only)
58
  # ---------------------------
59
  def classify_message_content(message: str) -> str:
 
 
 
 
60
  prompt = (
61
  "Classify the following message as either 'spam' or 'unknown'. "
62
  "Respond with only one word: spam or unknown.\n"
 
88
  # Chat History Helper Functions
89
  # ---------------------------
90
  def get_chat_history(caller_number: str) -> List[dict]:
 
 
 
 
91
  doc = chat_history_collection.find_one({"caller_number": caller_number})
92
  if doc and "messages" in doc:
93
  return doc["messages"]
94
  return []
95
 
96
  def update_chat_history(caller_number: str, role: str, content: str):
 
 
 
 
97
  chat_history_collection.update_one(
98
  {"caller_number": caller_number},
99
  {"$push": {"messages": {"role": role, "content": content}}},
 
104
  # Conversation Simulation Functions
105
  # ---------------------------
106
  def simulate_text_conversation(caller_number: str, initial_message: str, conversation_type: str = "unknown") -> str:
 
 
 
 
 
107
  if conversation_type == "unknown":
108
  system_prompt = (
109
  f"You are a call assistant. The unknown caller's text message is '{initial_message}'.\n"
110
+ "Simulate a multi-turn conversation."
111
+
 
 
 
 
112
  )
113
  else:
114
  system_prompt = "You are a call assistant."
 
132
  return assistant_response
133
 
134
  def simulate_call_conversation(caller_number: str, initial_message: str, conversation_type: str = "spam") -> str:
 
 
 
 
135
  if conversation_type == "spam":
136
  system_prompt = (
137
  f"You are HumorBot on a phone call. The caller's number is {caller_number} and the transcribed message is '{initial_message}'.\n"
138
  "Simulate a multi-turn spam call conversation."
 
139
  )
140
  elif conversation_type == "unknown":
141
  system_prompt = (
142
  f"You are a call assistant. The caller's number is {caller_number} and the transcribed message is '{initial_message}'.\n"
143
  "Simulate a multi-turn unknown call conversation."
 
144
  )
145
  else:
146
  system_prompt = "You are a call assistant."
 
168
  # ---------------------------
169
  @app.post("/contacts", response_model=List[Contact])
170
  def create_contacts(contacts: List[Contact]):
 
 
 
 
171
  contacts_to_insert = []
172
  for contact in contacts:
173
  if contacts_collection.find_one({"phone": contact.phone}):
 
180
 
181
  @app.get("/contacts", response_model=List[Contact])
182
  def get_all_contacts():
 
 
 
183
  contacts = list(contacts_collection.find({}, {"_id": 0}))
184
  return contacts
185
 
186
  @app.get("/contacts/{phone}", response_model=Contact)
187
  def get_contact(phone: str):
 
 
 
188
  contact = contacts_collection.find_one({"phone": phone}, {"_id": 0})
189
  if not contact:
190
  raise HTTPException(status_code=404, detail="Contact not found")
 
192
 
193
  @app.post("/incoming-message")
194
  def process_incoming_message(incoming: IncomingMessage):
 
 
 
 
 
 
 
195
  if contacts_collection.find_one({"phone": incoming.phone}):
196
  return {
197
  "status": "primary",
 
213
 
214
  @app.get("/messages/{caller_number}")
215
  def get_messages(caller_number: str):
 
 
 
216
  messages = get_chat_history(caller_number)
217
  if not messages:
218
  raise HTTPException(status_code=404, detail="No messages found for this caller")
 
220
 
221
  @app.post("/setup-call-forwarding")
222
  def setup_call_forwarding():
 
 
 
223
  forwarding_number = "+1-555-123-4567"
224
  return {"status": "success", "message": f"Setup done! Calls forwarded to {forwarding_number}"}
225
 
 
227
  # STT and TTS Functions for Voice Calls
228
  # ---------------------------
229
  def transcribe_audio(audio_file: bytes) -> str:
 
 
 
230
  response = groq_client.audio.transcriptions.create(
231
  file=("audio.m4a", audio_file),
232
  model="whisper-large-v3",
 
235
  return response.text
236
 
237
  def text_to_speech(text: str) -> bytes:
 
 
 
238
  audio_bytes = cartesia_client.tts.bytes(
239
  model_id="sonic",
240
  transcript=text,
 
249
 
250
  @app.post("/process-call")
251
  async def process_call(caller_number: str = Form(...), audio: UploadFile = File(...)):
 
 
 
 
 
252
  if contacts_collection.find_one({"phone": caller_number}):
253
  ringing_text = f"Call from {caller_number} – Ringing"
254
+ _ = text_to_speech(ringing_text)
255
  return {"status": "success", "message": ringing_text}
256
 
257
  try:
 
266
 
267
  update_chat_history(caller_number, "stt", transcription)
268
 
 
269
  conversation_result = simulate_call_conversation(caller_number, transcription, conversation_type="spam")
270
 
271
  try:
 
277
 
278
  @app.get("/audio-reply/{caller_number}")
279
  def get_audio_reply(caller_number: str):
 
 
 
 
 
 
 
280
  messages = get_chat_history(caller_number)
281
  if not messages:
282
  raise HTTPException(status_code=404, detail="No conversation found for this caller")
 
299
  "audio_reply": audio_base64
300
  }
301
 
302
+ # ---------------------------
303
+ # New Endpoints for Direct STT, TTS, and LLM Calls
304
+ # ---------------------------
305
+ @app.post("/stt")
306
+ async def stt_endpoint(audio: UploadFile = File(...)):
307
+ try:
308
+ audio_bytes = await audio.read()
309
+ transcription = transcribe_audio(audio_bytes)
310
+ return {"transcription": transcription}
311
+ except Exception as e:
312
+ raise HTTPException(status_code=500, detail=f"STT Error: {str(e)}")
313
+
314
+ @app.post("/tts")
315
+ def tts_endpoint(text: str = Body(..., embed=True)):
316
+ try:
317
+ audio_bytes = text_to_speech(text)
318
+ audio_base64 = base64.b64encode(audio_bytes).decode("utf-8")
319
+ return {"audio": audio_base64}
320
+ except Exception as e:
321
+ raise HTTPException(status_code=500, detail=f"TTS Error: {str(e)}")
322
+
323
+ @app.post("/llm")
324
+ def llm_endpoint(message: str = Body(..., embed=True)):
325
+ try:
326
+ messages = [
327
+ {"role": "system", "content": "You are a helpful assistant."},
328
+ {"role": "user", "content": message}
329
+ ]
330
+ response = llm.invoke(messages)
331
+ if hasattr(response, "content"):
332
+ reply = response.content
333
+ else:
334
+ reply = str(response)
335
+ return {"reply": reply}
336
+ except Exception as e:
337
+ raise HTTPException(status_code=500, detail=f"LLM Error: {str(e)}")
338
+
339
  @app.get("/")
340
  async def root():
341
  return {"message": "Welcome to the AI Spam Blocker API."}