GuXSs commited on
Commit
deb411f
Β·
verified Β·
1 Parent(s): 4b7b91f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +152 -115
app.py CHANGED
@@ -238,6 +238,27 @@ class DatabaseManager:
238
  logger.error(f"Error getting all users stats: {e}")
239
  return []
240
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
241
  # ----------------- Model Manager -----------------
242
  class ModelManager:
243
  def __init__(self, config: Config):
@@ -293,7 +314,7 @@ class ModelManager:
293
  logger.error(f"❌ Error loading model: {e}")
294
  self.model_loaded = False
295
 
296
- async def generate(self, request: GenerationRequest) -> Tuple[bool, str, int]:
297
  """Generate text with the model"""
298
  if not self.model_loaded:
299
  return False, "❌ Model not loaded", 0
@@ -380,6 +401,7 @@ class GemmaSaaSService:
380
  self.db = DatabaseManager(self.config)
381
  self.model_manager = ModelManager(self.config)
382
  self.analytics_manager = AnalyticsManager(self.db)
 
383
  self._validate_config()
384
 
385
  def _validate_config(self):
@@ -417,17 +439,33 @@ class GemmaSaaSService:
417
  error="Internal service error"
418
  )
419
 
420
- async def generate_text(self, prompt: str, api_key: str, **kwargs) -> APIResponse:
421
- """Generate text with authentication and rate limiting"""
422
  try:
423
- # Validate API key
424
- user = await self.db.validate_api_key(api_key)
425
- if not user:
426
  return APIResponse(
427
  success=False,
428
- error="⚠️ Invalid API key"
429
  )
430
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
431
  # Check rate limit
432
  can_make_request, requests_used = await self.db.check_rate_limit(user['id'])
433
  if not can_make_request:
@@ -442,7 +480,7 @@ class GemmaSaaSService:
442
 
443
  # Generate text
444
  request = GenerationRequest(prompt=prompt, **kwargs)
445
- success, text, tokens_used = await self.model_manager.generate(request)
446
 
447
  if success:
448
  # Update usage statistics
@@ -470,14 +508,22 @@ class GemmaSaaSService:
470
  error="Internal service error"
471
  )
472
 
473
- async def get_user_stats(self, api_key: str) -> APIResponse:
474
- """Get user statistics"""
475
  try:
476
- user = await self.db.validate_api_key(api_key)
 
 
 
 
 
 
 
 
477
  if not user:
478
  return APIResponse(
479
  success=False,
480
- error="Invalid API key"
481
  )
482
 
483
  # Generate analytics plot
@@ -505,11 +551,19 @@ class GemmaSaaSService:
505
  error="Error retrieving stats"
506
  )
507
 
508
- async def get_admin_stats(self, api_key: str) -> APIResponse:
509
  """Get admin statistics (only for admin users)"""
510
  try:
