GuXSs commited on
Commit
ce7ff6c
·
verified ·
1 Parent(s): cebf496

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +767 -130
app.py CHANGED
@@ -1,154 +1,791 @@
1
  import os
2
  import uuid
3
  import json
 
 
 
 
 
 
 
 
4
  import gradio as gr
5
  import requests
 
 
6
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
7
  from dotenv import load_dotenv
 
 
 
 
 
8
 
9
- # ----------------- Configuração -----------------
10
  load_dotenv()
11
 
12
- HF_TOKEN = os.environ.get("HF_TOKEN") # Token Hugging Face para modelos gated
13
- MODEL_NAME = "google/gemma-3-270m-it"
14
- SUPABASE_URL = os.environ.get("SUPABASE_URL")
15
- SUPABASE_KEY = os.environ.get("SUPABASE_KEY")
16
-
17
- if not all([HF_TOKEN, SUPABASE_URL, SUPABASE_KEY]):
18
- raise ValueError("Missing required environment variables.")
19
-
20
- HEADERS = {
21
- "apikey": SUPABASE_KEY,
22
- "Authorization": f"Bearer {SUPABASE_KEY}",
23
- "Content-Type": "application/json"
24
- }
25
-
26
- # ----------------- Carregar Modelo -----------------
27
- try:
28
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_auth_token=HF_TOKEN)
29
- model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, use_auth_token=HF_TOKEN, device_map="auto")
30
- gen_pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
31
- print("✅ Model loaded successfully!")
32
- except Exception as e:
33
- print(f"❌ Error loading model: {e}")
34
- gen_pipe = None
35
 
36
- # ----------------- Funções Supabase -----------------
37
- def create_user(name: str, hf_user_id: str):
38
- """Cria usuário no Supabase vinculado ao Hugging Face ID"""
39
- if not name.strip():
40
- return "⚠️ Name cannot be empty.", ""
41
- api_key = str(uuid.uuid4())
42
- data = {"name": name.strip(), "api_key": api_key, "hf_user_id": hf_user_id, "requests": 0}
43
- try:
44
- r = requests.post(f"{SUPABASE_URL}/rest/v1/users", headers=HEADERS, data=json.dumps(data))
45
- r.raise_for_status()
46
- return f"✅ API key created for {name}.", api_key
47
- except Exception as e:
48
- return f"❌ Error creating user: {e}", ""
49
 
50
- def check_key(api_key: str) -> bool:
51
- try:
52
- r = requests.get(f"{SUPABASE_URL}/rest/v1/users?api_key=eq.{api_key.strip()}", headers=HEADERS)
53
- r.raise_for_status()
54
- return len(r.json()) > 0
55
- except:
56
- return False
57
 
58
- def increment_requests(api_key: str):
59
- """Incrementa contagem de requisições para API Key"""
60
- try:
61
- r = requests.get(f"{SUPABASE_URL}/rest/v1/users?api_key=eq.{api_key.strip()}", headers=HEADERS)
62
- r.raise_for_status()
63
- data = r.json()
64
- if not data:
65
- return
66
- user_id = data[0]["id"]
67
- requests_count = data[0].get("requests", 0) + 1
68
- patch_data = {"requests": requests_count}
69
- requests.patch(f"{SUPABASE_URL}/rest/v1/users?id=eq.{user_id}", headers=HEADERS, data=json.dumps(patch_data))
70
- except:
71
- pass
72
 
73
- def get_user_info(api_key: str):
74
- try:
75
- r = requests.get(f"{SUPABASE_URL}/rest/v1/users?api_key=eq.{api_key.strip()}", headers=HEADERS)
76
- r.raise_for_status()
77
- data = r.json()
78
- return data[0] if data else None
79
- except:
80
- return None
 
 
 
81
 
82
- # ----------------- Função de Geração -----------------
83
- def generate_text(prompt: str, api_key: str):
84
- if gen_pipe is None:
85
- return "❌ Model not loaded."
86
- if not api_key.strip() or not check_key(api_key):
87
- return "⚠️ Invalid API key!"
88
- if not prompt.strip():
89
- return "⚠️ Prompt cannot be empty."
90
- try:
91
- result = gen_pipe(prompt.strip(), max_new_tokens=200, do_sample=True, temperature=0.7, top_k=50, top_p=0.95)
92
- increment_requests(api_key)
93
- return result[0]["generated_text"]
94
- except Exception as e:
95
- print(f"❌ Error: {e}")
96
- return "❌ Something went wrong."
97
 
