GuXSs commited on
Commit
23a7d96
·
verified ·
1 Parent(s): b9300de

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +200 -422
app.py CHANGED
@@ -1,61 +1,39 @@
1
- # Updated Gemma SaaS Gradio app with Hugging Face OAuth login + key listing
2
- # Save this as `gemma_saas_gradio_oauth.py`
3
-
4
- # IMPORTANT: Before running, set these environment variables (in your Space secrets or .env):
5
- # SUPABASE_URL, SUPABASE_KEY, HF_OAUTH_CLIENT_ID, HF_OAUTH_CLIENT_SECRET,
6
- # HF_OAUTH_REDIRECT_URI (e.g. https://<your-space>.hf.space/hf_callback),
7
- # ADMIN_EMAIL (optional), MODEL_NAME (optional), HF_TOKEN (optional for model preloading)
8
-
9
- # NOTE: When deploying to Hugging Face Spaces, register the redirect URI in the
10
- # Hugging Face OAuth app settings to exactly match HF_OAUTH_REDIRECT_URI.
11
-
12
  import os
 
13
  import json
14
  import asyncio
15
  import logging
 
16
  from datetime import datetime, timedelta
 
17
  from dataclasses import dataclass
18
- from typing import Dict, Optional, Tuple, Any, List
19
 
20
  import gradio as gr
21
  import aiohttp
22
- import secrets
 
23
  from pydantic import BaseModel, ValidationError
24
- from fastapi import Request
25
- from starlette.responses import HTMLResponse
 
26
 
27
- # Optional transformer/model imports (kept like before)
28
- from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
29
 
30
- # ----------------- Config -----------------
31
  @dataclass
32
  class Config:
33
- SUPABASE_URL: str = os.getenv("SUPABASE_URL", "")
34
- SUPABASE_KEY: str = os.getenv("SUPABASE_KEY", "")
35
- HF_TOKEN: str = os.getenv("HF_TOKEN", "") # optional; used for loading model directly
36
- HF_OAUTH_CLIENT_ID: str = os.getenv("HF_OAUTH_CLIENT_ID", "")
37
- HF_OAUTH_CLIENT_SECRET: str = os.getenv("HF_OAUTH_CLIENT_SECRET", "")
38
- HF_OAUTH_REDIRECT_URI: str = os.getenv("HF_OAUTH_REDIRECT_URI", "")
39
- HF_OAUTH_PORT: int = int(os.getenv("HF_OAUTH_PORT", "8000"))
40
- MODEL_NAME: str = os.getenv("MODEL_NAME", "google/gemma-3-270m-it")
41
- JWT_SECRET: str = os.getenv("JWT_SECRET", secrets.token_urlsafe(32))
42
- RATE_LIMIT_PER_HOUR: int = int(os.getenv("RATE_LIMIT_PER_HOUR", "100"))
43
- MAX_TOKENS: int = int(os.getenv("MAX_TOKENS", "500"))
44
  LOG_LEVEL: str = os.getenv("LOG_LEVEL", "INFO")
45
- ADMIN_EMAIL: str = os.getenv("ADMIN_EMAIL", "")
46
 
47
  class GenerationRequest(BaseModel):
48
  prompt: str
49
- max_tokens: int = 200
50
- temperature: float = 0.7
51
  top_k: int = 50
52
  top_p: float = 0.95
53
- repetition_penalty: float = 1.0
54
-
55
- class UserCreate(BaseModel):
56
- name: str
57
- email: str
58
- plan: str = "free"
59
 
60
  class APIResponse(BaseModel):
61
  success: bool
@@ -63,13 +41,13 @@ class APIResponse(BaseModel):
63
  error: Optional[str] = None
64
  timestamp: datetime = datetime.now()
65
 
66
- # ----------------- Logging -----------------
67
  def setup_logger():
68
- cfg = Config()
69
  logging.basicConfig(
70
- level=getattr(logging, cfg.LOG_LEVEL),
71
  format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
72
  handlers=[
 
73
  logging.StreamHandler()
74
  ]
75
  )
@@ -77,177 +55,7 @@ def setup_logger():
77
 
78
  logger = setup_logger()
79
 
