paijo77 commited on
Commit
ae50c7b
·
verified ·
1 Parent(s): b778408

update app/db_storage.py

Browse files
Files changed (1) hide show
  1. app/db_storage.py +571 -0
app/db_storage.py ADDED
@@ -0,0 +1,571 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sqlalchemy.ext.asyncio import AsyncSession
2
+ from sqlalchemy import select, func, and_, insert
3
+ from sqlalchemy.orm import selectinload
4
+ from sqlalchemy.dialects.sqlite import insert as sqlite_insert
5
+ from typing import List, Optional
6
+ from datetime import datetime
7
+ import logging
8
+
9
+ from app.db_models import User, ProxySource, Proxy
10
+ from app.validator import proxy_validator
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ class DatabaseStorage:
16
+ def __init__(self, enable_validation: bool = True):
17
+ self.enable_validation = enable_validation
18
+
19
+ async def create_admin_user(
20
+ self, session: AsyncSession, email: str = "admin@1proxy.local"
21
+ ) -> User:
22
+ result = await session.execute(select(User).where(User.email == email))
23
+ user = result.scalar_one_or_none()
24
+
25
+ if not user:
26
+ user = User(
27
+ oauth_provider="local",
28
+ oauth_id="admin",
29
+ email=email,
30
+ username="admin",
31
+ role="admin",
32
+ avatar_url=None,
33
+ )
34
+ session.add(user)
35
+ await session.commit()
36
+ await session.refresh(user)
37
+
38
+ return user
39
+
40
+ async def seed_admin_sources(self, session: AsyncSession, admin_user_id: int):
41
+ from app.sources import SourceRegistry
42
+
43
+ for source_config in SourceRegistry.SOURCES:
44
+ result = await session.execute(
45
+ select(ProxySource).where(ProxySource.url == str(source_config.url))
46
+ )
47
+ existing = result.scalar_one_or_none()
48
+
49
+ if not existing:
50
+ source = ProxySource(
51
+ user_id=admin_user_id,
52
+ url=str(source_config.url),
53
+ type=source_config.type.value
54
+ if hasattr(source_config.type, "value")
55
+ else str(source_config.type),
56
+ name=str(source_config.url).split("/")[-2],
57
+ enabled=source_config.enabled,
58
+ validated=True,
59
+ is_admin_source=True,
60
+ is_paid=False,
61
+ )
62
+ session.add(source)
63
+
64
+ await session.commit()
65
+
66
+ async def add_proxy(
67
+ self, session: AsyncSession, proxy_data: dict, source_id: Optional[int] = None
68
+ ) -> Optional[Proxy]:
69
+ result = await session.execute(
70
+ select(Proxy).where(Proxy.url == proxy_data["url"])
71
+ )
72
+ existing = result.scalar_one_or_none()
73
+
74
+ if existing:
75
+ existing.last_seen = datetime.utcnow()
76
+ existing.updated_at = datetime.utcnow()
77
+ if source_id and not existing.source_id:
78
+ existing.source_id = source_id
79
+ await session.commit()
80
+ return existing
81
+
82
+ proxy = Proxy(
83
+ source_id=source_id,
84
+ url=proxy_data["url"],
85
+ protocol=proxy_data.get("protocol", "http"),
86
+ ip=proxy_data.get("ip"),
87
+ port=proxy_data.get("port"),
88
+ is_working=True,
89
+ )
90
+ session.add(proxy)
91
+ await session.commit()
92
+ await session.refresh(proxy)
93
+ return proxy
94
+
95
+ async def add_proxy_with_validation(
96
+ self, session: AsyncSession, proxy_data: dict, source_id: Optional[int] = None
97
+ ) -> Optional[Proxy]:
98
+ """Add proxy with comprehensive validation"""
99
+ url = proxy_data.get("url")
100
+ ip = proxy_data.get("ip")
101
+
102
+ if not url or not ip:
103
+ return None
104
+
105
+ if self.enable_validation:
106
+ validation_result = await proxy_validator.validate_comprehensive(url, ip)
107
+
108
+ if not validation_result.success:
109
+ return None
110
+
111
+ proxy_data.update(
112
+ {
113
+ "latency_ms": validation_result.latency_ms,
114
+ "anonymity": validation_result.anonymity,
115
+ "can_access_google": validation_result.can_access_google,
116
+ "country_code": validation_result.country_code,
117
+ "country_name": validation_result.country_name,
118
+ "proxy_type": validation_result.proxy_type,
119
+ "quality_score": validation_result.quality_score,
120
+ "is_working": True,
121
+ "validation_status": "validated",
122
+ "last_validated": datetime.utcnow(),
123
+ }
124
+ )
125
+
126
+ return await self.add_proxy(session, proxy_data, source_id)
127
+
128
+ async def add_proxies(self, session: AsyncSession, proxies_data: List[dict]) -> int:
129
+ """
130
+ Efficiently add proxies using bulk insert with ON CONFLICT DO UPDATE.
131
+ This avoids N queries for N proxies and instead uses a single bulk operation.
132
+ """
133
+ if not proxies_data:
134
+ return 0
135
+
136
+ now = datetime.utcnow()
137
+ prepared_data = []
138
+
139
+ for proxy_data in proxies_data:
140
+ try:
141
+ # Extract or construct URL
142
+ url = proxy_data.get("url")
143
+ if not url:
144
+ ip = proxy_data.get("ip")
145
+ port = proxy_data.get("port")
146
+ protocol = proxy_data.get("protocol", "http")
147
+ if ip and port:
148
+ url = f"{protocol}://{ip}:{port}"
149
+ else:
150
+ continue
151
+
152
+ # Prepare data for bulk insert
153
+ prepared_data.append(
154
+ {
155
+ "url": url,
156
+ "protocol": proxy_data.get("protocol", "http"),
157
+ "ip": proxy_data.get("ip"),
158
+ "port": proxy_data.get("port"),
159
+ "country_code": proxy_data.get("country_code"),
160
+ "country_name": proxy_data.get("country_name"),
161
+ "city": proxy_data.get("city"),
162
+ "latency_ms": proxy_data.get("latency_ms"),
163
+ "speed_mbps": proxy_data.get("speed_mbps"),
164
+ "anonymity": proxy_data.get("anonymity"),
165
+ "proxy_type": proxy_data.get("proxy_type"),
166
+ "quality_score": proxy_data.get("quality_score"),
167
+ "is_working": True,
168
+ "validation_status": proxy_data.get(
169
+ "validation_status", "pending"
170
+ ),
171
+ "last_validated": proxy_data.get("last_validated"),
172
+ "first_seen": now,
173
+ "last_seen": now,
174
+ "created_at": now,
175
+ "updated_at": now,
176
+ }
177
+ )
178
+ except Exception as e:
179
+ logger.error(f"Error preparing proxy data: {e}")
180
+ continue
181
+
182
+ if not prepared_data:
183
+ return 0
184
+
185
+ try:
186
+ batch_size = 100
187
+ total_inserted = 0
188
+
189
+ for i in range(0, len(prepared_data), batch_size):
190
+ batch = prepared_data[i : i + batch_size]
191
+
192
+ for proxy_dict in batch:
193
+ try:
194
+ result = await session.execute(
195
+ select(Proxy).where(Proxy.url == proxy_dict["url"])
196
+ )
197
+ existing = result.scalar_one_or_none()
198
+
199
+ if existing:
200
+ existing.last_seen = now
201
+ existing.updated_at = now
202
+ else:
203
+ proxy = Proxy(**proxy_dict)
204
+ session.add(proxy)
205
+ total_inserted += 1
206
+
207
+ except Exception as e:
208
+ logger.error(
209
+ f"Error inserting proxy {proxy_dict.get('url')}: {e}"
210
+ )
211
+ continue
212
+
213
+ await session.commit()
214
+
215
+ logger.info(
216
+ f"Successfully processed {len(prepared_data)} proxies, inserted {total_inserted} new ones"
217
+ )
218
+ return len(prepared_data)
219
+
220
+ except Exception as e:
221
+ logger.error(f"Error in bulk insert: {e}")
222
+ await session.rollback()
223
+ return await self._add_proxies_fallback(session, prepared_data)
224
+
225
+ async def _add_proxies_fallback(
226
+ self, session: AsyncSession, proxies_data: List[dict]
227
+ ) -> int:
228
+ """Fallback method for adding proxies one by one if bulk insert fails."""
229
+ added_count = 0
230
+ now = datetime.utcnow()
231
+
232
+ for proxy_data in proxies_data:
233
+ try:
234
+ url = proxy_data.get("url")
235
+ if not url:
236
+ continue
237
+
238
+ # Check if exists
239
+ result = await session.execute(select(Proxy).where(Proxy.url == url))
240
+ existing = result.scalar_one_or_none()
241
+
242
+ if existing:
243
+ existing.last_seen = now
244
+ existing.updated_at = now
245
+ else:
246
+ proxy = Proxy(**proxy_data)
247
+ session.add(proxy)
248
+ added_count += 1
249
+
250
+ except Exception as e:
251
+ logger.error(f"Error in fallback insert for proxy: {e}")
252
+ continue
253
+
254
+ await session.commit()
255
+ return added_count
256
+
257
+ async def validate_and_update_proxies(
258
+ self,
259
+ session: AsyncSession,
260
+ proxy_ids: Optional[List[int]] = None,
261
+ limit: int = 50,
262
+ ) -> dict:
263
+ """Validate pending proxies and update their status"""
264
+ if proxy_ids:
265
+ query = select(Proxy).where(
266
+ Proxy.id.in_(proxy_ids), Proxy.validation_status == "pending"
267
+ )
268
+ else:
269
+ query = (
270
+ select(Proxy).where(Proxy.validation_status == "pending").limit(limit)
271
+ )
272
+
273
+ result = await session.execute(query)
274
+ proxies_to_validate = result.scalars().all()
275
+
276
+ if not proxies_to_validate:
277
+ return {"validated": 0, "failed": 0, "total": 0}
278
+
279
+ proxy_tuples = [(p.url, p.ip) for p in proxies_to_validate if p.ip]
280
+
281
+ if not proxy_tuples:
282
+ return {"validated": 0, "failed": 0, "total": 0}
283
+
284
+ validation_results = await proxy_validator.validate_batch(proxy_tuples)
285
+
286
+ validated_count = 0
287
+ failed_count = 0
288
+
289
+ for proxy in proxies_to_validate:
290
+ matching_result = next(
291
+ (r for url, r in validation_results if url == proxy.url), None
292
+ )
293
+
294
+ if not matching_result:
295
+ continue
296
+
297
+ if matching_result.success:
298
+ proxy.latency_ms = matching_result.latency_ms
299
+ proxy.anonymity = matching_result.anonymity
300
+ proxy.can_access_google = matching_result.can_access_google
301
+ proxy.country_code = matching_result.country_code
302
+ proxy.country_name = matching_result.country_name
303
+ proxy.proxy_type = matching_result.proxy_type
304
+ proxy.quality_score = matching_result.quality_score
305
+ proxy.is_working = True
306
+ proxy.validation_status = "validated"
307
+ proxy.last_validated = datetime.utcnow()
308
+ proxy.validation_failures = 0
309
+ validated_count += 1
310
+ else:
311
+ proxy.is_working = False
312
+ proxy.validation_status = "failed"
313
+ proxy.validation_failures = (proxy.validation_failures or 0) + 1
314
+ failed_count += 1
315
+
316
+ await session.commit()
317
+
318
+ return {
319
+ "validated": validated_count,
320
+ "failed": failed_count,
321
+ "total": len(proxies_to_validate),
322
+ }
323
+
324
+ async def get_proxies(
325
+ self,
326
+ session: AsyncSession,
327
+ protocol: Optional[str] = None,
328
+ country_code: Optional[str] = None,
329
+ anonymity: Optional[str] = None,
330
+ min_quality: Optional[int] = None,
331
+ is_working: bool = True,
332
+ validation_status: str = "validated",
333
+ limit: int = 100,
334
+ offset: int = 0,
335
+ order_by: str = "quality_score",
336
+ ) -> tuple[List[Proxy], int]:
337
+ # Use selectinload to prevent N+1 query problem when accessing proxy.source
338
+ query = (
339
+ select(Proxy)
340
+ .options(selectinload(Proxy.source))
341
+ .where(
342
+ Proxy.is_working == is_working,
343
+ Proxy.validation_status == validation_status,
344
+ )
345
+ )
346
+
347
+ if protocol:
348
+ query = query.where(Proxy.protocol == protocol)
349
+ if country_code:
350
+ query = query.where(Proxy.country_code == country_code)
351
+ if anonymity:
352
+ query = query.where(Proxy.anonymity == anonymity)
353
+ if min_quality:
354
+ query = query.where(Proxy.quality_score >= min_quality)
355
+
356
+ count_query = select(func.count()).select_from(query.subquery())
357
+ total_result = await session.execute(count_query)
358
+ total = total_result.scalar()
359
+
360
+ if order_by == "latency_ms":
361
+ query = query.order_by(Proxy.latency_ms.asc().nulls_last())
362
+ elif order_by == "quality_score":
363
+ query = query.order_by(Proxy.quality_score.desc().nulls_last())
364
+ elif order_by == "created_at":
365
+ query = query.order_by(Proxy.created_at.desc())
366
+
367
+ query = query.limit(limit).offset(offset)
368
+ result = await session.execute(query)
369
+ proxies = result.scalars().all()
370
+
371
+ return list(proxies), total
372
+
373
+ async def get_sources(
374
+ self,
375
+ session: AsyncSession,
376
+ user_id: Optional[int] = None,
377
+ enabled_only: bool = False,
378
+ ) -> List[ProxySource]:
379
+ query = select(ProxySource)
380
+
381
+ if user_id:
382
+ query = query.where(ProxySource.user_id == user_id)
383
+ if enabled_only:
384
+ query = query.where(ProxySource.enabled == True)
385
+
386
+ result = await session.execute(query)
387
+ return list(result.scalars().all())
388
+
389
+ async def get_random_proxy(
390
+ self,
391
+ session: AsyncSession,
392
+ protocol: Optional[str] = None,
393
+ country_code: Optional[str] = None,
394
+ min_quality: Optional[int] = None,
395
+ anonymity: Optional[str] = None,
396
+ max_latency: Optional[int] = None,
397
+ ) -> Optional[Proxy]:
398
+ query = select(Proxy).where(
399
+ Proxy.is_working == True, Proxy.validation_status == "validated"
400
+ )
401
+
402
+ if protocol:
403
+ query = query.where(Proxy.protocol == protocol)
404
+ if country_code:
405
+ query = query.where(Proxy.country_code == country_code)
406
+ if min_quality:
407
+ query = query.where(Proxy.quality_score >= min_quality)
408
+ if anonymity:
409
+ query = query.where(Proxy.anonymity == anonymity)
410
+ if max_latency:
411
+ query = query.where(Proxy.latency_ms <= max_latency)
412
+
413
+ query = query.order_by(func.random()).limit(1)
414
+ result = await session.execute(query)
415
+ return result.scalar_one_or_none()
416
+
417
+ async def get_stats(self, session: AsyncSession) -> dict:
418
+ """
419
+ Get proxy statistics efficiently using a single GROUP BY query
420
+ instead of multiple separate queries.
421
+ """
422
+ # Single query with GROUP BY for protocol counts
423
+ result = await session.execute(
424
+ select(Proxy.protocol, func.count(Proxy.id).label("count"))
425
+ .where(Proxy.validation_status == "validated")
426
+ .group_by(Proxy.protocol)
427
+ )
428
+
429
+ by_protocol = {}
430
+ total = 0
431
+
432
+ for row in result:
433
+ protocol = row.protocol if row.protocol else "unknown"
434
+ count = row.count
435
+ by_protocol[protocol] = count
436
+ total += count
437
+
438
+ # Ensure all expected protocols are present (even if 0)
439
+ expected_protocols = [
440
+ "http",
441
+ "https",
442
+ "vmess",
443
+ "vless",
444
+ "trojan",
445
+ "shadowsocks",
446
+ ]
447
+ for protocol in expected_protocols:
448
+ if protocol not in by_protocol:
449
+ by_protocol[protocol] = 0
450
+
451
+ return {"total_proxies": total, "by_protocol": by_protocol}
452
+
453
+ async def count_proxies(self, session: AsyncSession) -> int:
454
+ result = await session.execute(select(func.count()).select_from(Proxy))
455
+ return result.scalar() or 0
456
+
457
+ async def count_sources(self, session: AsyncSession) -> int:
458
+ result = await session.execute(select(func.count()).select_from(ProxySource))
459
+ return result.scalar() or 0
460
+
461
+ async def count_users(self, session: AsyncSession) -> int:
462
+ result = await session.execute(select(func.count()).select_from(User))
463
+ return result.scalar() or 0
464
+
465
+ async def get_or_create_user(
466
+ self,
467
+ session: AsyncSession,
468
+ oauth_provider: str,
469
+ oauth_id: str,
470
+ email: str,
471
+ username: str,
472
+ role: str = "user",
473
+ avatar_url: Optional[str] = None,
474
+ ) -> User:
475
+ result = await session.execute(
476
+ select(User).where(
477
+ and_(User.oauth_provider == oauth_provider, User.oauth_id == oauth_id)
478
+ )
479
+ )
480
+ user = result.scalar_one_or_none()
481
+
482
+ if not user:
483
+ user = User(
484
+ oauth_provider=oauth_provider,
485
+ oauth_id=oauth_id,
486
+ email=email,
487
+ username=username,
488
+ role=role,
489
+ avatar_url=avatar_url,
490
+ )
491
+ session.add(user)
492
+ await session.commit()
493
+ await session.refresh(user)
494
+
495
+ return user
496
+
497
+ async def create_notification(
498
+ self,
499
+ session: AsyncSession,
500
+ user_id: int,
501
+ notification_type: str,
502
+ title: str,
503
+ message: str,
504
+ severity: str = "info",
505
+ ):
506
+ from app.db_models import Notification
507
+
508
+ notification = Notification(
509
+ user_id=user_id,
510
+ type=notification_type,
511
+ title=title,
512
+ message=message,
513
+ severity=severity,
514
+ )
515
+ session.add(notification)
516
+ await session.commit()
517
+ await session.refresh(notification)
518
+ return notification
519
+
520
+ async def get_notifications(
521
+ self,
522
+ session: AsyncSession,
523
+ user_id: int,
524
+ unread_only: bool = False,
525
+ limit: int = 50,
526
+ ):
527
+ from app.db_models import Notification
528
+
529
+ query = select(Notification).where(Notification.user_id == user_id)
530
+ if unread_only:
531
+ query = query.where(Notification.read == False)
532
+ query = query.order_by(Notification.created_at.desc()).limit(limit)
533
+ result = await session.execute(query)
534
+ return list(result.scalars().all())
535
+
536
+ async def mark_notification_read(
537
+ self, session: AsyncSession, user_id: int, notification_id: int
538
+ ) -> bool:
539
+ from app.db_models import Notification
540
+
541
+ result = await session.execute(
542
+ select(Notification).where(
543
+ and_(
544
+ Notification.id == notification_id, Notification.user_id == user_id
545
+ )
546
+ )
547
+ )
548
+ notification = result.scalar_one_or_none()
549
+ if notification:
550
+ notification.read = True
551
+ await session.commit()
552
+ return True
553
+ return False
554
+
555
+ async def mark_all_notifications_read(
556
+ self, session: AsyncSession, user_id: int
557
+ ) -> int:
558
+ from app.db_models import Notification
559
+ from sqlalchemy import update
560
+
561
+ stmt = (
562
+ update(Notification)
563
+ .where(and_(Notification.user_id == user_id, Notification.read == False))
564
+ .values(read=True)
565
+ )
566
+ result = await session.execute(stmt)
567
+ await session.commit()
568
+ return result.rowcount
569
+
570
+
571
+ db_storage = DatabaseStorage()