511
- user = await self.db.validate_api_key(api_key)
512
- if not user or user.get('email') != self.config.ADMIN_EMAIL:
 
 
 
 
 
 
 
 
513
  return APIResponse(
514
  success=False,
515
  error="Unauthorized: Admin access required"
@@ -810,6 +864,25 @@ class GradioInterface:
810
  .dark-mode-plot {
811
  background: transparent !important;
812
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
813
  """
814
 
815
  def create_header(self):
@@ -847,6 +920,24 @@ class GradioInterface:
847
 
848
  return gr.HTML(examples_html)
849
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
850
  async def create_interface(self):
851
  """Create the enhanced Gradio interface"""
852
  with gr.Blocks(
@@ -866,12 +957,7 @@ class GradioInterface:
866
  with gr.Row():
867
  with gr.Column(scale=1):
868
  gr.Markdown("### βš™οΈ Configuration")
869
- api_key_playground = gr.Textbox(
870
- label="πŸ”‘ API Key",
871
- type="password",
872
- placeholder="Enter your API key...",
873
- elem_classes=["input"]
874
- )
875
 
876
  with gr.Accordion("Advanced Settings", open=False):
877
  max_tokens_input = gr.Slider(
@@ -932,59 +1018,31 @@ class GradioInterface:
932
  )
933
 
934
  # Stats display
935
- generation_stats = gr.JSON(
936
- label="πŸ“Š Generation Statistics",
937
- visible=False
938
- )
939
 
940
  # Profile Tab
941
  with gr.Tab("πŸ‘€ Profile", elem_classes=["card"]):
942
  with gr.Row():
943
  with gr.Column(scale=1):
944
- gr.Markdown("### πŸ†• Create Account")
945
- name_input = gr.Textbox(
946
- label="πŸ‘€ Full Name",
947
- elem_classes=["input"]
948
- )
949
- email_input = gr.Textbox(
950
- label="πŸ“§ Email Address",
951
- elem_classes=["input"]
952
- )
953
- plan_input = gr.Dropdown(
954
- choices=["free", "pro", "enterprise"],
955
- value="free",
956
- label="πŸ“‹ Plan",
957
  elem_classes=["input"]
958
  )
959
-
960
- create_btn = gr.Button(
961
- "✨ Create API Key",
962
- elem_classes=["btn", "btn-primary"],
963
- variant="primary"
964
- )
965
-
966
- creation_status = gr.HTML()
967
- api_key_display = gr.Textbox(
968
- label="πŸ”‘ Your API Key",
969
- interactive=False,
970
- visible=False,
971
- elem_classes=["input"]
972
  )
 
973
 
974
  with gr.Column(scale=1):
975
  gr.Markdown("### πŸ“Š Account Statistics")
976
- stats_api_key = gr.Textbox(
977
- label="πŸ”‘ API Key",
978
- type="password",
979
- placeholder="Enter API key to view stats",
980
- elem_classes=["input"]
981
- )
982
-
983
  refresh_stats_btn = gr.Button(
984
  "πŸ”„ Refresh Stats",
985
  elem_classes=["btn", "btn-secondary"]
986
  )
987
-
988
  user_stats_display = gr.HTML()
989
 
990
  # Analytics Tab
@@ -993,10 +1051,10 @@ class GradioInterface:
993
 
994
  with gr.Row():
995
  with gr.Column(scale=1):
996
- analytics_api_key = gr.Textbox(
997
- label="πŸ”‘ API Key",
998
  type="password",
999
- placeholder="Enter API key to view analytics",
1000
  elem_classes=["input"]
1001
  )
1002
  refresh_analytics_btn = gr.Button(
@@ -1012,10 +1070,10 @@ class GradioInterface:
1012
 
1013
  with gr.Row():
1014
  with gr.Column(scale=1):
1015
- admin_api_key = gr.Textbox(
1016
- label="πŸ”‘ Admin API Key",
1017
  type="password",
1018
- placeholder="Enter admin API key",
1019
  elem_classes=["input"]
1020
  )
1021
  refresh_admin_btn = gr.Button(
@@ -1030,13 +1088,13 @@ class GradioInterface:
1030
  self.create_footer()
1031
 
1032
  # Event Handlers
1033
- async def handle_generation(prompt, api_key, max_tokens, temperature, top_k, top_p, repetition_penalty):
1034
- if not api_key.strip():
1035
- return "⚠️ Please enter your API key", {}, False
1036
 
1037
  response = await self.service.generate_text(
1038
  prompt=prompt,
1039
- api_key=api_key,
1040
  max_tokens=max_tokens,
1041
  temperature=temperature,
1042
  top_k=top_k,
@@ -1063,48 +1121,27 @@ class GradioInterface:
1063
  """
1064
  return (
1065
  response.data["generated_text"],
1066
- stats_html,
1067
- True
1068
  )
1069
  else:
1070
  return (
1071
  response.error,
1072
- f'<div class="alert alert-error">{response.error}</div>',
1073
- False
1074
  )
1075
 
1076
- async def handle_user_creation(name, email, plan):
1077
- if not name or not name.strip():
1078
- return (
1079
- f'<div class="alert alert-error">❌ Name is required</div>',
1080
- "",
1081
- False
1082
- )
1083
 
1084
- if not email or not email.strip():
1085
- return (
1086
- f'<div class="alert alert-error">❌ Email is required</div>',
1087
- "",
1088
- False
1089
- )
1090
-
1091
- response = await self.service.create_user(name, email, plan)
1092
 
1093
  if response.success:
1094
- return (
1095
- f'<div class="alert alert-success">βœ… Account created successfully! Your API key is below.</div>',
1096
- response.data["api_key"],
1097
- True
1098
- )
1099
  else:
1100
- return (
1101
- f'<div class="alert alert-error">❌ {response.error}</div>',
1102
- "",
1103
- False
1104
- )
1105
 
1106
- async def handle_stats_refresh(api_key):
1107
- response = await self.service.get_user_stats(api_key)
1108
 
1109
  if response.success:
1110
  stats = response.data
@@ -1140,8 +1177,8 @@ class GradioInterface:
1140
  else:
1141
  return f'<div class="alert alert-error">❌ {response.error}</div>'
1142
 
1143
- async def handle_analytics_refresh(api_key):
1144
- response = await self.service.get_user_stats(api_key)
1145
 
1146
  if response.success:
1147
  return response.data["plot"]
@@ -1153,8 +1190,8 @@ class GradioInterface:
1153
  font=dict(color='white')
1154
  )
1155
 
1156
- async def handle_admin_refresh(api_key):
1157
- response = await self.service.get_admin_stats(api_key)
1158
 
1159
  if response.success:
1160
  stats = response.data
@@ -1227,39 +1264,39 @@ class GradioInterface:
1227
  generate_btn.click(
1228
  fn=handle_generation,
1229
  inputs=[
1230
- prompt_input, api_key_playground, max_tokens_input,
1231
  temperature_input, top_k_input, top_p_input, repetition_penalty_input
1232
  ],
1233
- outputs=[output_text, generation_stats, generation_stats]
1234
  )
1235
 
1236
  clear_btn.click(
1237
- fn=lambda: ("", "", False),
1238
  inputs=[],
1239
- outputs=[prompt_input, output_text, generation_stats]
1240
  )
1241
 
1242
- create_btn.click(
1243
- fn=handle_user_creation,
1244
- inputs=[name_input, email_input, plan_input],
1245
- outputs=[creation_status, api_key_display, api_key_display]
1246
  )
1247
 
1248
  refresh_stats_btn.click(
1249
  fn=handle_stats_refresh,
1250
- inputs=[stats_api_key],
1251
  outputs=[user_stats_display]
1252
  )
1253
 
1254
  refresh_analytics_btn.click(
1255
  fn=handle_analytics_refresh,
1256
- inputs=[analytics_api_key],
1257
  outputs=[analytics_plot]
1258
  )
1259
 
1260
  refresh_admin_btn.click(
1261
  fn=handle_admin_refresh,
1262
- inputs=[admin_api_key],
1263
  outputs=[admin_plot, admin_stats_display]
1264
  )
1265
 
 
238
  logger.error(f"Error getting all users stats: {e}")
239
  return []
240
 
241
+ # ----------------- Hugging Face Auth -----------------
242
+ class HuggingFaceAuth:
243
+ @staticmethod
244
+ async def validate_token(token: str) -> Tuple[bool, Optional[Dict]]:
245
+ """Validate Hugging Face token and return user info"""
246
+ try:
247
+ async with aiohttp.ClientSession() as session:
248
+ headers = {"Authorization": f"Bearer {token}"}
249
+ async with session.get(
250
+ "https://huggingface.co/api/whoami",
251
+ headers=headers
252
+ ) as response:
253
+ if response.status == 200:
254
+ user_info = await response.json()
255
+ return True, user_info
256
+ else:
257
+ return False, None
258
+ except Exception as e:
259
+ logger.error(f"Error validating HF token: {e}")
260
+ return False, None
261
+
262
  # ----------------- Model Manager -----------------
263
  class ModelManager:
264
  def __init__(self, config: Config):
 
314
  logger.error(f"❌ Error loading model: {e}")
315
  self.model_loaded = False
316
 
317
+ async def generate(self, request: GenerationRequest, hf_token: str = None) -> Tuple[bool, str, int]:
318
  """Generate text with the model"""
319
  if not self.model_loaded:
320
  return False, "❌ Model not loaded", 0
 
401
  self.db = DatabaseManager(self.config)
402
  self.model_manager = ModelManager(self.config)
403
  self.analytics_manager = AnalyticsManager(self.db)
404
+ self.hf_auth = HuggingFaceAuth()
405
  self._validate_config()
406
 
407
  def _validate_config(self):
 
439
  error="Internal service error"
440
  )