80
- # ----------------- Database Manager -----------------
81
- class DatabaseManager:
82
- def __init__(self, config: Config):
83
- self.config = config
84
- self.headers = {
85
- "apikey": config.SUPABASE_KEY,
86
- "Authorization": f"Bearer {config.SUPABASE_KEY}",
87
- "Content-Type": "application/json"
88
- }
89
-
90
- async def create_user(self, user_data: UserCreate, hf_user_id: str = None, hf_token: str = None) -> Tuple[bool, str, str]:
91
- try:
92
- # Check existing by email
93
- async with aiohttp.ClientSession() as session:
94
- async with session.get(
95
- f"{self.config.SUPABASE_URL}/rest/v1/users?email=eq.{user_data.email}",
96
- headers=self.headers
97
- ) as response:
98
- if response.status == 200:
99
- existing_users = await response.json()
100
- if existing_users:
101
- return False, "❌ User with this email already exists", ""
102
-
103
- api_key = self._generate_api_key()
104
- data = {
105
- "name": user_data.name.strip(),
106
- "email": user_data.email.strip(),
107
- "api_key": api_key,
108
- "hf_user_id": hf_user_id,
109
- "hf_token": hf_token,
110
- "requests": 0,
111
- "plan": user_data.plan,
112
- "created_at": datetime.utcnow().isoformat(),
113
- "last_request": None,
114
- "requests_this_hour": 0,
115
- "rate_limit_reset": (datetime.utcnow() + timedelta(hours=1)).isoformat(),
116
- "tokens_used": 0
117
- }
118
-
119
- async with aiohttp.ClientSession() as session:
120
- async with session.post(
121
- f"{self.config.SUPABASE_URL}/rest/v1/users",
122
- headers=self.headers,
123
- data=json.dumps(data)
124
- ) as response:
125
- if response.status in (200, 201):
126
- logger.info(f"User created: {user_data.email}")
127
- return True, f"✅ User created successfully for {user_data.name}", api_key
128
- else:
129
- text = await response.text()
130
- logger.error(f"Error creating user: {text}")
131
- return False, f"❌ Error creating user: {text}", ""
132
-
133
- except Exception as e:
134
- logger.error(f"DB create_user error: {e}")
135
- return False, f"❌ Database error: {str(e)}", ""
136
-
137
- def _generate_api_key(self) -> str:
138
- return f"gsa_{secrets.token_urlsafe(32)}"
139
-
140
- async def get_user_by_email(self, email: str) -> Optional[Dict]:
141
- try:
142
- async with aiohttp.ClientSession() as session:
143
- async with session.get(
144
- f"{self.config.SUPABASE_URL}/rest/v1/users?email=eq.{email}&select=*",
145
- headers=self.headers
146
- ) as response:
147
- if response.status == 200:
148
- data = await response.json()
149
- return data[0] if data else None
150
- except Exception as e:
151
- logger.error(f"DB get_user_by_email error: {e}")
152
- return None
153
-
154
- async def upsert_hf_token_for_email(self, email: str, hf_token: str, hf_user_id: Optional[str] = None) -> bool:
155
- try:
156
- user = await self.get_user_by_email(email)
157
- payload = {"hf_token": hf_token}
158
- if hf_user_id:
159
- payload["hf_user_id"] = hf_user_id
160
-
161
- async with aiohttp.ClientSession() as session:
162
- if user:
163
- # Patch existing
164
- async with session.patch(
165
- f"{self.config.SUPABASE_URL}/rest/v1/users?email=eq.{email}",
166
- headers=self.headers,
167
- data=json.dumps(payload)
168
- ) as response:
169
- return response.status in (200, 204)
170
- else:
171
- # Create new user row with minimal info
172
- new_data = {
173
- "name": email.split('@')[0],
174
- "email": email,
175
- "api_key": self._generate_api_key(),
176
- "hf_token": hf_token,
177
- "hf_user_id": hf_user_id,
178
- "requests": 0,
179
- "plan": "free",
180
- "created_at": datetime.utcnow().isoformat(),
181
- "last_request": None,
182
- "requests_this_hour": 0,
183
- "rate_limit_reset": (datetime.utcnow() + timedelta(hours=1)).isoformat(),
184
- "tokens_used": 0
185
- }
186
- async with session.post(
187
- f"{self.config.SUPABASE_URL}/rest/v1/users",
188
- headers=self.headers,
189
- data=json.dumps(new_data)
190
- ) as response:
191
- return response.status in (200, 201)
192
- except Exception as e:
193
- logger.error(f"DB upsert_hf_token error: {e}")
194
- return False
195
-
196
- async def list_keys_for_email(self, email: str) -> List[Dict]:
197
- try:
198
- user = await self.get_user_by_email(email)
199
- if not user:
200
- return []
201
- # Return obfuscated tokens/keys
202
- api_key = user.get('api_key')
203
- hf_token = user.get('hf_token')
204
- items = []
205
- if api_key:
206
- items.append({"type": "api_key", "value": self._obfuscate(api_key)})
207
- if hf_token:
208
- items.append({"type": "hf_token", "value": self._obfuscate(hf_token)})
209
- return items
210
- except Exception as e:
211
- logger.error(f"DB list_keys error: {e}")
212
- return []
213
-
214
- def _obfuscate(self, token: str) -> str:
215
- if not token or len(token) < 6:
216
- return "****"
217
- return '*' * (len(token) - 6) + token[-6:]
218
-
219
- async def get_all_users_stats(self):
220
- try:
221
- async with aiohttp.ClientSession() as session:
222
- async with session.get(
223
- f"{self.config.SUPABASE_URL}/rest/v1/users?select=*",
224
- headers=self.headers
225
- ) as response:
226
- if response.status == 200:
227
- return await response.json()
228
- return []
229
- except Exception as e:
230
- logger.error(f"DB get_all_users_stats error: {e}")
231
- return []
232
-
233
- # ----------------- Hugging Face Auth -----------------
234
- class HuggingFaceAuth:
235
- @staticmethod
236
- async def validate_token(token: str) -> Tuple[bool, Optional[Dict]]:
237
- try:
238
- async with aiohttp.ClientSession() as session:
239
- headers = {"Authorization": f"Bearer {token}"}
240
- async with session.get("https://huggingface.co/api/whoami-v2", headers=headers) as response:
241
- # whoami-v2 returns more structured data; fallback to whoami
242
- if response.status == 200:
243
- return True, await response.json()
244
- else:
245
- return False, None
246
- except Exception as e:
247
- logger.error(f"HF validate_token error: {e}")
248
- return False, None
249
-
250
- # ----------------- Model Manager (unchanged mostly) -----------------
251
  class ModelManager:
