arabic-chatbot / App /main.py
itsmehardawood's picture
Update App/main.py
5b5ed2b verified
import os
import uuid
import base64
import bcrypt
from datetime import datetime, timedelta , timezone
import io
from typing import Optional , Literal
from fastapi import FastAPI, UploadFile, File, HTTPException, Depends, Query as QueryParam , Request
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, EmailStr, Field , HttpUrl
from jose import jwt
from motor.motor_asyncio import AsyncIOMotorClient
from bson.binary import Binary
from fastapi import status
# from youtube_transcript_api import YouTubeTranscriptApi, NoTranscriptFound
from App.utils import (
load_and_split_bytes,
add_documents_to_index,
get_llm,
get_existing_retriever,
get_collection_stats ,
build_chroma_index ,
create_rag_chain_with_history ,
count_tokens ,
cosine_similarity
)
from bson import ObjectId
import openai
import numpy as np
import logging
from langchain_openai import OpenAIEmbeddings
import os
from dotenv import load_dotenv
from paypalcheckoutsdk.core import PayPalHttpClient, SandboxEnvironment, LiveEnvironment
from paypalcheckoutsdk.orders import OrdersCreateRequest, OrdersCaptureRequest
from datetime import datetime
now_utc = datetime.utcnow()
print("Current UTC time:", now_utc)
load_dotenv()
env_name = os.environ.get("PAYPAL_ENVIRONMENT", "sandbox")
creds = dict(
client_id = os.environ.get("PAYPAL_CLIENT_ID_sn"),
client_secret = os.environ.get("PAYPAL_CLIENT_SECRET_sn"),
)
environment = (
LiveEnvironment(**creds)
if env_name == "live"
else SandboxEnvironment(**creds)
)
p_client = PayPalHttpClient(environment)
# Configuration for JWT
SECRET_KEY = "hssjhdahsd"
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 30
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*" , 'http://localhost:3000'],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# MongoDB connection
connection_string = os.getenv(
"MONGODB_URI",
"mongodb+srv://ahmed0499280:haseeb.2003@cluster0.hzgrxp2.mongodb.net/"
"?retryWrites=true&w=majority&appName=Cluster0"
)
client = AsyncIOMotorClient(connection_string)
db = client["Cluster0"]
users_collection = db["users"]
chatbot_history_collection = db["chatbothistory"]
documents_collection = db["documents"] # Collection to store document files
chroma_db_collection = db["chroma_db_store"] # Collection to store Chroma DB
flash_cards = db['flash_cards']
videos_collection = db['videos']
orders_collection = db["orders"]
subscriptions_collection = db["subscriptions"]
# Default language setting
language = None
# Pydantic models
class CreateSubOrderSchema(BaseModel):
user_id: str
plan: Literal["monthly", "yearly"]
class TrialResponse(BaseModel):
status: str
expires: datetime
class OrderResponse(BaseModel):
order_id: str
status: str
class SubscriptionRequest(BaseModel):
user_id: str
plan: str
class SearchRequest(BaseModel):
query: str
class Video(BaseModel):
link: str
description: str
class LanguageRequest(BaseModel):
language: str
class QueryModel(BaseModel):
question: str
user_id: str
diacritics:bool
level : str
class UserSignup(BaseModel):
username: str = Field(..., min_length=3, max_length=50)
email: EmailStr
password: str = Field(..., min_length=8)
language: str
is_admin : bool
class UserModel(BaseModel):
id: str
username: str
email: str
language: str
is_admin: bool
password: Optional[str] = None
class UserResponse(BaseModel):
id: str
username: str
email: str
class UserLogin(BaseModel):
email: EmailStr
password: str
class Token(BaseModel):
access_token: str
token_type: str
class DocumentInfo(BaseModel):
filename: str
upload_date: datetime
file_size: int
chunks: int
user_id: str
class FlashcardSaveModel(BaseModel):
user_id: str
question: str
answer: str
# Helper for login
async def authenticate_user(email: str, password: str):
user = await users_collection.find_one({"email": email})
if not user or not bcrypt.checkpw(password.encode(), user["password"].encode()):
return None
return user
@app.get("/users/{user_id}", response_model=UserModel)
async def get_user(user_id: str):
"""
Fetch a single user by their user_id.
Supports lookup by MongoDB ObjectId or string _id.
"""
# Build query: if user_id is a valid ObjectId, use that, else match string _id
try:
obj_id = ObjectId(user_id)
query = {"_id": obj_id}
except Exception:
query = {"_id": user_id}
user = await users_collection.find_one(query)
if not user:
raise HTTPException(status_code=404, detail="User not found")
# Prepare response: convert _id to id
user_data = {
"id": str(user.get("_id")),
"username": user.get("username"),
"email": user.get("email"),
"language": user.get("language"),
"is_admin": user.get("is_admin", False),
}
return user_data
# Initialize retriever at startup
@app.on_event("startup")
async def startup_event():
# Retrieve existing retriever if it exists
try:
retriever = get_existing_retriever()
if retriever:
app.state.retriever = retriever
stats = get_collection_stats()
print(f"Loaded existing vector database with {stats['document_count']} document chunks")
else:
app.state.retriever = None
print("No existing vector database found. Will create when documents are uploaded.")
except Exception as e:
print(f"Error loading vector database: {e}")
app.state.retriever = None
### API Endpoints
# To get Language
@app.post("/language")
async def receive_language(req: LanguageRequest):
global language
language = req.language
return {'Message': f"Language is Selected to '{language}'"}
# To SignUp
@app.post("/signup", response_model=UserResponse, tags=["auth"])
async def signup(user: UserSignup):
if await users_collection.find_one({"email": user.email}):
raise HTTPException(status_code=400, detail="Email already registered")
salt = bcrypt.gensalt()
hashed = bcrypt.hashpw(user.password.encode(), salt).decode()
user_id = str(uuid.uuid4())
await users_collection.insert_one({
"_id": user_id,
"username": user.username,
"email": user.email,
"password": hashed,
"language": user.language ,
"is_admin" : False
})
print(user.language)
return {"id": user_id, "username": user.username, "email": user.email}
# TO Login
@app.post("/login", response_model=Token, tags=["auth"])
async def login(credentials: UserLogin):
user = await authenticate_user(credentials.email, credentials.password)
if not user:
raise HTTPException(status_code=401, detail="Invalid email or password")
expire = datetime.utcnow() + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
to_encode = {"sub": user["_id"], "exp": expire}
token = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
return {"access_token": token, "token_type": "bearer"}
# To Add Document to RAG - MongoDB storage for both document and Chroma
@app.post("/build_rag", tags=["rag"])
async def build_rag_endpoint(file: UploadFile = File(...), user_id: str = QueryParam()):
if not file.filename.endswith(".docx"):
raise HTTPException(status_code=400, detail="Only .docx files are supported.")
try:
# Read the file into memory
file_content = await file.read()
file_size = len(file_content)
# Generate a unique document ID
doc_id = str(uuid.uuid4())
# Process the document for RAG
temp_file = io.BytesIO(file_content)
docs = load_and_split_bytes(temp_file)
# Store document in MongoDB
doc_info = {
"_id": doc_id,
"filename": file.filename,
"upload_date": datetime.utcnow(),
"file_size": file_size,
"chunks": len(docs),
"user_id": user_id,
"file_content": Binary(file_content) # Store as Binary BSON type
}
await documents_collection.insert_one(doc_info)
# Add to or update the Chroma vector store and save to MongoDB
collection_name = "default"
app.state.retriever = add_documents_to_index(docs, collection_name=collection_name)
# Get updated stats
stats = get_collection_stats(collection_name)
return {
"message": f"File '{file.filename}' added to knowledge base.",
"document_id": doc_id,
"total_chunks_in_db": stats["document_count"]
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error processing document: {str(e)}")
# To query RAG
@app.post("/query_rag", tags=["rag"])
async def query_rag_endpoint(payload: QueryModel):
now = datetime.now(timezone.utc)
now_utc = datetime.utcnow()
print("Checking subrciption", now_utc)
subscription = await subscriptions_collection.find_one({
"user_id": payload.user_id,
"status": "active"
})
if not subscription:
raise HTTPException(status_code=403, detail="No active subscription found. Please go to manage subscription section to continue chat ")
end_date = subscription.get("end_date")
if end_date and end_date.tzinfo is None:
end_date = end_date.replace(tzinfo=timezone.utc)
if end_date and now > end_date:
# Update status to expired
await subscriptions_collection.update_one(
{"_id": subscription["_id"]},
{"$set": {"status": "expired"}}
)
raise HTTPException(status_code=403, detail="Subscription expired.")
now_utc = datetime.utcnow()
print("Classification of query", now_utc)
llm = get_llm()
cl_promp =f'''You are a binary intent classifier. When given a user message, decide if the user is asking to play or watch a video—explicitly or implicitly.
Instructions:
1. Analyze the provided message (`{payload.question}`) to determine intent:
- If the user intends to start, resume, or watch a video (e.g. “play the video now,” “let me see that demo,” “could you launch the tutorial clip?”), choose:
play_video
- Otherwise, choose:
other
2. Output exactly one label (play_video or other), with no additional text, punctuation, or formatting.
Input:
{{message}}
Output:'''
cl = llm.invoke(cl_promp).content
print(cl)
if cl == 'play_video':
now_utc = datetime.utcnow()
print("Video Processing", now_utc)
try:
# 1) Fetch up to 1000 docs (just the link+description)
docs = await videos_collection.find(
{},
{"link": 1, "description": 1, "_id": 0}
).to_list(length=1000)
if not docs:
raise HTTPException(status_code=404, detail="No videos found")
# 2) Get descriptions as a list
descriptions = [doc["description"] for doc in docs]
# 3) Embed query using LangChain's OpenAIEmbeddings
query_embedding = embeddings_model.embed_query(payload.question)
# 4) Embed all descriptions (this is still inefficient but uses LangChain)
description_embeddings = embeddings_model.embed_documents(descriptions)
# 5) Compute similarities
similarities = [cosine_similarity(query_embedding, desc_emb)
for desc_emb in description_embeddings]
# 6) Pick best match
best_idx = int(np.argmax(similarities))
best_doc = docs[best_idx]
print(similarities[best_idx])
if similarities[best_idx] > 0.6:
return {
"answer" : "Here is video Based on your query",
"youtube": {
"embed_url": best_doc["link"],
"watch_url": ""
},
"transcript": ""
}
else:
return {
"answer" : "No Video found" ,
"youtube": {
"embed_url": "",
"watch_url": ""
},
"transcript": ""
}
except Exception as e:
logging.error(f"Error during video search: {str(e)}")
raise HTTPException(status_code=500, detail="Error processing search request")
# t_prompt = f'''You are a concise title generator. When given a user’s message, produce a short, keyword‑rich title that captures the core topic or intent—optimized for searching on YouTube.
# Instructions:
# 1. Read the provided message (`{payload.question}`).
# 2. Extract its main subject or action.
# 3. Craft a clear, descriptive title (5–8 words) suitable as a YouTube search query.
# 4. Output **exactly** the title, with no extra text or punctuation.
# Input:
# {{message}}
# Output:
# '''
# message_title = llm.invoke(t_prompt).content
# print(message_title)
# try:
# vid = search_youtube_video(message_title)
# try:
# transcript_data = YouTubeTranscriptApi.get_transcript(vid, languages=['en'])
# # transcript_data is a list of { "text": "...", "start": 12.34, "duration": 2.5 }
# except NoTranscriptFound:
# transcript_data = []
# except Exception:
# raise HTTPException(status_code=404, detail="Could not find a matching YouTube video.")
# embed_url = f"https://www.youtube.com/embed/{vid}"
# watch_url = f"https://www.youtube.com/watch?v={vid}"
# return {
# "answer": f"Sure—here’s what I found on YouTube for your request:",
# "youtube": {
# "embed_url": embed_url,
# "watch_url": watch_url
# },
# "transcript": transcript_data
# }
else:
now_utc = datetime.utcnow()
print("General query", now_utc)
# Use app.state.retriever which is initialized at startup or updated when docs are added
if not hasattr(app.state, "retriever") or app.state.retriever is None:
# Check if vector store exists but hasn't been loaded
retriever = get_existing_retriever()
if retriever:
app.state.retriever = retriever
else:
raise HTTPException(
status_code=400,
detail="No documents have been uploaded. Please upload documents using /build_rag first."
)
# Get user language preference
if language is None:
user_id = payload.user_id
user = await users_collection.find_one({"_id": user_id})
if user:
user_lan = user.get("language")
# else:
# user_lan = "Arabic" # Default to English
else:
user_lan = language
# question_tokens = count_tokens(payload.question)
# In your query_rag_endpoint, before rag_chain.invoke():
now_utc = datetime.utcnow()
print("Getting retriver", now_utc)
try:
retrieved_docs = app.state.retriever.get_relevant_documents(payload.question)
except Exception as e:
# Rebuild retriever and try again
app.state.retriever = get_existing_retriever()
if not app.state.retriever:
raise HTTPException(400, "Vector store corrupted. Please rebuild.")
retrieved_docs = app.state.retriever.get_relevant_documents(payload.question)
now_utc = datetime.utcnow()
print("Getting previous context", now_utc)
# retrieved_docs = app.state.retriever.get_relevant_documents(payload.question)
# context_text = "\n\n".join([doc.page_content for doc in retrieved_docs])
# context_tokens = count_tokens(context_text)
chat_history = await chatbot_history_collection.find_one({"userId": payload.user_id})
previous_messages = []
if chat_history and "messages" in chat_history:
history_limit = 5
recent_messages = chat_history["messages"][-history_limit:]
for msg in recent_messages:
previous_messages.append({"role": "human", "content": msg["user_message"]})
previous_messages.append({"role": "assistant", "content": msg["ai_response"]})
history_text = ""
for msg in previous_messages:
history_text += f"{msg['role']}: {msg['content']}\n"
# history_tokens = count_tokens(history_text)
# print(f"History tokens: {history_tokens}")
# Create RAG chain and generate response
llm = get_llm()
print("Diacritics : " , payload.diacritics)
print("Language : " ,user_lan)
now_utc = datetime.utcnow()
print("RAG Chain ", now_utc)
rag_chain = create_rag_chain_with_history(app.state.retriever, llm, user_lan,payload.level ,payload.diacritics, previous_messages)
response = rag_chain.invoke({"input": payload.question})
# response_tokens = count_tokens(response['answer'])
# print(f"Response tokens: {response_tokens}")
# print(f"Total tokens: {question_tokens + history_tokens + context_tokens}")
now_utc = datetime.utcnow()
print("RAG Chain Done", now_utc)
# Record chat history
msg_record = {
"user_message": payload.question,
"ai_response": response['answer'],
"timestamp": datetime.utcnow()
}
await chatbot_history_collection.update_one(
{"userId": payload.user_id},
{"$push": {"messages": msg_record}},
upsert=True
)
return {
"answer": response["answer"] ,
"youtube": {
"embed_url": "",
"watch_url": ""
},
"transcript": "" }
@app.post("/save_flashcard", tags=["flashcards"])
async def save_flashcard(payload: FlashcardSaveModel):
"""
Save a chat exchange as a flashcard with LLM-generated title.
"""
try:
# Generate title using LLM
# Get LLM instance
llm = get_llm()
# Create a prompt for title generation
title_prompt = f"""
Based on the following question and answer, generate a concise, descriptive title (15 words or less)
that captures the main topic or key insight. Make it specific enough that someone can understand
what information they'll find in this flashcard.
Question: {payload.question}
Answer: {payload.answer}
Title:
"""
# Generate title using LLM
title_response = llm.invoke(title_prompt)
# Clean up the response (remove any quotation marks, extra spaces, etc.)
title = title_response.content.strip().strip('"\'').strip()
# If LLM fails to generate a good title, fall back to question snippet
if not title or len(title) > 100:
title = payload.question[:50] + ("..." if len(payload.question) > 50 else "")
# Create the flashcard document
flashcard_doc = {
"_id": str(uuid.uuid4()),
"user_id": payload.user_id,
"title": title,
"question": payload.question,
"answer": payload.answer,
"created_at": datetime.utcnow()
}
# Insert into flash_cards collection
await flash_cards.insert_one(flashcard_doc)
return {
"success": True,
"flashcard_id": flashcard_doc["_id"],
"title": title, # Return the generated title to the frontend
"message": "Flashcard saved successfully"
}
except Exception as e:
print(f"Error saving flashcard: {str(e)}")
raise HTTPException(status_code=500, detail=f"Error saving flashcard: {str(e)}")
# Save Flash Cards
@app.get("/flashcards/{user_id}", tags=["flashcards"])
async def get_user_flashcards(user_id: str):
"""
Retrieve all flashcards for a specific user.
"""
try:
cursor = flash_cards.find({"user_id": user_id}).sort("created_at", -1)
flashcards = []
async for card in cursor:
# Convert ObjectId to string if needed
card["_id"] = str(card["_id"])
flashcards.append(card)
return {"user_id": user_id, "flashcards": flashcards}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error retrieving flashcards: {str(e)}")
# Delete FlashCards
@app.delete("/flashcards/{flashcard_id}", tags=["flashcards"])
async def delete_flashcard(flashcard_id: str, user_id: str = QueryParam()):
"""
Delete a specific flashcard. Requires user_id to verify ownership.
"""
try:
result = await flash_cards.delete_one({
"_id": flashcard_id,
"user_id": user_id
})
if result.deleted_count == 0:
raise HTTPException(status_code=404, detail="Flashcard not found or not owned by this user")
return {"success": True, "message": "Flashcard deleted successfully"}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error deleting flashcard: {str(e)}")
# For Video
import re
def get_embed_link(youtube_url):
match = re.search(r"(?:v=|youtu\.be/)([a-zA-Z0-9_-]{11})", youtube_url)
if match:
video_id = match.group(1)
return f"https://www.youtube.com/embed/{video_id}"
return None
@app.post("/add_video")
async def add_video(video: Video):
embed_url = get_embed_link(video.link)
result = await videos_collection.insert_one({
"link": embed_url,
"description": video.description
})
if result.inserted_id:
return {"message": "Video added successfully", "id": str(result.inserted_id)}
raise HTTPException(500, "Failed to add video")
# To Delete a document
@app.delete("/documents/{document_id}", tags=["rag"], status_code=status.HTTP_204_NO_CONTENT)
async def delete_document(document_id: str):
# 1. Delete the document metadata + file bytes
result = await documents_collection.delete_one({"_id": document_id})
if result.deleted_count == 0:
raise HTTPException(status_code=404, detail="Document not found")
# 2. Rebuild the RAG index from all remaining documents
# - Fetch all stored documents
cursor = documents_collection.find({})
all_docs = []
async for doc in cursor:
# load_and_split_bytes expects a file-like, so wrap bytes in BytesIO
from io import BytesIO
chunks = load_and_split_bytes(BytesIO(doc["file_content"]))
all_docs.extend(chunks)
# 3. If there are any chunks left, rebuild; otherwise clear retriever
if all_docs:
# This will re-create your Chroma index and persist it to Mongo
app.state.retriever = build_chroma_index(
all_docs,
collection_name="default"
)
else:
# No documents → clear both in-memory and persisted vector store
app.state.retriever = None
await chroma_db_collection.delete_one({"_id": "default"})
# To download a document
@app.get("/documents", tags=["rag"])
async def list_all_documents():
try:
cursor = documents_collection.find({})
documents = []
async for doc in cursor:
# Create a copy without the binary content
documents.append({
"id": str(doc["_id"]),
"filename": doc.get("filename"),
"upload_date": doc.get("upload_date"),
"file_size": doc.get("file_size"),
"chunks": doc.get("chunks"),
})
return {"documents": documents}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# To get Chat History
@app.get("/chat-history/{user_id}")
async def get_chat_history(user_id: str):
try:
cursor = chatbot_history_collection.find({"userId": user_id})
chat_history = []
async for doc in cursor:
# Convert ObjectId to string
doc["_id"] = str(doc["_id"])
chat_history.append(doc)
if not chat_history:
raise HTTPException(status_code=404, detail="No chat history found for this user")
return {"user_id": user_id, "chat_history": chat_history}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# To list uploaded documents
@app.get("/documents/user/{user_id}", tags=["rag"])
async def list_documents(user_id: str):
try:
cursor = documents_collection.find({"user_id": user_id})
documents = []
async for doc in cursor:
# Create a copy without the binary content
doc_info = {
"id": str(doc["_id"]),
"filename": doc["filename"],
"upload_date": doc["upload_date"],
"file_size": doc["file_size"],
"chunks": doc["chunks"],
}
documents.append(doc_info)
return {"user_id": user_id, "documents": documents}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# To get knowledge base info
@app.get("/knowledge-base-info", tags=["rag"])
async def get_knowledge_base_info():
stats = get_collection_stats()
return {
"documents_loaded": stats["exists"],
"total_chunks": stats["document_count"]
}
embeddings_model = OpenAIEmbeddings(model="text-embedding-ada-002" , openai_api_key = "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")
@app.post("/search_video")
async def search_video(req: SearchRequest):
if not req.query.strip():
raise HTTPException(status_code=400, detail="Search query cannot be empty")
try:
# 1) Fetch up to 1000 docs (just the link+description)
docs = await videos_collection.find(
{},
{"link": 1, "description": 1, "_id": 0}
).to_list(length=1000)
if not docs:
raise HTTPException(status_code=404, detail="No videos found")
# 2) Get descriptions as a list
descriptions = [doc["description"] for doc in docs]
# 3) Embed query using LangChain's OpenAIEmbeddings
query_embedding = embeddings_model.embed_query(req.query)
# 4) Embed all descriptions (this is still inefficient but uses LangChain)
description_embeddings = embeddings_model.embed_documents(descriptions)
# 5) Compute similarities
similarities = [cosine_similarity(query_embedding, desc_emb)
for desc_emb in description_embeddings]
# 6) Pick best match
best_idx = int(np.argmax(similarities))
best_doc = docs[best_idx]
return {
"link": best_doc["link"],
"description": best_doc["description"],
"score": similarities[best_idx]
}
except Exception as e:
logging.error(f"Error during video search: {str(e)}")
raise HTTPException(status_code=500, detail="Error processing search request")
# @app.post("/start-trial/{user_id}", response_model=TrialResponse)
# async def start_trial(user_id: str):
# trial = await subscr`iptions_collection.find_one({
# "user_id": user_id,
# "plan": "trial"
# })
# if trial:
# # Optional: Mark expired trial as "expired" for clarity
# if trial["end_date"] < datetime.utcnow() and trial["status"] == "active":
# await subscriptions_collection.update_one(
# {"_id": trial["_id"]},
# {"$set": {"status": "expired", "updated_at": datetime.utcnow()}}
# )
# raise HTTPException(400, "You have already used your free trial. Please subscribe to continue.")
# now = datetime.utcnow()
# trial = {
# "user_id": user_id,
# "plan": "trial",
# "status": "active",
# "start_date": now,
# "end_date": now + timedelta(days=3),
# "created_at": now,
# "updated_at": now
# }
# await subscriptions_collection.insert_one(trial)
# return {"status": "trial started", "expires": trial["end_date"]}
async def has_valid_active_subscription(user_id: str) -> bool:
now = datetime.utcnow()
sub = await subscriptions_collection.find_one({
"user_id": user_id,
"status": "active",
"end_date": {"$gt": now}
})
return sub is not None
async def expire_old_subscriptions(user_id: str):
now = datetime.utcnow()
await subscriptions_collection.update_many(
{
"user_id": user_id,
"status": "active",
"end_date": {"$lte": now} # already expired
},
{"$set": {"status": "expired", "updated_at": now}}
)
@app.post("/start-trial/{user_id}", response_model=TrialResponse)
async def start_trial(user_id: str):
await expire_old_subscriptions(user_id)
if await has_valid_active_subscription(user_id):
raise HTTPException(400, "You already have an active subscription.")
now = datetime.utcnow()
trial = {
"user_id": user_id,
"plan": "trial",
"status": "active",
"start_date": now,
"end_date": now + timedelta(days=3),
"created_at": now,
"updated_at": now
}
await subscriptions_collection.insert_one(trial)
return {"status": "trial started", "expires": trial["end_date"].isoformat()}
@app.post("/create-subscription-order")
async def create_sub_order(sub_req: SubscriptionRequest): # Pass it from frontend (POST body or query param)
user_id = sub_req.user_id
plan = sub_req.plan
price_map = {
"monthly": "10.00",
"yearly": "100.00"
}
if plan not in price_map:
raise HTTPException(400, "Invalid plan selected")
price = price_map[plan]
request = OrdersCreateRequest()
request.prefer("return=representation")
request.request_body({
"intent": "CAPTURE",
"purchase_units": [{
"custom_id": user_id, # This can be any internal identifier
"amount": {
"currency_code": "USD",
"value": price
}
}],
"application_context": {
"return_url": f"https://arabic-chatbot-frontend.vercel.app/payments", # optional
"cancel_url": "https://arabic-chatbot-frontend.vercel.app/payments"
}
})
response = p_client.execute(request)
result = response.result
approve_url = next((link.href for link in result.links if link.rel == "approve"), None)
now = datetime.utcnow()
await orders_collection.insert_one({
"paypal_order_id": result.id,
"user_id": user_id,
"plan": plan,
"status": result.status,
"created_at": now,
"updated_at": now
})
return {
"order_id": result.id,
"status": result.status,
"approve_url": approve_url
}
@app.post("/capture-order/{order_id}", response_model=dict)
async def capture_order(order_id: str):
# ... your existing PayPal capture code here ...
req = OrdersCaptureRequest(order_id)
req.request_body({})
# ❌ Don't await this — it's synchronous
# resp = p_client.execute(req)
import asyncio
resp = await asyncio.to_thread(p_client.execute, req)
if resp.status_code != 201:
raise HTTPException(400, "Failed to capture order")
cap = resp.result
order_doc = await orders_collection.find_one({"paypal_order_id": order_id})
if not order_doc:
raise HTTPException(404, "Order not found")
# Expire old subscriptions before inserting new one
await expire_old_subscriptions(order_doc["user_id"])
if await has_valid_active_subscription(order_doc["user_id"]):
raise HTTPException(400, "You already have an active subscription.")
now = datetime.utcnow()
days = 30 if order_doc["plan"] == "monthly" else 365
sub = {
"user_id": order_doc["user_id"],
"plan": order_doc["plan"],
"status": "active",
"start_date": now,
"end_date": now + timedelta(days=days),
"paypal_order_id": order_id,
"created_at": now,
"updated_at": now
}
await subscriptions_collection.insert_one(sub)
# Update order status
await orders_collection.update_one(
{"paypal_order_id": order_id},
{"$set": {"status": cap.status}}
)
return {"status": cap.status}
# @app.post("/capture-order/{order_id}", response_model=dict)
# async def capture_order(order_id: str):
# req = OrdersCaptureRequest(order_id)
# req.request_body({})
# # ❌ Don't await this — it's synchronous
# # resp = p_client.execute(req)
# import asyncio
# resp = await asyncio.to_thread(p_client.execute, req)
# if resp.status_code != 201:
# raise HTTPException(400, "Failed to capture order")
# cap = resp.result
# order_doc = await orders_collection.find_one({"paypal_order_id": order_id})
# if not order_doc:
# raise HTTPException(404, "Order not found")
# await orders_collection.update_one(
# {"paypal_order_id": order_id},
# {"$set": {"status": cap.status}}
# )
# # Insert into subscriptions if plan exists
# if "plan" in order_doc:
# now = datetime.utcnow()
# days = 30 if order_doc["plan"] == "monthly" else 365
# sub = {
# "user_id": order_doc["user_id"],
# "plan": order_doc["plan"],
# "status": "active",
# "start_date": now,
# "end_date": now + timedelta(days=days),
# "paypal_order_id": order_id,
# "created_at": now,
# "updated_at": now
# }
# existing = await subscriptions_collection.find_one({"user_id": order_doc["user_id"]})
# if existing:
# # Optional: You can update the existing subscription instead of inserting
# await subscriptions_collection.update_one(
# {"user_id": order_doc["user_id"]},
# {"$set": {
# "plan": order_doc["plan"],
# "status": "active",
# "start_date": now,
# "end_date": now + timedelta(days=days),
# "paypal_order_id": order_id,
# "updated_at": now
# }}
# )
# else:
# await subscriptions_collection.insert_one(sub)
# return {"status": cap.status}
@app.post("/cancel", response_model=dict)
async def cancel_subscription(user_id: str):
subscription = await subscriptions_collection.find_one({
"user_id": user_id,
"status": "active"
})
if not subscription:
raise HTTPException(status_code=404, detail="Active subscription not found")
await subscriptions_collection.update_one(
{"_id": subscription["_id"]},
{"$set": {"status": "cancelled", "updated_at": datetime.utcnow()}}
)
return {"status": "cancelled", "message": "Subscription cancelled successfully"}
# @app.get("/payments")
# async def success_page(user_id: str, token: str = "", PayerID: str = ""):
# return {"message": "Payment approved", "user_id": user_id, "token": token, "PayerID": PayerID}
def serialize_subscription(subscription):
subscription["_id"] = str(subscription["_id"])
return subscription
@app.get("/subscriptions/{user_id}")
async def get_subscription_by_user_id(user_id: str):
subscription = await subscriptions_collection.find_one({"user_id": user_id})
if not subscription:
raise HTTPException(status_code=404, detail="Subscription not found")
return serialize_subscription(subscription)
@app.post("/webhook")
async def paypal_webhook(request: Request):
event = await request.json()
if event.get("event_type") == "PAYMENT.CAPTURE.COMPLETED":
rid = (event["resource"]
.get("supplementary_data", {})
.get("related_ids", {})
.get("order_id"))
if rid:
await orders_collection.update_one(
{"paypal_order_id": rid},
{"$set": {"status": "COMPLETED"}}
)
return {"ok": True}
# To Health Check APP
@app.get("/health", tags=["health"])
async def health_check():
try:
await client.admin.command("ping")
db_status = "connected"
except Exception as e:
db_status = f"error: {str(e)}"
# Check if vector store is accessible
try:
stats = get_collection_stats()
vector_db_status = "connected" if stats["exists"] else "empty but ready"
except Exception as e:
vector_db_status = f"error: {str(e)}"
return {
"status": "healthy",
"database": db_status,
"vector_database": vector_db_status,
"storage_type": "mongodb"
}
# if __name__ == "__main__":
# import uvicorn
# uvicorn.run(app, host="0.0.0.0", port=int(os.getenv("PORT", 8001)))