441
 
442
+ async def generate_text(self, prompt: str, hf_token: str, **kwargs) -> APIResponse:
443
+ """Generate text with Hugging Face authentication"""
444
  try:
445
+ # Validate HF token
446
+ is_valid, user_info = await self.hf_auth.validate_token(hf_token)
447
+ if not is_valid or not user_info:
448
  return APIResponse(
449
  success=False,
450
+ error="⚠️ Invalid Hugging Face token"
451
  )
452
 
453
+ # Get or create user
454
+ user_email = user_info.get('email', '')
455
+ user_name = user_info.get('name', user_info.get('preferred_username', 'User'))
456
+
457
+ # Check if user exists
458
+ user = await self.db.validate_api_key(hf_token) # We're using HF token as API key now
459
+
460
+ if not user:
461
+ # Create new user
462
+ user_data = UserCreate(name=user_name, email=user_email, plan="free")
463
+ success, message, api_key = await self.db.create_user(user_data, hf_user_id=user_info.get('sub', ''))
464
+ if not success:
465
+ return APIResponse(success=False, error=message)
466
+
467
+ user = await self.db.validate_api_key(hf_token)
468
+
469
  # Check rate limit
470
  can_make_request, requests_used = await self.db.check_rate_limit(user['id'])
471
  if not can_make_request:
 