252
  def __init__(self, config: Config):
253
  self.config = config
@@ -257,268 +65,238 @@ class ModelManager:
257
  self.model_loaded = False
258
 
259
  async def initialize(self):
 
260
  if not self.config.HF_TOKEN:
261
- logger.info("No HF_TOKEN provided skipping model preload. Will lazy-load on first request if needed.")
262
  self.model_loaded = False
263
  return
264
-
265
  try:
266
- logger.info("Loading model...")
267
  loop = asyncio.get_event_loop()
268
- self.tokenizer = await loop.run_in_executor(
269
- None,
270
- lambda: AutoTokenizer.from_pretrained(
271
- self.config.MODEL_NAME,
272
- use_auth_token=self.config.HF_TOKEN,
273
- trust_remote_code=True
274
- )
275
- )
276
- self.model = await loop.run_in_executor(
277
- None,
278
- lambda: AutoModelForCausalLM.from_pretrained(
279
  self.config.MODEL_NAME,
280
- use_auth_token=self.config.HF_TOKEN,
281
  device_map="auto",
282
- torch_dtype="auto",
283
- trust_remote_code=True
284
  )
285
- )
286
- self.pipeline = await loop.run_in_executor(
287
- None,
288
- lambda: pipeline(
289
  "text-generation",
290
- model=self.model,
291
- tokenizer=self.tokenizer,
292
- pad_token_id=self.tokenizer.eos_token_id
293
  )
294
- )
 
 
295
  self.model_loaded = True
296
- logger.info("Model loaded successfully")
297
  except Exception as e:
298
- logger.error(f"Model load error: {e}")
299
  self.model_loaded = False
300
 
301
- async def generate(self, request: GenerationRequest, hf_token: str = None) -> Tuple[bool, str, int]:
 
302
  if not self.model_loaded:
303
- # try lazy-load with provided token
304
- if hf_token:
305
- self.config.HF_TOKEN = hf_token
306
- await self.initialize()
307
- else:
308
- return False, "❌ Model not loaded and no HF token provided", 0
309
-
310
  try:
311
- if len(request.prompt.strip()) == 0:
312
- return False, "⚠️ Prompt cannot be empty", 0
 
 
 
313
  loop = asyncio.get_event_loop()
314
- result = await loop.run_in_executor(
315
- None,
316
- lambda: self.pipeline(
317
- request.prompt.strip(),
 
 
 
 
 
318
  max_new_tokens=min(request.max_tokens, self.config.MAX_TOKENS),
319
  do_sample=True,
320
  temperature=request.temperature,
321
  top_k=request.top_k,
322
  top_p=request.top_p,
323
- repetition_penalty=request.repetition_penalty,
324
- pad_token_id=self.tokenizer.eos_token_id,
325
- return_full_text=False
326
  )
327
- )
328
- generated_text = result[0]["generated_text"]
 
