rohanshaw commited on
Commit
7e2225b
·
verified ·
1 Parent(s): 5516f0a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +172 -112
app.py CHANGED
@@ -1,112 +1,172 @@
1
- from fastapi import FastAPI, HTTPException, Depends, Request, status, Form
2
- from fastapi.middleware.cors import CORSMiddleware
3
- from fastapi.security import OAuth2PasswordRequestForm
4
- from jose import JWTError, jwt
5
- from passlib.context import CryptContext
6
- from datetime import datetime, timedelta
7
- from pymongo import MongoClient
8
- from bson import ObjectId
9
- from fastapi.responses import JSONResponse
10
- from dotenv import load_dotenv
11
- import os
12
-
13
- load_dotenv()
14
-
15
- # Constants
16
- SECRET_KEY = os.environ.get("SECRET_KEY")
17
- ALGORITHM = os.environ.get("ALGORITHM")
18
- ACCESS_TOKEN_EXPIRE_MINUTES = 60
19
-
20
- # Admin credentials
21
- fake_admin_db = {
22
- "admin": {
23
- "username": os.environ.get("ADMIN_USERNAME"),
24
- "hashed_password": os.environ.get("ADMIN_PASSWORD")
25
- }
26
- }
27
-
28
- pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
29
- client = MongoClient(os.environ.get("MONGODB_URI"))
30
- db = client["thehexatechdb"]
31
- collection = db["quotationsdb"]
32
-
33
- app = FastAPI()
34
-
35
- # CORS
36
- origins = ["*"]
37
- app.add_middleware(
38
- CORSMiddleware,
39
- allow_origins=origins,
40
- allow_credentials=True,
41
- allow_methods=["*"],
42
- allow_headers=["*"],
43
- )
44
-
45
- # Auth Utilities
46
- def verify_password(plain, hashed):
47
- return pwd_context.verify(plain, hashed)
48
-
49
- def get_password_hash(password):
50
- return pwd_context.hash(password)
51
-
52
- def authenticate_user(username: str, password: str):
53
- user = fake_admin_db.get(username)
54
- if not user or not verify_password(password, user["hashed_password"]):
55
- return False
56
- return {"username": username}
57
-
58
- def create_access_token(data: dict, expires_delta=None):
59
- to_encode = data.copy()
60
- expire = datetime.utcnow() + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
61
- to_encode.update({"exp": expire})
62
- return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
63
-
64
- def get_current_user(token: str = Depends(OAuth2PasswordRequestForm)):
65
- credentials_exception = HTTPException(
66
- status_code=status.HTTP_401_UNAUTHORIZED,
67
- detail="Invalid credentials",
68
- headers={"WWW-Authenticate": "Bearer"},
69
- )
70
- try:
71
- payload = jwt.decode(token.password, SECRET_KEY, algorithms=[ALGORITHM])
72
- username: str = payload.get("sub")
73
- if username is None:
74
- raise credentials_exception
75
- return {"username": username}
76
- except JWTError:
77
- raise credentials_exception
78
-
79
- # Routes
80
- @app.post("/api/submit")
81
- async def submit_query(name: str = Form(...), email: str = Form(...), message: str = Form(...)):
82
- query = {"name": name, "email": email, "message": message, "created_at": datetime.utcnow()}
83
- result = collection.insert_one(query)
84
- return JSONResponse(content={"id": str(result.inserted_id), "notify": True})
85
-
86
- @app.post("/api/login")
87
- async def login(form_data: OAuth2PasswordRequestForm = Depends()):
88
- user = authenticate_user(form_data.username, form_data.password)
89
- if not user:
90
- raise HTTPException(status_code=401, detail="Invalid credentials")
91
- token = create_access_token({"sub": user["username"]})
92
- return {"access_token": token, "token_type": "bearer"}
93
-
94
- @app.get("/api/queries")
95
- async def get_queries(token: str):
96
- try:
97
- jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
98
- queries = list(collection.find())
99
- for q in queries:
100
- q["_id"] = str(q["_id"])
101
- return queries
102
- except JWTError:
103
- raise HTTPException(status_code=401, detail="Invalid token")
104
-
105
- @app.delete("/api/queries/{query_id}")
106
- async def delete_query(query_id: str, token: str):
107
- try:
108
- jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
109
- result = collection.delete_one({"_id": ObjectId(query_id)})
110
- return {"deleted": result.deleted_count == 1}
111
- except JWTError:
112
- raise HTTPException(status_code=401, detail="Invalid token")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException, Depends, Request, status, Form, WebSocket, WebSocketDisconnect
2
+ from fastapi.middleware.cors import CORSMiddleware
3
+ from fastapi.security import OAuth2PasswordRequestForm
4
+ from jose import JWTError, jwt
5
+ from passlib.context import CryptContext
6
+ from datetime import datetime, timedelta
7
+ from pymongo import MongoClient
8
+ from bson import ObjectId
9
+ from fastapi.responses import JSONResponse
10
+ from dotenv import load_dotenv
11
+ import os
12
+ from starlette.websockets import WebSocketState
13
+ from jose import JWTError
14
+
15
+ load_dotenv()
16
+
17
+ class ConnectionManager:
18
+ def __init__(self):
19
+ self.active_connections: list[WebSocket] = []
20
+
21
+ async def connect(self, websocket: WebSocket):
22
+ await websocket.accept()
23
+ self.active_connections.append(websocket)
24
+
25
+ def disconnect(self, websocket: WebSocket):
26
+ self.active_connections.remove(websocket)
27
+
28
+ async def broadcast(self, message: dict):
29
+ for connection in self.active_connections:
30
+ try:
31
+ await connection.send_json(message)
32
+ except Exception:
33
+ self.active_connections.remove(connection)
34
+
35
+ manager = ConnectionManager()
36
+
37
+ # Constants
38
+ SECRET_KEY = os.environ.get("SECRET_KEY")
39
+ ALGORITHM = os.environ.get("ALGORITHM")
40
+ ACCESS_TOKEN_EXPIRE_MINUTES = 60
41
+
42
+ # Admin credentials
43
+ fake_admin_db = {
44
+ "admin": {
45
+ "username": os.environ.get("ADMIN_USERNAME"),
46
+ "hashed_password": os.environ.get("ADMIN_PASSWORD")
47
+ }
48
+ }
49
+
50
+ pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
51
+ client = MongoClient(os.environ.get("MONGODB_URI"))
52
+ db = client["thehexatechdb"]
53
+ collection = db["quotationsdb"]
54
+
55
+ app = FastAPI()
56
+
57
+ # CORS
58
+ origins = ["*"]
59
+ app.add_middleware(
60
+ CORSMiddleware,
61
+ allow_origins=origins,
62
+ allow_credentials=True,
63
+ allow_methods=["*"],
64
+ allow_headers=["*"],
65
+ )
66
+
67
+ # Auth Utilities
68
+ def verify_password(plain, hashed):
69
+ return pwd_context.verify(plain, hashed)
70
+
71
+ def get_password_hash(password):
72
+ return pwd_context.hash(password)
73
+
74
+ def authenticate_user(username: str, password: str):
75
+ user = fake_admin_db.get(username)
76
+ if not user or not verify_password(password, user["hashed_password"]):
77
+ return False
78
+ return {"username": username}
79
+
80
+ def create_access_token(data: dict, expires_delta=None):
81
+ to_encode = data.copy()
82
+ expire = datetime.utcnow() + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
83
+ to_encode.update({"exp": expire})
84
+ return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
85
+
86
+ def get_current_user(token: str = Depends(OAuth2PasswordRequestForm)):
87
+ credentials_exception = HTTPException(
88
+ status_code=status.HTTP_401_UNAUTHORIZED,
89
+ detail="Invalid credentials",
90
+ headers={"WWW-Authenticate": "Bearer"},
91
+ )
92
+ try:
93
+ payload = jwt.decode(token.password, SECRET_KEY, algorithms=[ALGORITHM])
94
+ username: str = payload.get("sub")
95
+ if username is None:
96
+ raise credentials_exception
97
+ return {"username": username}
98
+ except JWTError:
99
+ raise credentials_exception
100
+
101
+ # Routes
102
+ @app.post("/api/submit")
103
+ async def submit_query(name: str = Form(...), email: str = Form(...), message: str = Form(...)):
104
+ query = {"name": name, "email": email, "message": message, "created_at": datetime.utcnow()}
105
+ result = collection.insert_one(query)
106
+ query["_id"] = str(result.inserted_id)
107
+
108
+ total_count = collection.count_documents({})
109
+ await manager.broadcast({
110
+ "event": "new_quote",
111
+ "data": {
112
+ "name": name,
113
+ "email": email,
114
+ "message": message,
115
+ "total_count": total_count
116
+ }
117
+ })
118
+ return JSONResponse(content={"id": query["_id"], "notify": True})
119
+
120
+
121
+ @app.post("/api/login")
122
+ async def login(form_data: OAuth2PasswordRequestForm = Depends()):
123
+ user = authenticate_user(form_data.username, form_data.password)
124
+ if not user:
125
+ raise HTTPException(status_code=401, detail="Invalid credentials")
126
+ token = create_access_token({"sub": user["username"], "role": "admin"})
127
+ return {"access_token": token, "token_type": "bearer"}
128
+
129
+ @app.get("/api/queries")
130
+ async def get_queries(token: str):
131
+ try:
132
+ jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
133
+ queries = list(collection.find())
134
+ for q in queries:
135
+ q["_id"] = str(q["_id"])
136
+ return queries
137
+ except JWTError:
138
+ raise HTTPException(status_code=401, detail="Invalid token")
139
+
140
+ @app.delete("/api/queries/{query_id}")
141
+ async def delete_query(query_id: str, token: str):
142
+ try:
143
+ jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
144
+ result = collection.delete_one({"_id": ObjectId(query_id)})
145
+ return {"deleted": result.deleted_count == 1}
146
+ except JWTError:
147
+ raise HTTPException(status_code=401, detail="Invalid token")
148
+
149
+
150
+ @app.websocket("/ws/notifications")
151
+ async def websocket_endpoint(websocket: WebSocket):
152
+ token = websocket.query_params.get("token")
153
+ try:
154
+ payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
155
+ username = payload.get("sub")
156
+ role = payload.get("role")
157
+ if not username or role != "admin":
158
+ await websocket.close(code=1008)
159
+ return
160
+ except JWTError:
161
+ await websocket.close(code=1008)
162
+ return
163
+
164
+ await manager.connect(websocket)
165
+ print(f"[WS CONNECT] {username} (admin) connected at {datetime.utcnow()} from {websocket.client.host}")
166
+
167
+ try:
168
+ while True:
169
+ await websocket.receive_text()
170
+ except WebSocketDisconnect:
171
+ print(f"[WS DISCONNECT] {username} disconnected at {datetime.utcnow()}")
172
+ manager.disconnect(websocket)