98
- # ----------------- CSS para Sidebar e Cards -----------------
99
- custom_css = """
100
- body {background: linear-gradient(135deg,#f0f4f8 0%,#d9e2ec 100%); font-family:'Inter', sans-serif;}
101
- .gradio-container {max-width: 1200px !important; margin:auto !important; padding-top:2rem;}
102
- .sidebar {background:#1e293b; color:white; padding:1rem; border-radius:10px;}
103
- .sidebar label {color:#cbd5e1; font-weight:600;}
104
- .gr-button {border-radius:12px !important; background:linear-gradient(90deg,#4f46e5,#3b82f6) !important; color:white !important; font-weight:600;}
105
- .gr-button:hover {transform: translateY(-2px); box-shadow:0 6px 12px rgba(0,0,0,0.15);}
106
- """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
 
108
- # ----------------- Interface Gradio -----------------
109
- with gr.Blocks(css=custom_css) as app:
110
- # Header
111
- gr.Markdown("<h1 style='text-align:center; font-size:3rem; background: linear-gradient(90deg,#4f46e5,#3b82f6); -webkit-background-clip:text; -webkit-text-fill-color:transparent;'>✨ Gemma SaaS Platform</h1>")
112
- gr.Markdown("<p style='text-align:center; font-size:1.2rem; color:#334155;'>Playground + Profile with Hugging Face login & API keys</p>")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
 