329
  tokens_used = len(self.tokenizer.encode(generated_text))
330
  return True, generated_text, tokens_used
331
  except Exception as e:
332
  logger.error(f"Generation error: {e}")
333
- return False, f"❌ Generation failed: {str(e)}", 0
334
 
335
- # ----------------- Service -----------------
336
- class GemmaSaaSService:
337
  def __init__(self):
338
  self.config = Config()
339
- self.db = DatabaseManager(self.config)
340
  self.model_manager = ModelManager(self.config)
341
- self.hf_auth = HuggingFaceAuth()
342
 
343
- # Validate minimal config (SUPABASE required)
344
- if not self.config.SUPABASE_URL or not self.config.SUPABASE_KEY:
345
- raise ValueError("Missing SUPABASE_URL or SUPABASE_KEY in environment variables")
 
346
 
347
  async def initialize(self):
348
  await self.model_manager.initialize()
349
 
350
- async def exchange_code_and_store(self, code: str) -> Tuple[bool, str]:
351
- """Exchange OAuth code for token and store it in Supabase for the user."""
352
  try:
353
- if not (self.config.HF_OAUTH_CLIENT_ID and self.config.HF_OAUTH_CLIENT_SECRET and self.config.HF_OAUTH_REDIRECT_URI):
354
- return False, "OAuth client id/secret or redirect URI not configured"
355
-
356
- token_url = "https://huggingface.co/api/oauth/token"
357
- data = {
358
- "grant_type": "authorization_code",
359
- "code": code,
360
- "client_id": self.config.HF_OAUTH_CLIENT_ID,
361
- "client_secret": self.config.HF_OAUTH_CLIENT_SECRET,
362
- "redirect_uri": self.config.HF_OAUTH_REDIRECT_URI
363
- }
364
-
365
- async with aiohttp.ClientSession() as session:
366
- async with session.post(token_url, data=data) as resp:
367
- if resp.status != 200:
368
- text = await resp.text()
369
- logger.error(f"Token exchange failed: {text}")
370
- return False, f"Token exchange failed: {text}"
371
- token_resp = await resp.json()
372
 
373
- access_token = token_resp.get('access_token')
374
- if not access_token:
375
- return False, "No access token returned from HF"
376
-
377
- # Validate token (whoami)
378
- valid, whoami = await self.hf_auth.validate_token(access_token)
379
- if not valid or not whoami:
380
- return False, "Failed to validate HF token after exchange"
381
-
382
- # whoami-v2 returns 'user' and 'id' fields often; try to obtain email
383
- email = whoami.get('email') or (whoami.get('user', {}).get('email') if isinstance(whoami.get('user'), dict) else None)
384
- hf_user_id = whoami.get('id') or whoami.get('user', {}).get('id') if isinstance(whoami.get('user'), dict) else None
385
-
386
- if not email:
387
- # We can still store token but prefer email
388
- email = f"hf_user_{hf_user_id or secrets.token_hex(8)}@nomail"
389
-
390
- # Upsert token into Supabase
391
- ok = await self.db.upsert_hf_token_for_email(email=email, hf_token=access_token, hf_user_id=hf_user_id)
392
- if not ok:
393
- return False, "Failed to save token in database"
394
-
395
- return True, f"Successfully linked Hugging Face account: {email}"
396
 
397
  except Exception as e:
398
- logger.error(f"exchange_code error: {e}")
399
- return False, str(e)
400
-
401
- async def get_user_keys(self, hf_token: str) -> APIResponse:
402
- valid, whoami = await self.hf_auth.validate_token(hf_token)
403
- if not valid or not whoami:
404
- return APIResponse(success=False, error="Invalid Hugging Face token")
405
-
406
- # Try extract email
407
- email = whoami.get('email') or (whoami.get('user', {}).get('email') if isinstance(whoami.get('user'), dict) else None)
408
- if not email:
409
- return APIResponse(success=False, error="Could not extract email from HF token")
410
-
411
- items = await self.db.list_keys_for_email(email)
412
- return APIResponse(success=True, data=items)
413
 
414
- # Keep other methods such as generate_text / get_user_stats from your original code – simplified here
415
-
416
- # ----------------- Gradio UI -----------------
417
  class GradioInterface:
418
- def __init__(self, service: GemmaSaaSService):
419
  self.service = service
420
 