480
 
481
  # Generate text
482
  request = GenerationRequest(prompt=prompt, **kwargs)
483
+ success, text, tokens_used = await self.model_manager.generate(request, hf_token)
484
 
485
  if success:
486
  # Update usage statistics
 
508
  error="Internal service error"
509
  )
510
 
511
+ async def get_user_stats(self, hf_token: str) -> APIResponse:
512
+ """Get user statistics using Hugging Face token"""
513
  try:
514
+ # Validate HF token
515
+ is_valid, user_info = await self.hf_auth.validate_token(hf_token)
516
+ if not is_valid or not user_info:
517
+ return APIResponse(
518
+ success=False,
519
+ error="Invalid Hugging Face token"
520
+ )
521
+
522
+ user = await self.db.validate_api_key(hf_token)
523
  if not user:
524
  return APIResponse(
525
  success=False,
526
+ error="User not found"
527
  )
528
 
529
  # Generate analytics plot
 
551
  error="Error retrieving stats"
552
  )
553
 
554
+ async def get_admin_stats(self, hf_token: str) -> APIResponse:
555
  """Get admin statistics (only for admin users)"""
556
  try:
557
+ # Validate HF token
558
+ is_valid, user_info = await self.hf_auth.validate_token(hf_token)
559
+ if not is_valid or not user_info:
560
+ return APIResponse(
561
+ success=False,
562
+ error="Invalid Hugging Face token"
563
+ )
564
+
565
+ user_email = user_info.get('email', '')
566
+ if user_email != self.config.ADMIN_EMAIL:
567
  return APIResponse(
568
  success=False,
569
  error="Unauthorized: Admin access required"
 
864
  .dark-mode-plot {
865
  background: transparent !important;
866
  }
867
+
868
+ .hf-login-container {
869
+ text-align: center;
870
+ padding: 2rem;
871
+ }
872
+
873
+ .hf-login-btn {
874
+ background: linear-gradient(135deg, #ffd21e 0%, #ff9d0a 100%) !important;
875
+ color: #000 !important;
876
+ font-weight: 600 !important;
877
+ padding: 1rem 2rem !important;
878
+ border-radius: 12px !important;
879
+ margin: 1rem 0;
880
+ }
881
+
882
+ .hf-login-btn:hover {
883
+ transform: translateY(-2px);
884
+ box-shadow: 0 6px 20px rgba(255, 157, 10, 0.4) !important;
885
+ }
886
  """
887
 
888
  def create_header(self):
 
920
 
921
  return gr.HTML(examples_html)
922
 
923
+ def create_hf_login_component(self):
924
+ with gr.Column(elem_classes=["card"]):
925
+ gr.Markdown("### πŸ” Hugging Face Login")
926
+ gr.HTML("""
927
+ <div class="hf-login-container">
928
+ <p>To use this platform, please login with your Hugging Face account:</p>
929
+ <a href="https://huggingface.co/login" target="_blank" class="btn hf-login-btn">Login with Hugging Face</a>
930
+ <p>After logging in, go to your <a href="https://huggingface.co/settings/tokens" target="_blank" style="color: #ffd21e;">Hugging Face settings</a> to get your access token.</p>
931
+ </div>
932
+ """)
933
+ hf_token_input = gr.Textbox(
934
+ label="Hugging Face Access Token",
935
+ type="password",
936
+ placeholder="hf_xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx",
937
+ elem_classes=["input"]
938
+ )
939
+ return hf_token_input
940
+
941
  async def create_interface(self):
942
  """Create the enhanced Gradio interface"""
943
  with gr.Blocks(
 
957
  with gr.Row():
958
  with gr.Column(scale=1):
959
  gr.Markdown("### βš™οΈ Configuration")
960
+ hf_token_input = self.create_hf_login_component()
 
 
 
 
 
961
 
962
  with gr.Accordion("Advanced Settings", open=False):
963
  max_tokens_input = gr.Slider(
 
1018
  )
1019
 
1020
  # Stats display
1021
+ generation_stats = gr.HTML()
 
 
 
1022
 
1023
  # Profile Tab
1024
  with gr.Tab("πŸ‘€ Profile", elem_classes=["card"]):
1025
  with gr.Row():
1026
  with gr.Column(scale=1):
1027
+ gr.Markdown("### πŸ” Authentication")
1028
+ profile_hf_token = gr.Textbox(
1029
+ label="Hugging Face Access Token",
1030
+ type="password",
1031
+ placeholder="Enter your Hugging Face token...",
 
 
 
 
 
 
 
 
1032
  elem_classes=["input"]
1033
  )
1034
+ profile_login_btn = gr.Button(
1035
+ "πŸ”“ Login",
1036
+ elem_classes=["btn", "btn-primary"]
 
 
 
 
 
 
 
 
 
 
1037
  )
1038
+ profile_status = gr.HTML()
1039
 
1040
  with gr.Column(scale=1):
1041
  gr.Markdown("### πŸ“Š Account Statistics")
 
 
 
 
 
 
 
1042
  refresh_stats_btn = gr.Button(
1043
  "πŸ”„ Refresh Stats",
1044
  elem_classes=["btn", "btn-secondary"]
1045
  )
 
1046
  user_stats_display = gr.HTML()
1047
 
1048
  # Analytics Tab
 
1051
 
1052
  with gr.Row():
1053
  with gr.Column(scale=1):
1054
+ analytics_hf_token = gr.Textbox(
1055
+ label="Hugging Face Access Token",
1056
  type="password",
1057
+ placeholder="Enter your Hugging Face token...",
1058
  elem_classes=["input"]
1059
  )
1060
  refresh_analytics_btn = gr.Button(
 
1070
 
1071
  with gr.Row():
1072
  with gr.Column(scale=1):
1073
+ admin_hf_token = gr.Textbox(
1074
+ label="Hugging Face Access Token",
1075
  type="password",
1076
+ placeholder="Enter admin Hugging Face token...",
1077
  elem_classes=["input"]
1078
  )
1079
  refresh_admin_btn = gr.Button(
 
1088
  self.create_footer()
1089
 
1090
  # Event Handlers
1091
+ async def handle_generation(prompt, hf_token, max_tokens, temperature, top_k, top_p, repetition_penalty):
1092
+ if not hf_token.strip():
1093
+ return "⚠️ Please enter your Hugging Face token", ""
1094
 
1095
  response = await self.service.generate_text(
1096
  prompt=prompt,
1097
+ hf_token=hf_token,
1098
  max_tokens=max_tokens,
1099
  temperature=temperature,
1100
  top_k=top_k,
 
1121
  """
1122
  return (
1123
  response.data["generated_text"],
1124
+ stats_html
 
1125
  )
1126
  else:
1127
  return (
1128
  response.error,
1129
+ f'<div class="alert alert-error">{response.error}</div>'
 
1130
  )
1131
 
1132
+ async def handle_profile_login(hf_token):
1133
+ if not hf_token.strip():
1134
+ return f'<div class="alert alert-error">❌ Please enter your Hugging Face token</div>'
 
 
 
 
1135
 
1136
+ response = await self.service.get_user_stats(hf_token)
 
 
 
 
 
 
 
1137
 
1138
  if response.success:
1139
+ return f'<div class="alert alert-success">βœ… Successfully logged in!</div>'
 
 
 
 
1140
  else:
1141
+ return f'<div class="alert alert-error">❌ {response.error}</div>'
 
 
 
 
1142
 
1143
+ async def handle_stats_refresh(hf_token):
1144
+ response = await self.service.get_user_stats(hf_token)
1145
 
1146
  if response.success:
1147
  stats = response.data
 
1177
  else:
1178
  return f'<div class="alert alert-error">❌ {response.error}</div>'
1179
 
1180
+ async def handle_analytics_refresh(hf_token):
1181
+ response = await self.service.get_user_stats(hf_token)
1182
 
1183
  if response.success:
1184
  return response.data["plot"]
 
1190
  font=dict(color='white')
1191
  )
1192
 
1193
+ async def handle_admin_refresh(hf_token):
1194
+ response = await self.service.get_admin_stats(hf_token)
1195
 
1196
  if response.success:
1197
  stats = response.data
 
1264
  generate_btn.click(
1265
  fn=handle_generation,
1266
  inputs=[
1267
+ prompt_input, hf_token_input, max_tokens_input,
1268
  temperature_input, top_k_input, top_p_input, repetition_penalty_input
1269
  ],
1270
+ outputs=[output_text, generation_stats]
1271
  )
1272
 
1273
  clear_btn.click(
1274
+ fn=lambda: ("", ""),
1275
  inputs=[],
1276
+ outputs=[prompt_input, output_text]
1277
  )
1278
 
1279
+ profile_login_btn.click(
1280
+ fn=handle_profile_login,
1281
+ inputs=[profile_hf_token],
1282
+ outputs=[profile_status]
1283
  )
1284
 
1285
  refresh_stats_btn.click(
1286
  fn=handle_stats_refresh,
1287
+ inputs=[profile_hf_token],
1288
  outputs=[user_stats_display]
1289
  )
1290
 
1291
  refresh_analytics_btn.click(
1292
  fn=handle_analytics_refresh,
1293
+ inputs=[analytics_hf_token],
1294
  outputs=[analytics_plot]
1295
  )
1296
 
1297
  refresh_admin_btn.click(
1298
  fn=handle_admin_refresh,
1299
+ inputs=[admin_hf_token],
1300
  outputs=[admin_plot, admin_stats_display]
1301
  )
1302