114
- with gr.Row():
115
- # Sidebar
116
- with gr.Column(scale=1, min_width=250):
117
- gr.Markdown("### Menu")
118
- menu_radio = gr.Radio(choices=["Playground", "Profile"], value="Playground", label="", elem_classes=["sidebar"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
 
120
- # Main Content
121
- with gr.Column(scale=4):
122
- with gr.Group():
123
- # Playground
124
- with gr.Tab("Playground"):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
  with gr.Row():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
  with gr.Column(scale=2):
127
- gr.Markdown("### Instructions / Example JSON")
128
- gr.Markdown("```json\n{\n \"prompt\": \"Write a short story about a robot\",\n \"api_key\": \"YOUR_API_KEY\"\n}```")
129
- with gr.Column(scale=3):
130
- prompt_input = gr.Textbox(label="💬 Prompt", lines=5)
131
- api_key_input = gr.Textbox(label="🔑 API Key", type="password")
132
- output_text = gr.Textbox(label="💡 Result", interactive=False, lines=8)
 
133
  with gr.Row():
134
- clear_btn = gr.ClearButton(value="Clear", components=[prompt_input, output_text])
135
- generate_btn = gr.Button("Generate")
136
- generate_btn.click(fn=generate_text, inputs=[prompt_input, api_key_input], outputs=output_text)
137
-
138
- # Profile
139
- with gr.Tab("Profile"):
140
- gr.Markdown("### Manage your API Key (Login required)")
141
- name_input = gr.Textbox(label="👤 Your Name")
142
- status_output = gr.Textbox(label="Status", interactive=False)
143
- key_output = gr.Textbox(label="✅ API Key", interactive=False)
144
- request_count_output = gr.Textbox(label="Requests Made", interactive=False)
145
-
146
- def create_key_profile(name):
147
- # Aqui você pode vincular hf_user_id se usar OAuth
148
- msg, key = create_user(name, hf_user_id="demo") # demo placeholder
149
- return msg, key, 0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
150
 
151
- create_btn = gr.Button("Create API Key")
152
- create_btn.click(fn=create_key_profile, inputs=name_input, outputs=[status_output, key_output, request_count_output])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
 
154
- app.launch(server_name="0.0.0.0", server_port=786)
 
 
 
1
  import os
2
  import uuid
3
  import json
4
+ import asyncio
5
+ import logging
6
+ import time
7
+ from datetime import datetime, timedelta
8
+ from typing import Dict, List, Optional, Tuple, Any
9
+ from dataclasses import dataclass
10
+ from contextlib import asynccontextmanager
11
+
12
  import gradio as gr
13
  import requests
14
+ import aiohttp
15
+ import asyncpg
16
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
17
  from dotenv import load_dotenv
18
+ from pydantic import BaseModel, ValidationError
19
+ import jwt
20
+ from functools import wraps
21
+ import hashlib
22
+ import secrets
23
 
24
+ # ----------------- Configuration & Models -----------------
25
  load_dotenv()
26
 
27
+ @dataclass
28
+ class Config:
29
+ HF_TOKEN: str = os.getenv("HF_TOKEN", "")
30
+ MODEL_NAME: str = os.getenv("MODEL_NAME", "google/gemma-3-270m-it")
31
+ SUPABASE_URL: str = os.getenv("SUPABASE_URL", "")
32
+ SUPABASE_KEY: str = os.getenv("SUPABASE_KEY", "")
33
+ JWT_SECRET: str = os.getenv("JWT_SECRET", secrets.token_urlsafe(32))
34
+ RATE_LIMIT_PER_HOUR: int = int(os.getenv("RATE_LIMIT_PER_HOUR", "100"))
35
+ MAX_TOKENS: int = int(os.getenv("MAX_TOKENS", "500"))
36
+ LOG_LEVEL: str = os.getenv("LOG_LEVEL", "INFO")
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
+ class GenerationRequest(BaseModel):
39
+ prompt: str
40
+ max_tokens: int = 200
41
+ temperature: float = 0.7
42
+ top_k: int = 50
43
+ top_p: float = 0.95
 
 
 
 
 
 
 
44
 
45
+ class UserCreate(BaseModel):
46
+ name: str
47
+ email: str
48
+ plan: str = "free"
 
 
 
49
 
50
+ class APIResponse(BaseModel):
51
+ success: bool
52
+ data: Any = None
53
+ error: str = None
54
+ timestamp: datetime = datetime.now()
 
 
 
 
 
 
 
 
 
55
 
56
+ # ----------------- Enhanced Logger -----------------
57
+ def setup_logger():
58
+ logging.basicConfig(
59
+ level=getattr(logging, Config().LOG_LEVEL),
60
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
61
+ handlers=[
62
+ logging.FileHandler('gemma_saas.log'),
63
+ logging.StreamHandler()
64
+ ]
65
+ )
66
+ return logging.getLogger(__name__)
67
 
68
+ logger = setup_logger()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
 
70
+ # ----------------- Database Manager -----------------
71
+ class DatabaseManager:
72
+ def __init__(self, config: Config):
73
+ self.config = config
74
+ self.headers = {
75
+ "apikey": config.SUPABASE_KEY,
76
+ "Authorization": f"Bearer {config.SUPABASE_KEY}",
77
+ "Content-Type": "application/json"
78
+ }
79
+
80
+ async def create_user(self, user_data: UserCreate, hf_user_id: str = None) -> Tuple[bool, str, str]:
81
+ """Create user with enhanced validation and security"""
82
+ try:
83
+ # Generate secure API key
84
+ api_key = self._generate_api_key()
85
+ api_key_hash = hashlib.sha256(api_key.encode()).hexdigest()
86
+
87
+ data = {
88
+ "name": user_data.name.strip(),
89
+ "email": user_data.email.strip(),
90
+ "api_key_hash": api_key_hash,
91
+ "hf_user_id": hf_user_id,
92
+ "requests": 0,
93
+ "plan": user_data.plan,
94
+ "created_at": datetime.now().isoformat(),
95
+ "last_request": None,
96
+ "rate_limit_reset": (datetime.now() + timedelta(hours=1)).isoformat()
97
+ }
98
+
99
+ async with aiohttp.ClientSession() as session:
100
+ async with session.post(
101
+ f"{self.config.SUPABASE_URL}/rest/v1/users",
102
+ headers=self.headers,
103
+ data=json.dumps(data)
104
+ ) as response:
105
+ if response.status == 201:
106
+ logger.info(f"User created successfully: {user_data.email}")
107
+ return True, f"✅ User created successfully for {user_data.name}", api_key
108
+ else:
109
+ error_text = await response.text()
110
+ logger.error(f"Error creating user: {error_text}")
111
+ return False, f"❌ Error creating user: {error_text}", ""
112
+
113
+ except Exception as e:
114
+ logger.error(f"Database error creating user: {e}")
115
+ return False, f"❌ Database error: {str(e)}", ""
116
+
117
+ def _generate_api_key(self) -> str:
118
+ """Generate secure API key with prefix"""
119
+ return f"gsa_{secrets.token_urlsafe(32)}"
120
+
121
+ async def validate_api_key(self, api_key: str) -> Optional[Dict]:
122
+ """Validate API key and return user data"""
123
+ try:
124
+ api_key_hash = hashlib.sha256(api_key.encode()).hexdigest()
125
+
126
+ async with aiohttp.ClientSession() as session:
127
+ async with session.get(
128
+ f"{self.config.SUPABASE_URL}/rest/v1/users?api_key_hash=eq.{api_key_hash}",
129
+ headers=self.headers
130
+ ) as response:
131
+ if response.status == 200:
132
+ data = await response.json()
133
+ return data[0] if data else None
134
+ return None
135
+ except Exception as e:
136
+ logger.error(f"Error validating API key: {e}")
137
+ return None
138
+
139
+ async def check_rate_limit(self, user_id: int) -> bool:
140
+ """Check if user has exceeded rate limit"""
141
+ try:
142
+ async with aiohttp.ClientSession() as session:
143
+ async with session.get(
144
+ f"{self.config.SUPABASE_URL}/rest/v1/users?id=eq.{user_id}",
145
+ headers=self.headers
146
+ ) as response:
147
+ if response.status == 200:
148
+ data = await response.json()
149
+ if data:
150
+ user = data[0]
151
+ reset_time = datetime.fromisoformat(user.get('rate_limit_reset', ''))
152
+
153
+ if datetime.now() > reset_time:
154
+ # Reset rate limit
155
+ await self._reset_rate_limit(user_id)
156
+ return True
157
+
158
+ return user.get('requests_this_hour', 0) < Config().RATE_LIMIT_PER_HOUR
159
+ return False
160
+ except Exception as e:
161
+ logger.error(f"Error checking rate limit: {e}")
162
+ return False
163
+
164
+ async def _reset_rate_limit(self, user_id: int):
165
+ """Reset hourly rate limit"""
166
+ try:
167
+ data = {
168
+ "requests_this_hour": 0,
169
+ "rate_limit_reset": (datetime.now() + timedelta(hours=1)).isoformat()
170
+ }
171
+
172
+ async with aiohttp.ClientSession() as session:
173
+ async with session.patch(
174
+ f"{self.config.SUPABASE_URL}/rest/v1/users?id=eq.{user_id}",
175
+ headers=self.headers,
176
+ data=json.dumps(data)
177
+ ) as response:
178
+ if response.status != 200:
179
+ logger.error(f"Failed to reset rate limit for user {user_id}")
180
+ except Exception as e:
181
+ logger.error(f"Error resetting rate limit: {e}")
182
+
183
+ async def increment_usage(self, user_id: int, tokens_used: int):
184
+ """Increment user usage statistics"""
185
+ try:
186
+ # Get current stats
187
+ async with aiohttp.ClientSession() as session:
188
+ async with session.get(
189
+ f"{self.config.SUPABASE_URL}/rest/v1/users?id=eq.{user_id}",
190
+ headers=self.headers
191
+ ) as response:
192
+ if response.status == 200:
193
+ data = await response.json()
194
+ if data:
195
+ user = data[0]
196
+ new_data = {
197
+ "requests": user.get('requests', 0) + 1,
198
+ "requests_this_hour": user.get('requests_this_hour', 0) + 1,
199
+ "tokens_used": user.get('tokens_used', 0) + tokens_used,
200
+ "last_request": datetime.now().isoformat()
201
+ }
202
+
203
+ await session.patch(
204
+ f"{self.config.SUPABASE_URL}/rest/v1/users?id=eq.{user_id}",
205
+ headers=self.headers,
206
+ data=json.dumps(new_data)
207
+ )
208
+ except Exception as e:
209
+ logger.error(f"Error incrementing usage: {e}")
210
 
211
+ # ----------------- Model Manager -----------------
212
+ class ModelManager:
213
+ def __init__(self, config: Config):
214
+ self.config = config
215
+ self.tokenizer = None
216
+ self.model = None
217
+ self.pipeline = None
218
+ self.model_loaded = False
219
+
220
+ async def initialize(self):
221
+ """Initialize model asynchronously"""
222
+ try:
223
+ logger.info("Loading model...")
224
+ loop = asyncio.get_event_loop()
225
+
226
+ # Load tokenizer
227
+ self.tokenizer = await loop.run_in_executor(
228
+ None,
229
+ lambda: AutoTokenizer.from_pretrained(
230
+ self.config.MODEL_NAME,
231
+ use_auth_token=self.config.HF_TOKEN
232
+ )
233
+ )
234
+
235
+ # Load model
236
+ self.model = await loop.run_in_executor(
237
+ None,
238
+ lambda: AutoModelForCausalLM.from_pretrained(
239
+ self.config.MODEL_NAME,
240
+ use_auth_token=self.config.HF_TOKEN,
241
+ device_map="auto",
242
+ torch_dtype="auto"
243
+ )
244
+ )
245
+
246
+ # Create pipeline
247
+ self.pipeline = await loop.run_in_executor(
248
+ None,
249
+ lambda: pipeline(
250
+ "text-generation",
251
+ model=self.model,
252
+ tokenizer=self.tokenizer,
253
+ pad_token_id=self.tokenizer.eos_token_id
254
+ )
255
+ )
256
+
257
+ self.model_loaded = True
258
+ logger.info("✅ Model loaded successfully!")
259
+
260
+ except Exception as e:
261
+ logger.error(f"❌ Error loading model: {e}")
262
+ self.model_loaded = False
263
+
264
+ async def generate(self, request: GenerationRequest) -> Tuple[bool, str, int]:
265
+ """Generate text with the model"""
266
+ if not self.model_loaded:
267
+ return False, "❌ Model not loaded", 0
268
+
269
+ try:
270
+ # Input validation
271
+ if len(request.prompt.strip()) == 0:
272
+ return False, "⚠️ Prompt cannot be empty", 0
273
+
274
+ if len(request.prompt) > 2000:
275
+ return False, "⚠️ Prompt too long (max 2000 characters)", 0
276
+
277
+ # Generate text
278
+ loop = asyncio.get_event_loop()
279
+ result = await loop.run_in_executor(
280
+ None,
281
+ lambda: self.pipeline(
282
+ request.prompt.strip(),
283
+ max_new_tokens=min(request.max_tokens, self.config.MAX_TOKENS),
284
+ do_sample=True,
285
+ temperature=request.temperature,
286
+ top_k=request.top_k,
287
+ top_p=request.top_p,
288
+ pad_token_id=self.tokenizer.eos_token_id,
289
+ return_full_text=False
290
+ )
291
+ )
292
+
293
+ generated_text = result[0]["generated_text"]
294
+ tokens_used = len(self.tokenizer.encode(generated_text))
295
+
296
+ return True, generated_text, tokens_used
297
+
298
+ except Exception as e:
299
+ logger.error(f"Generation error: {e}")
300
+ return False, f"❌ Generation failed: {str(e)}", 0
301
 
302
+ # ----------------- Service Layer -----------------
303
+ class GemmaSaaSService:
304
+ def __init__(self):
305
+ self.config = Config()
306
+ self.db = DatabaseManager(self.config)
307
+ self.model_manager = ModelManager(self.config)
308
+ self._validate_config()
309
+
310
+ def _validate_config(self):
311
+ """Validate required configuration"""
312
+ required_fields = ['HF_TOKEN', 'SUPABASE_URL', 'SUPABASE_KEY']
313
+ missing_fields = [field for field in required_fields if not getattr(self.config, field)]
314
+
315
+ if missing_fields:
316
+ raise ValueError(f"Missing required environment variables: {', '.join(missing_fields)}")
317
+
318
+ async def initialize(self):
319
+ """Initialize all services"""
320
+ await self.model_manager.initialize()
321
+
322
+ async def create_user(self, name: str, email: str, plan: str = "free") -> APIResponse:
323
+ """Create new user with API key"""
324
+ try:
325
+ user_data = UserCreate(name=name, email=email, plan=plan)
326
+ success, message, api_key = await self.db.create_user(user_data)
327
+
328
+ return APIResponse(
329
+ success=success,
330
+ data={"api_key": api_key} if success else None,
331
+ error=message if not success else None
332
+ )
333
+ except ValidationError as e:
334
+ return APIResponse(success=False, error=f"Validation error: {e}")
335
+ except Exception as e:
336
+ logger.error(f"Service error creating user: {e}")
337
+ return APIResponse(success=False, error="Internal service error")
338
+
339
+ async def generate_text(self, prompt: str, api_key: str, **kwargs) -> APIResponse:
340
+ """Generate text with authentication and rate limiting"""
341
+ try:
342
+ # Validate API key
343
+ user = await self.db.validate_api_key(api_key)
344
+ if not user:
345
+ return APIResponse(success=False, error="⚠️ Invalid API key")
346
+
347
+ # Check rate limit
348
+ if not await self.db.check_rate_limit(user['id']):
349
+ return APIResponse(success=False, error="⚠️ Rate limit exceeded. Try again later.")
350
+
351
+ # Generate text
352
+ request = GenerationRequest(prompt=prompt, **kwargs)
353
+ success, text, tokens_used = await self.model_manager.generate(request)
354
+
355
+ if success:
356
+ # Update usage statistics
357
+ await self.db.increment_usage(user['id'], tokens_used)
358
+
359
+ return APIResponse(
360
+ success=True,
361
+ data={
362
+ "generated_text": text,
363
+ "tokens_used": tokens_used,
364
+ "user_plan": user.get('plan', 'free')
365
+ }
366
+ )
367
+ else:
368
+ return APIResponse(success=False, error=text)
369
+
370
+ except Exception as e:
371
+ logger.error(f"Service error generating text: {e}")
372
+ return APIResponse(success=False, error="Internal service error")
373
+
374
+ async def get_user_stats(self, api_key: str) -> APIResponse:
375
+ """Get user statistics"""
376
+ try:
377
+ user = await self.db.validate_api_key(api_key)
378
+ if not user:
379
+ return APIResponse(success=False, error="Invalid API key")
380
+
381
+ stats = {
382
+ "name": user.get('name'),
383
+ "email": user.get('email'),
384
+ "plan": user.get('plan', 'free'),
385
+ "total_requests": user.get('requests', 0),
386
+ "tokens_used": user.get('tokens_used', 0),
387
+ "requests_this_hour": user.get('requests_this_hour', 0),
388
+ "rate_limit": self.config.RATE_LIMIT_PER_HOUR,
389
+ "created_at": user.get('created_at'),
390
+ "last_request": user.get('last_request')
391
+ }
392
+
393
+ return APIResponse(success=True, data=stats)
394
+
395
+ except Exception as e:
396
+ logger.error(f"Error getting user stats: {e}")
397
+ return APIResponse(success=False, error="Error retrieving stats")
398
 
399
+ # ----------------- Enhanced UI -----------------
400
+ class GradioInterface:
401
+ def __init__(self, service: GemmaSaaSService):
402
+ self.service = service
403
+
404
+ def create_advanced_css(self):
405
+ return """
406
+ @import url('https://fonts.googleapis.com/css2?family=Inter:wght@300;400;500;600;700&display=swap');
407
+
408
+ :root {
409
+ --primary-gradient: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
410
+ --secondary-gradient: linear-gradient(135deg, #f093fb 0%, #f5576c 100%);
411
+ --success-gradient: linear-gradient(135deg, #4facfe 0%, #00f2fe 100%);
412
+ --card-shadow: 0 10px 30px rgba(0,0,0,0.1);
413
+ --hover-shadow: 0 15px 40px rgba(0,0,0,0.15);
414
+ }
415
+
416
+ body {
417
+ background: linear-gradient(135deg, #f5f7fa 0%, #c3cfe2 100%);
418
+ font-family: 'Inter', sans-serif;
419
+ margin: 0;
420
+ padding: 0;
421
+ }
422
+
423
+ .gradio-container {
424
+ max-width: 1400px !important;
425
+ margin: 0 auto !important;
426
+ padding: 2rem !important;
427
+ }
428
+
429
+ .main-header {
430
+ text-align: center;
431
+ margin-bottom: 3rem;
432
+ padding: 2rem;
433
+ background: white;
434
+ border-radius: 20px;
435
+ box-shadow: var(--card-shadow);
436
+ }
437
+
438
+ .main-title {
439
+ font-size: 3.5rem;
440
+ font-weight: 700;
441
+ background: var(--primary-gradient);
442
+ -webkit-background-clip: text;
443
+ -webkit-text-fill-color: transparent;
444
+ margin-bottom: 1rem;
445
+ }
446
+
447
+ .main-subtitle {
448
+ font-size: 1.3rem;
449
+ color: #64748b;
450
+ font-weight: 400;
451
+ }
452
+
453
+ .feature-card {
454
+ background: white;
455
+ border-radius: 15px;
456
+ padding: 2rem;
457
+ box-shadow: var(--card-shadow);
458
+ transition: all 0.3s ease;
459
+ border: 1px solid rgba(255,255,255,0.2);
460
+ }
461
+
462
+ .feature-card:hover {
463
+ transform: translateY(-5px);
464
+ box-shadow: var(--hover-shadow);
465
+ }
466
+
467
+ .gr-button {
468
+ border-radius: 12px !important;
469
+ font-weight: 600 !important;
470
+ padding: 0.8rem 2rem !important;
471
+ transition: all 0.3s ease !important;
472
+ border: none !important;
473
+ }
474
+
475
+ .gr-button-primary {
476
+ background: var(--primary-gradient) !important;
477
+ color: white !important;
478
+ }
479
+
480
+ .gr-button-secondary {
481
+ background: var(--secondary-gradient) !important;
482
+ color: white !important;
483
+ }
484
+
485
+ .gr-button:hover {
486
+ transform: translateY(-2px) !important;
487
+ box-shadow: 0 8px 25px rgba(0,0,0,0.2) !important;
488
+ }
489
+
490
+ .stats-card {
491
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
492
+ color: white;
493
+ border-radius: 15px;
494
+ padding: 2rem;
495
+ margin: 1rem 0;
496
+ }
497
+
498
+ .metric {
499
+ text-align: center;
500
+ padding: 1rem;
501
+ }
502
+
503
+ .metric-value {
504
+ font-size: 2.5rem;
505
+ font-weight: 700;
506
+ }
507
+
508
+ .metric-label {
509
+ font-size: 0.9rem;
510
+ opacity: 0.8;
511
+ text-transform: uppercase;
512
+ letter-spacing: 1px;
513
+ }
514
+
515
+ .alert-success {
516
+ background: linear-gradient(135deg, #4facfe 0%, #00f2fe 100%);
517
+ color: white;
518
+ padding: 1rem;
519
+ border-radius: 10px;
520
+ margin: 1rem 0;
521
+ }
522
+
523
+ .alert-error {
524
+ background: linear-gradient(135deg, #ff6b6b 0%, #ee5a52 100%);
525
+ color: white;
526
+ padding: 1rem;
527
+ border-radius: 10px;
528
+ margin: 1rem 0;
529
+ }
530
+ """
531
+
532
+ async def create_interface(self):
533
+ """Create the enhanced Gradio interface"""
534
+ with gr.Blocks(css=self.create_advanced_css(), title="Gemma SaaS Platform") as app:
535
+
536
+ # Header
537
+ with gr.Row():
538
+ gr.HTML("""
539
+ <div class="main-header">
540
+ <h1 class="main-title">���� Gemma SaaS Platform</h1>
541
+ <p class="main-subtitle">Professional AI Text Generation with Advanced Analytics</p>
542
+ </div>
543
+ """)
544
+
545
+ # Main Content
546
+ with gr.Tabs():
547
+
548
+ # Playground Tab
549
+ with gr.Tab("🎮 Playground", elem_classes=["feature-card"]):
550
  with gr.Row():
551
+ with gr.Column(scale=1):
552
+ gr.Markdown("### 📋 Generation Parameters")
553
+ api_key_playground = gr.Textbox(
554
+ label="🔑 API Key",
555
+ type="password",
556
+ placeholder="Enter your API key..."
557
+ )
558
+
559
+ with gr.Accordion("⚙️ Advanced Settings", open=False):
560
+ max_tokens_input = gr.Slider(
561
+ minimum=50, maximum=500, value=200,
562
+ label="Max Tokens"
563
+ )
564
+ temperature_input = gr.Slider(
565
+ minimum=0.1, maximum=2.0, value=0.7,
566
+ label="Temperature"
567
+ )
568
+ top_k_input = gr.Slider(
569
+ minimum=1, maximum=100, value=50,
570
+ label="Top K"
571
+ )
572
+ top_p_input = gr.Slider(
573
+ minimum=0.1, maximum=1.0, value=0.95,
574
+ label="Top P"
575
+ )
576
+
577
  with gr.Column(scale=2):
578
+ gr.Markdown("### 💭 Text Generation")
579
+ prompt_input = gr.Textbox(
580
+ label="✍️ Prompt",
581
+ lines=8,
582
+ placeholder="Enter your prompt here... (e.g., 'Write a short story about a robot discovering emotions')"
583
+ )
584
+
585
  with gr.Row():
586
+ generate_btn = gr.Button(
587
+ "🚀 Generate",
588
+ elem_classes=["gr-button-primary"],
589
+ variant="primary"
590
+ )
591
+ clear_btn = gr.ClearButton(
592
+ components=[prompt_input],
593
+ value="🗑️ Clear",
594
+ elem_classes=["gr-button-secondary"]
595
+ )
596
+
597
+ output_text = gr.Textbox(
598
+ label="📝 Generated Text",
599
+ lines=12,
600
+ interactive=False
601
+ )
602
+
603
+ # Stats display
604
+ generation_stats = gr.JSON(
605
+ label="📊 Generation Statistics",
606
+ visible=False
607
+ )
608
+
609
+ # Profile Tab
610
+ with gr.Tab("👤 Profile", elem_classes=["feature-card"]):
611
+ with gr.Row():
612
+ with gr.Column():
613
+ gr.Markdown("### 🆕 Create Account")
614
+ name_input = gr.Textbox(label="👤 Full Name")
615
+ email_input = gr.Textbox(label="📧 Email Address")
616
+ plan_input = gr.Dropdown(
617
+ choices=["free", "pro", "enterprise"],
618
+ value="free",
619
+ label="📋 Plan"
620
+ )
621
+
622
+ create_btn = gr.Button(
623
+ "✨ Create API Key",
624
+ elem_classes=["gr-button-primary"],
625
+ variant="primary"
626
+ )
627
+
628
+ creation_status = gr.HTML()
629
+ api_key_display = gr.Textbox(
630
+ label="🔑 Your API Key",
631
+ interactive=False,
632
+ visible=False
633
+ )
634
+
635
+ with gr.Column():
636
+ gr.Markdown("### 📊 Account Statistics")
637
+ stats_api_key = gr.Textbox(
638
+ label="🔑 API Key",
639
+ type="password",
640
+ placeholder="Enter API key to view stats"
641
+ )
642
+
643
+ refresh_stats_btn = gr.Button(
644
+ "🔄 Refresh Stats",
645
+ elem_classes=["gr-button-secondary"]
646
+ )
647
+
648
+ user_stats_display = gr.HTML()
649
+
650
+ # Analytics Tab
651
+ with gr.Tab("📈 Analytics", elem_classes=["feature-card"]):
652
+ gr.Markdown("### 📊 Platform Analytics")
653
+
654
+ # Placeholder for future analytics
655
+ gr.HTML("""
656
+ <div style="text-align: center; padding: 3rem;">
657
+ <h3>📈 Advanced Analytics Coming Soon</h3>
658
+ <p>Real-time usage metrics, performance insights, and detailed reporting.</p>
659
+ </div>
660
+ """)
661
+
662
+ # Event Handlers
663
+ async def handle_generation(prompt, api_key, max_tokens, temperature, top_k, top_p):
664
+ if not api_key.strip():
665
+ return "⚠️ Please enter your API key", {}, False
666
+
667
+ response = await self.service.generate_text(
668
+ prompt=prompt,
669
+ api_key=api_key,
670
+ max_tokens=max_tokens,
671
+ temperature=temperature,
672
+ top_k=top_k,
673
+ top_p=top_p
674
+ )
675
+
676
+ if response.success:
677
+ return (
678
+ response.data["generated_text"],
679
+ response.data,
680
+ True
681
+ )
682
+ else:
683
+ return response.error, {}, False
684
+
685
+ async def handle_user_creation(name, email, plan):
686
+ response = await self.service.create_user(name, email, plan)
687
+
688
+ if response.success:
689
+ return (
690
+ f'<div class="alert-success">✅ Account created successfully!</div>',
691
+ response.data["api_key"],
692
+ True
693
+ )
694
+ else:
695
+ return (
696
+ f'<div class="alert-error">❌ {response.error}</div>',
697
+ "",
698
+ False
699
+ )
700
+
701
+ async def handle_stats_refresh(api_key):
702
+ response = await self.service.get_user_stats(api_key)
703
+
704
+ if response.success:
705
+ stats = response.data
706
+ return f"""
707
+ <div class="stats-card">
708
+ <div style="display: grid; grid-template-columns: repeat(auto-fit, minmax(200px, 1fr)); gap: 1rem;">
709
+ <div class="metric">
710
+ <div class="metric-value">{stats['total_requests']}</div>
711
+ <div class="metric-label">Total Requests</div>
712
+ </div>
713
+ <div class="metric">
714
+ <div class="metric-value">{stats['tokens_used']:,}</div>
715
+ <div class="metric-label">Tokens Used</div>
716
+ </div>
717
+ <div class="metric">
718
+ <div class="metric-value">{stats['requests_this_hour']}/{stats['rate_limit']}</div>
719
+ <div class="metric-label">Hourly Usage</div>
720
+ </div>
721
+ <div class="metric">
722
+ <div class="metric-value">{stats['plan'].title()}</div>
723
+ <div class="metric-label">Current Plan</div>
724
+ </div>
725
+ </div>
726
+ <hr style="margin: 2rem 0; opacity: 0.3;">
727
+ <p><strong>Account:</strong> {stats['name']} ({stats['email']})</p>
728
+ <p><strong>Member since:</strong> {stats['created_at'][:10] if stats['created_at'] else 'N/A'}</p>
729
+ </div>
730
+ """
731
+ else:
732
+ return f'<div class="alert-error">❌ {response.error}</div>'
733
+
734
+ # Wire up events
735
+ generate_btn.click(
736
+ fn=handle_generation,
737
+ inputs=[
738
+ prompt_input, api_key_playground, max_tokens_input,
739
+ temperature_input, top_k_input, top_p_input
740
+ ],
741
+ outputs=[output_text, generation_stats, generation_stats]
742
+ )
743
+
744
+ create_btn.click(
745
+ fn=handle_user_creation,
746
+ inputs=[name_input, email_input, plan_input],
747
+ outputs=[creation_status, api_key_display, api_key_display]
748
+ )
749
+
750
+ refresh_stats_btn.click(
751
+ fn=handle_stats_refresh,
752
+ inputs=[stats_api_key],
753
+ outputs=[user_stats_display]
754
+ )
755
+
756
+ return app
757
 
758
+ # ----------------- Main Application -----------------
759
+ async def main():
760
+ """Main application entry point"""
761
+ try:
762
+ # Initialize service
763
+ service = GemmaSaaSService()
764
+ await service.initialize()
765
+
766
+ # Create interface
767
+ interface = GradioInterface(service)
768
+ app = await interface.create_interface()
769
+
770
+ # Launch with enhanced configuration
771
+ app.launch(
772
+ server_name="0.0.0.0",
773
+ server_port=7860,
774
+ share=False,
775
+ debug=False,
776
+ show_error=True,
777
+ quiet=False,
778
+ favicon_path=None,
779
+ ssl_keyfile=None,
780
+ ssl_certfile=None,
781
+ auth=None, # Add your auth function here if needed
782
+ max_threads=10
783
+ )
784
+
785
+ except Exception as e:
786
+ logger.error(f"Failed to start application: {e}")
787
+ raise
788
 
789
+ if __name__ == "__main__":
790
+ # Run the application
791
+ asyncio.run(main())