421
- def create_hf_authorize_url(self, state: str = None) -> str:
422
- cfg = self.service.config
423
- client_id = cfg.HF_OAUTH_CLIENT_ID
424
- redirect = cfg.HF_OAUTH_REDIRECT_URI
425
- scopes = "api:read api:write offline_access"
426
- base = "https://huggingface.co/oauth/authorize"
427
- params = f"?response_type=code&client_id={client_id}&redirect_uri={redirect}&scope={scopes}"
428
- if state:
429
- params += f"&state={state}"
430
- return base + params
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
431
 
432
  async def create_interface(self):
433
- cfg = self.service.config
434
-
435
- with gr.Blocks(title="Gemma AI Platform (OAuth)") as app:
436
- gr.Markdown("# Gemma AI — Login com Hugging Face")
437
-
438
- with gr.Row():
439
- authorize_html = gr.HTML(f"<a class='btn hf-login-btn' target='_blank' href='{self.create_hf_authorize_url()}'>🔐 Login with Hugging Face</a>")
440
- gr.Markdown("Depois de autorizar, você será redirecionado de volta para `/hf_callback` e o token será salvo no banco de dados.")
441
-
442
- with gr.Row():
443
- hf_token_input = gr.Textbox(label="Ou cole aqui o seu token Hugging Face (opcional)", type="password")
444
- show_keys_btn = gr.Button("🔎 Mostrar minhas chaves")
445
- keys_out = gr.HTML()
446
-
447
- # Admin: view all users & keys
448
- with gr.Accordion("Admin — Lista de usuários e tokens (somente admin)", open=False):
449
- admin_email_input = gr.Textbox(label="Admin HF token (para validar)", type="password")
450
- list_all_btn = gr.Button("📋 Listar todos os usuários")
451
- admin_out = gr.HTML()
452
-
453
- async def handle_show_keys(hf_token: str):
454
- if not hf_token or not hf_token.strip():
455
- return '<div class="alert alert-error">❌ Forneça seu token da Hugging Face (ou use o fluxo OAuth)</div>'
456
- resp = await self.service.get_user_keys(hf_token)
457
- if not resp.success:
458
- return f"<div class='alert alert-error'>❌ {resp.error}</div>"
459
- items = resp.data or []
460
- if not items:
461
- return '<div class="alert">Nenhuma chave encontrada</div>'
462
- html = '<ul>'
463
- for it in items:
464
- html += f"<li><strong>{it['type']}</strong>: <code>{it['value']}</code></li>"
465
- html += '</ul>'
466
- return html
467
-
468
- async def handle_list_all(admin_token: str):
469
- # Validate admin token via HF and email check
470
- valid, whoami = await self.service.hf_auth.validate_token(admin_token)
471
- if not valid or not whoami:
472
- return '<div class="alert alert-error">Token inválido</div>'
473
- email = whoami.get('email') or whoami.get('user', {}).get('email')
474
- if email != cfg.ADMIN_EMAIL:
475
- return '<div class="alert alert-error">Acesso negado — não é admin</div>'
476
-
477
- users = await self.service.db.get_all_users_stats()
478
- if not users:
479
- return '<div class="alert">Nenhum usuário registrado</div>'
480
- html = '<table><tr><th>Name</th><th>Email</th><th>Plan</th><th>API Key (obf.)</th><th>HF Token (obf.)</th></tr>'
481
- for u in users:
482
- html += f"<tr><td>{u.get('name','')}</td><td>{u.get('email','')}</td><td>{u.get('plan','')}</td><td>{self.service.db._obfuscate(u.get('api_key',''))}</td><td>{self.service.db._obfuscate(u.get('hf_token',''))}</td></tr>"
483
- html += '</table>'
484
- return html
485
 
486
- show_keys_btn.click(fn=handle_show_keys, inputs=[hf_token_input], outputs=[keys_out])
487
- list_all_btn.click(fn=handle_list_all, inputs=[admin_email_input], outputs=[admin_out])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
488
 
489
  return app
490
 
491
- # ----------------- Main -----------------
492
  async def main():
493
- service = GemmaSaaSService()
494
- # Model is optional at startup
495
- await service.initialize()
496
-
497
- interface = GradioInterface(service)
498
- app = await interface.create_interface()
499
-
500
- # --- Add OAuth callback route onto Gradio's internal Starlette app ---
501
- async def hf_callback(request: Request):
502
- code = request.query_params.get('code')
503
- state = request.query_params.get('state')
504
- if not code:
505
- return HTMLResponse('<h3>Missing code in callback</h3>')
506
-
507
- ok, message = await service.exchange_code_and_store(code)
508
- # Simple HTML response - can be improved
509
- html = f"<html><body><h3>{'Success' if ok else 'Error'}</h3><p>{message}</p><script>setTimeout(()=>window.close(),1500)</script></body></html>"
510
- return HTMLResponse(html)
511
-
512
- # Register route
513
  try:
514
- # `app` from gr.Blocks has attribute `.app` which is a Starlette/FastAPI app
515
- app.app.add_api_route("/hf_callback", hf_callback, methods=["GET"]) # type: ignore
516
- logger.info("Registered /hf_callback route on Gradio app")
 
 
 
 
 
 
 
 
 
 
517
  except Exception as e:
518
- logger.error(f"Could not register callback route on Gradio app: {e}")
519
-
520
- # Launch Gradio app
521
- app.launch(server_name="0.0.0.0", server_port=7860)
522
 
523
- if __name__ == '__main__':
 
 
524
  asyncio.run(main())
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
+ import uuid
3
  import json
4
  import asyncio
5
  import logging
6
+ import time
7
  from datetime import datetime, timedelta
8
+ from typing import Dict, List, Optional, Tuple, Any
9
  from dataclasses import dataclass
 
10
 
11
  import gradio as gr
12
  import aiohttp
13
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
14
+ from dotenv import load_dotenv
15
  from pydantic import BaseModel, ValidationError
16
+ import secrets
17
+ import plotly.graph_objects as go
18
+ from plotly.subplots import make_subplots
19
 
20
+ # ----------------- Configuration & Models -----------------
21
+ load_dotenv()
22
 
 
23
  @dataclass
24
  class Config:
25
+ HF_TOKEN: str = os.getenv("HF_TOKEN", "")
26
+ MODEL_NAME: str = os.getenv("MODEL_NAME", "google/gemma-2-9b-it")
27
+ MAX_TOKENS: int = int(os.getenv("MAX_TOKENS", "1500"))
 
 
 
 
 
 
 
 
28
  LOG_LEVEL: str = os.getenv("LOG_LEVEL", "INFO")
 
29
 
30
  class GenerationRequest(BaseModel):
31
  prompt: str
32
+ max_tokens: int = 500
33
+ temperature: float = 0.75
34
  top_k: int = 50
35
  top_p: float = 0.95
36
+ repetition_penalty: float = 1.1
 
 
 
 
 
37
 
38
  class APIResponse(BaseModel):
39
  success: bool
 
41
  error: Optional[str] = None
42
  timestamp: datetime = datetime.now()
43
 
44
+ # ----------------- Enhanced Logger -----------------
45
  def setup_logger():
 
46
  logging.basicConfig(
47
+ level=getattr(logging, Config().LOG_LEVEL),
48
  format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
49
  handlers=[
50
+ logging.FileHandler('gemma_saas.log'),
51
  logging.StreamHandler()
52
  ]
53
  )
 
55
 
56
  logger = setup_logger()
57
 
58
+ # ----------------- Model Manager -----------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  class ModelManager:
60
  def __init__(self, config: Config):
61
  self.config = config
 
65
  self.model_loaded = False
66
 
67
  async def initialize(self):
68
+ """Initialize the model, tokenizer, and pipeline asynchronously."""
69
  if not self.config.HF_TOKEN:
70
+ logger.error("Hugging Face token not found. Model loading will fail.")
71
  self.model_loaded = False
72
  return
 
73
  try:
74
+ logger.info(f"Loading model: {self.config.MODEL_NAME}...")
75
  loop = asyncio.get_event_loop()
76
+
77
+ def load_components():
78
+ tokenizer = AutoTokenizer.from_pretrained(self.config.MODEL_NAME, token=self.config.HF_TOKEN)
79
+ model = AutoModelForCausalLM.from_pretrained(
 
 
 
 
 
 
 
80
  self.config.MODEL_NAME,
81
+ token=self.config.HF_TOKEN,
82
  device_map="auto",
83
+ torch_dtype="auto"
 
84
  )
85
+ text_pipeline = pipeline(
 
 
 
86
  "text-generation",
87
+ model=model,
88
+ tokenizer=tokenizer,
 
89
  )
90
+ return tokenizer, model, text_pipeline
91
+
92
+ self.tokenizer, self.model, self.pipeline = await loop.run_in_executor(None, load_components)
93
  self.model_loaded = True
94
+ logger.info("Model loaded successfully!")
95
  except Exception as e:
96
+ logger.error(f" Error loading model: {e}")
97
  self.model_loaded = False
98
 
99
+ async def generate(self, request: GenerationRequest) -> Tuple[bool, str, int]:
100
+ """Generate text based on the provided request."""
101
  if not self.model_loaded:
102
+ return False, "❌ O modelo não está disponível. Por favor, verifique os logs do servidor.", 0
 
 
 
 
 
 
103
  try:
104
+ if not request.prompt.strip():
105
+ return False, "⚠️ O prompt não pode estar vazio.", 0
106
+ if len(request.prompt) > 8000:
107
+ return False, "⚠️ O prompt é muito longo (máximo de 8000 caracteres).", 0
108
+
109
  loop = asyncio.get_event_loop()
110
+
111
+ messages = [
112
+ {"role": "user", "content": request.prompt.strip()},
113
+ ]
114
+
115
+ def do_generation():
116
+ prompt = self.pipeline.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
117
+ outputs = self.pipeline(
118
+ prompt,
119
  max_new_tokens=min(request.max_tokens, self.config.MAX_TOKENS),
120
  do_sample=True,
121
  temperature=request.temperature,
122
  top_k=request.top_k,
123
  top_p=request.top_p,
 
 
 
124
  )
125
+ return outputs[0]["generated_text"][len(prompt):]
126
+
127
+ generated_text = await loop.run_in_executor(None, do_generation)
128
  tokens_used = len(self.tokenizer.encode(generated_text))
129
  return True, generated_text, tokens_used
130
  except Exception as e:
131
  logger.error(f"Generation error: {e}")
132
+ return False, f"❌ A geração falhou: {str(e)}", 0
133
 
134
+ # ----------------- Service Layer -----------------
135
+ class GemmaService:
136
  def __init__(self):
137
  self.config = Config()
 
138
  self.model_manager = ModelManager(self.config)
139
+ self._validate_config()
140
 
141
+ def _validate_config(self):
142
+ """Validate that required environment variables are set."""
143
+ if not self.config.HF_TOKEN:
144
+ raise ValueError("Missing required environment variable: HF_TOKEN")
145
 
146
  async def initialize(self):
147
  await self.model_manager.initialize()
148
 
149
+ async def generate_text(self, prompt: str, **kwargs) -> APIResponse:
150
+ """Generate text directly."""
151
  try:
152
+ request = GenerationRequest(prompt=prompt, **kwargs)
153
+ success, text, tokens_used = await self.model_manager.generate(request)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
 
155
+ if success:
156
+ return APIResponse(
157
+ success=True,
158
+ data={"generated_text": text, "tokens_used": tokens_used}
159
+ )
160
+ else:
161
+ return APIResponse(success=False, error=text)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
162
 
163
  except Exception as e:
164
+ logger.error(f"Service error during text generation: {e}")
165
+ return APIResponse(success=False, error="Ocorreu um erro interno no serviço.")
 
 
 
 
 
 
 
 
 
 
 
 
 
166
 
167
+ # ----------------- Enhanced UI -----------------
 
 
168
  class GradioInterface:
169
+ def __init__(self, service: GemmaService):
170
  self.service = service
171
 
172
+ def create_custom_css(self):
173
+ return """
174
+ :root {
175
+ --dark-bg: #111111;
176
+ --panel-bg: #1C1C1C;
177
+ --border-color: #333333;
178
+ --text-color: #E0E0E0;
179
+ --text-light: #A0A0A0;
180
+ --accent-orange: #FF4500;
181
+ --accent-orange-hover: #FF6347;
182
+ }
183
+ .gradio-container { background-color: var(--dark-bg) !important; }
184
+ #main_layout { background-color: transparent; border: none !important; box-shadow: none !important; }
185
+ #right_panel { background-color: var(--panel-bg); border-left: 1px solid var(--border-color); border-radius: 12px; padding: 2rem !important; }
186
+ #left_panel { background-color: var(--panel-bg); border-radius: 12px; padding: 1rem !important; display: flex !important; flex-direction: column !important; height: 70vh; }
187
+ #output_display { flex-grow: 1; overflow-y: auto; padding: 1rem; color: var(--text-color); }
188
+ #output_display p { margin-bottom: 1rem; line-height: 1.6; }
189
+ #prompt_row { border-top: 1px solid var(--border-color); padding-top: 1rem; }
190
+ #prompt_input textarea { background-color: #2C2C2C !important; border-color: var(--border-color) !important; color: var(--text-color) !important; border-radius: 8px !important; }
191
+ #send_button { background-color: var(--accent-orange); color: white; border: none; border-radius: 50% !important; width: 50px !important; height: 50px !important; min-width: 50px !important; transition: background-color 0.3s ease; }
192
+ #send_button:hover { background-color: var(--accent-orange-hover); }
193
+ #generate_button {
194
+ background: linear-gradient(135deg, var(--accent-orange), var(--accent-orange-hover));
195
+ color: white !important;
196
+ font-size: 1.2rem !important;
197
+ font-weight: bold !important;
198
+ border: none;
199
+ border-radius: 12px !important;
200
+ padding: 1rem !important;
201
+ box-shadow: 0 4px 15px rgba(255, 69, 0, 0.4);
202
+ transition: all 0.3s ease;
203
+ }
204
+ #generate_button:hover {
205
+ transform: translateY(-2px);
206
+ box-shadow: 0 6px 20px rgba(255, 69, 0, 0.6);
207
+ }
208
+ .gr-label { color: var(--text-light) !important; }
209
+ h2 { color: white; border-bottom: 1px solid var(--border-color); padding-bottom: 0.5rem; margin-bottom: 1rem; }
210
+ #info_text { color: var(--text-light); line-height: 1.7; }
211
+ """
212
 
213
  async def create_interface(self):
214
+ with gr.Blocks(css=self.create_custom_css(), theme=None) as app:
215
+ with gr.Row(elem_id="main_layout", equal_height=False):
216
+ with gr.Column(scale=2, elem_id="left_panel_container"):
217
+ with gr.Column(elem_id="left_panel"):
218
+ output_display = gr.Markdown(elem_id="output_display", value="<p style='color: #A0A0A0;'>Sua resposta aparecerá aqui...</p>")
219
+ with gr.Row(elem_id="prompt_row"):
220
+ prompt_input = gr.Textbox(
221
+ show_label=False,
222
+ placeholder="Digite sua mensagem aqui...",
223
+ elem_id="prompt_input",
224
+ scale=10
225
+ )
226
+ send_button = gr.Button("➤", elem_id="send_button", scale=1)
227
+
228
+ with gr.Column(scale=1, elem_id="right_panel"):
229
+ gr.Markdown("## Informações")
230
+ gr.Markdown(
231
+ """
232
+ Este é um ambiente interativo para o modelo de linguagem **Gemma**.
233
+
234
+ - **Como usar:** Digite seu prompt na caixa de texto à esquerda e clique no botão de envio ou no botão "Gerar" abaixo.
235
+ - **Modelo:** `google/gemma-2-9b-it`
236
+ - **Capacidades:** Geração de texto criativo, respostas a perguntas, resumo, tradução e muito mais.
237
+
238
+ Sinta-se à vontade para experimentar diferentes tipos de prompts para explorar todo o potencial do modelo.
239
+ """,
240
+ elem_id="info_text"
241
+ )
242
+ generate_button = gr.Button("✨ Gerar", elem_id="generate_button")
243
+
244
+ # --- Event Handlers ---
245
+ async def handle_generation(prompt):
246
+ if not prompt:
247
+ return "<p style='color: #FFCC00;'>Por favor, digite um prompt para começar.</p>"
248
+
249
+ # Show a loading indicator
250
+ yield "<p style='color: #A0A0A0;'>Gerando resposta...</p>"
251
+
252
+ response = await self.service.generate_text(prompt=prompt)
253
+
254
+ if response.success:
255
+ yield response.data["generated_text"]
256
+ else:
257
+ yield f"<p style='color: #FF4500;'>{response.error}</p>"
 
 
 
 
 
 
 
 
258
 
259
+ # --- Wiring ---
260
+ generate_button.click(
261
+ handle_generation,
262
+ inputs=[prompt_input],
263
+ outputs=[output_display]
264
+ )
265
+ send_button.click(
266
+ handle_generation,
267
+ inputs=[prompt_input],
268
+ outputs=[output_display]
269
+ )
270
+ prompt_input.submit(
271
+ handle_generation,
272
+ inputs=[prompt_input],
273
+ outputs=[output_display]
274
+ )
275
 
276
  return app
277
 
278
+ # ----------------- Main Application -----------------
279
  async def main():
280
+ """Main application entry point"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
281
  try:
282
+ service = GemmaService()
283
+ await service.initialize()
284
+
285
+ interface = GradioInterface(service)
286
+ app = await interface.create_interface()
287
+
288
+ app.launch(
289
+ server_name="0.0.0.0",
290
+ server_port=7860,
291
+ share=False,
292
+ debug=False,
293
+ show_error=True
294
+ )
295
  except Exception as e:
296
+ logger.critical(f"Failed to start application: {e}", exc_info=True)
297
+ raise
 
 
298
 
299
+ if __name__ == "__main__":
300
+ # To run this, you need a .env file with:
301
+ # HF_TOKEN="your_hugging_face_token"
302
  asyncio.run(main())