| import os |
| import uuid |
| import json |
| import asyncio |
| import logging |
| import time |
| from datetime import datetime, timedelta |
| from typing import Dict, List, Optional, Tuple, Any |
| from dataclasses import dataclass |
| from contextlib import asynccontextmanager |
|
|
| import gradio as gr |
| import requests |
| import aiohttp |
| import asyncpg |
| from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline |
| from dotenv import load_dotenv |
| from pydantic import BaseModel, ValidationError |
| import jwt |
| from functools import wraps |
| import hashlib |
| import secrets |
| import plotly.graph_objects as go |
| from plotly.subplots import make_subplots |
|
|
| |
| load_dotenv() |
|
|
| @dataclass |
| class Config: |
| HF_TOKEN: str = os.getenv("HF_TOKEN", "") |
| MODEL_NAME: str = os.getenv("MODEL_NAME", "google/gemma-3-270m-it") |
| SUPABASE_URL: str = os.getenv("SUPABASE_URL", "") |
| SUPABASE_KEY: str = os.getenv("SUPABASE_KEY", "") |
| JWT_SECRET: str = os.getenv("JWT_SECRET", secrets.token_urlsafe(32)) |
| RATE_LIMIT_PER_HOUR: int = int(os.getenv("RATE_LIMIT_PER_HOUR", "100")) |
| MAX_TOKENS: int = int(os.getenv("MAX_TOKENS", "500")) |
| LOG_LEVEL: str = os.getenv("LOG_LEVEL", "INFO") |
| ADMIN_EMAIL: str = os.getenv("ADMIN_EMAIL", "") |
|
|
| class GenerationRequest(BaseModel): |
| prompt: str |
| max_tokens: int = 200 |
| temperature: float = 0.7 |
| top_k: int = 50 |
| top_p: float = 0.95 |
| repetition_penalty: float = 1.0 |
|
|
| class UserCreate(BaseModel): |
| name: str |
| email: str |
| plan: str = "free" |
|
|
| class APIResponse(BaseModel): |
| success: bool |
| data: Any = None |
| error: Optional[str] = None |
| timestamp: datetime = datetime.now() |
|
|
| |
| def setup_logger(): |
| logging.basicConfig( |
| level=getattr(logging, Config().LOG_LEVEL), |
| format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', |
| handlers=[ |
| logging.FileHandler('gemma_saas.log'), |
| logging.StreamHandler() |
| ] |
| ) |
| return logging.getLogger(__name__) |
|
|
| logger = setup_logger() |
|
|
| |
| class DatabaseManager: |
| def __init__(self, config: Config): |
| self.config = config |
| self.headers = { |
| "apikey": config.SUPABASE_KEY, |
| "Authorization": f"Bearer {config.SUPABASE_KEY}", |
| "Content-Type": "application/json" |
| } |
| |
| async def create_user(self, user_data: UserCreate, hf_user_id: str = None) -> Tuple[bool, str, str]: |
| """Create user with enhanced validation and security""" |
| try: |
| |
| async with aiohttp.ClientSession() as session: |
| async with session.get( |
| f"{self.config.SUPABASE_URL}/rest/v1/users?email=eq.{user_data.email}", |
| headers=self.headers |
| ) as response: |
| if response.status == 200: |
| existing_users = await response.json() |
| if existing_users: |
| return False, "❌ User with this email already exists", "" |
| |
| |
| api_key = self._generate_api_key() |
| |
| data = { |
| "name": user_data.name.strip(), |
| "email": user_data.email.strip(), |
| "api_key": api_key, |
| "hf_user_id": hf_user_id, |
| "requests": 0, |
| "plan": user_data.plan, |
| "created_at": datetime.now().isoformat(), |
| "last_request": None, |
| "requests_this_hour": 0, |
| "rate_limit_reset": (datetime.now() + timedelta(hours=1)).isoformat(), |
| "tokens_used": 0 |
| } |
| |
| async with aiohttp.ClientSession() as session: |
| async with session.post( |
| f"{self.config.SUPABASE_URL}/rest/v1/users", |
| headers=self.headers, |
| data=json.dumps(data) |
| ) as response: |
| if response.status == 201: |
| logger.info(f"User created successfully: {user_data.email}") |
| return True, f"✅ User created successfully for {user_data.name}", api_key |
| else: |
| error_text = await response.text() |
| logger.error(f"Error creating user: {error_text}") |
| return False, f"❌ Error creating user: {error_text}", "" |
| |
| except Exception as e: |
| logger.error(f"Database error creating user: {e}") |
| return False, f"❌ Database error: {str(e)}", "" |
| |
| def _generate_api_key(self) -> str: |
| """Generate secure API key with prefix""" |
| return f"gsa_{secrets.token_urlsafe(32)}" |
| |
| async def validate_api_key(self, api_key: str) -> Optional[Dict]: |
| """Validate API key and return user data""" |
| try: |
| async with aiohttp.ClientSession() as session: |
| async with session.get( |
| f"{self.config.SUPABASE_URL}/rest/v1/users?api_key=eq.{api_key}", |
| headers=self.headers |
| ) as response: |
| if response.status == 200: |
| data = await response.json() |
| return data[0] if data else None |
| return None |
| except Exception as e: |
| logger.error(f"Error validating API key: {e}") |
| return None |
| |
| async def check_rate_limit(self, user_id: int) -> Tuple[bool, int]: |
| """Check if user has exceeded rate limit""" |
| try: |
| async with aiohttp.ClientSession() as session: |
| async with session.get( |
| f"{self.config.SUPABASE_URL}/rest/v1/users?id=eq.{user_id}", |
| headers=self.headers |
| ) as response: |
| if response.status == 200: |
| data = await response.json() |
| if data: |
| user = data[0] |
| reset_time = datetime.fromisoformat(user.get('rate_limit_reset', datetime.now().isoformat())) |
| |
| if datetime.now() > reset_time: |
| |
| await self._reset_rate_limit(user_id) |
| return True, 0 |
| |
| requests_this_hour = user.get('requests_this_hour', 0) |
| return requests_this_hour < self.config.RATE_LIMIT_PER_HOUR, requests_this_hour |
| return False, 0 |
| except Exception as e: |
| logger.error(f"Error checking rate limit: {e}") |
| return False, 0 |
| |
| async def _reset_rate_limit(self, user_id: int): |
| """Reset hourly rate limit""" |
| try: |
| data = { |
| "requests_this_hour": 0, |
| "rate_limit_reset": (datetime.now() + timedelta(hours=1)).isoformat() |
| } |
| |
| async with aiohttp.ClientSession() as session: |
| async with session.patch( |
| f"{self.config.SUPABASE_URL}/rest/v1/users?id=eq.{user_id}", |
| headers=self.headers, |
| data=json.dumps(data) |
| ) as response: |
| if response.status != 200: |
| logger.error(f"Failed to reset rate limit for user {user_id}") |
| except Exception as e: |
| logger.error(f"Error resetting rate limit: {e}") |
| |
| async def increment_usage(self, user_id: int, tokens_used: int): |
| """Increment user usage statistics""" |
| try: |
| |
| async with aiohttp.ClientSession() as session: |
| async with session.get( |
| f"{self.config.SUPABASE_URL}/rest/v1/users?id=eq.{user_id}", |
| headers=self.headers |
| ) as response: |
| if response.status == 200: |
| data = await response.json() |
| if data: |
| user = data[0] |
| new_data = { |
| "requests": user.get('requests', 0) + 1, |
| "requests_this_hour": user.get('requests_this_hour', 0) + 1, |
| "tokens_used": user.get('tokens_used', 0) + tokens_used, |
| "last_request": datetime.now().isoformat() |
| } |
| |
| await session.patch( |
| f"{self.config.SUPABASE_URL}/rest/v1/users?id=eq.{user_id}", |
| headers=self.headers, |
| data=json.dumps(new_data) |
| ) |
| except Exception as e: |
| logger.error(f"Error incrementing usage: {e}") |
| |
| async def get_all_users_stats(self): |
| """Get statistics for all users (admin only)""" |
| try: |
| async with aiohttp.ClientSession() as session: |
| async with session.get( |
| f"{self.config.SUPABASE_URL}/rest/v1/users?select=*", |
| headers=self.headers |
| ) as response: |
| if response.status == 200: |
| return await response.json() |
| return [] |
| except Exception as e: |
| logger.error(f"Error getting all users stats: {e}") |
| return [] |
|
|
| |
| class ModelManager: |
| def __init__(self, config: Config): |
| self.config = config |
| self.tokenizer = None |
| self.model = None |
| self.pipeline = None |
| self.model_loaded = False |
| |
| async def initialize(self): |
| """Initialize model asynchronously""" |
| try: |
| logger.info("Loading model...") |
| loop = asyncio.get_event_loop() |
| |
| |
| self.tokenizer = await loop.run_in_executor( |
| None, |
| lambda: AutoTokenizer.from_pretrained( |
| self.config.MODEL_NAME, |
| token=self.config.HF_TOKEN, |
| trust_remote_code=True |
| ) |
| ) |
| |
| |
| self.model = await loop.run_in_executor( |
| None, |
| lambda: AutoModelForCausalLM.from_pretrained( |
| self.config.MODEL_NAME, |
| token=self.config.HF_TOKEN, |
| device_map="auto", |
| torch_dtype="auto", |
| trust_remote_code=True |
| ) |
| ) |
| |
| |
| self.pipeline = await loop.run_in_executor( |
| None, |
| lambda: pipeline( |
| "text-generation", |
| model=self.model, |
| tokenizer=self.tokenizer, |
| pad_token_id=self.tokenizer.eos_token_id |
| ) |
| ) |
| |
| self.model_loaded = True |
| logger.info("✅ Model loaded successfully!") |
| |
| except Exception as e: |
| logger.error(f"❌ Error loading model: {e}") |
| self.model_loaded = False |
| |
| async def generate(self, request: GenerationRequest) -> Tuple[bool, str, int]: |
| """Generate text with the model""" |
| if not self.model_loaded: |
| return False, "❌ Model not loaded", 0 |
| |
| try: |
| |
| if len(request.prompt.strip()) == 0: |
| return False, "⚠️ Prompt cannot be empty", 0 |
| |
| if len(request.prompt) > 4000: |
| return False, "⚠️ Prompt too long (max 4000 characters)", 0 |
| |
| |
| loop = asyncio.get_event_loop() |
| result = await loop.run_in_executor( |
| None, |
| lambda: self.pipeline( |
| request.prompt.strip(), |
| max_new_tokens=min(request.max_tokens, self.config.MAX_TOKENS), |
| do_sample=True, |
| temperature=request.temperature, |
| top_k=request.top_k, |
| top_p=request.top_p, |
| repetition_penalty=request.repetition_penalty, |
| pad_token_id=self.tokenizer.eos_token_id, |
| return_full_text=False |
| ) |
| ) |
| |
| generated_text = result[0]["generated_text"] |
| tokens_used = len(self.tokenizer.encode(generated_text)) |
| |
| return True, generated_text, tokens_used |
| |
| except Exception as e: |
| logger.error(f"Generation error: {e}") |
| return False, f"❌ Generation failed: {str(e)}", 0 |
|
|
| |
| class AnalyticsManager: |
| def __init__(self, db: DatabaseManager): |
| self.db = db |
| |
| async def generate_usage_plot(self, user_data: Dict) -> go.Figure: |
| """Generate a usage plot for the user""" |
| fig = make_subplots( |
| rows=1, cols=2, |
| subplot_titles=('Requests Over Time', 'Token Usage Distribution'), |
| specs=[[{"type": "scatter"}, {"type": "pie"}]] |
| ) |
| |
| |
| dates = [datetime.now() - timedelta(days=i) for i in range(7, 0, -1)] |
| requests = [user_data.get('requests', 0) // 7] * 7 |
| |
| fig.add_trace( |
| go.Scatter(x=dates, y=requests, mode='lines+markers', name='Requests'), |
| row=1, col=1 |
| ) |
| |
| |
| token_categories = ['Generated', 'Prompt', 'Other'] |
| token_values = [user_data.get('tokens_used', 0), user_data.get('tokens_used', 0) // 3, user_data.get('tokens_used', 0) // 5] |
| |
| fig.add_trace( |
| go.Pie(labels=token_categories, values=token_values, name="Token Usage"), |
| row=1, col=2 |
| ) |
| |
| fig.update_layout( |
| height=400, |
| showlegend=True, |
| paper_bgcolor='rgba(0,0,0,0)', |
| plot_bgcolor='rgba(0,0,0,0)', |
| font=dict(color='white') |
| ) |
| |
| return fig |
|
|
| |
| class GemmaSaaSService: |
| def __init__(self): |
| self.config = Config() |
| self.db = DatabaseManager(self.config) |
| self.model_manager = ModelManager(self.config) |
| self.analytics_manager = AnalyticsManager(self.db) |
| self._validate_config() |
| |
| def _validate_config(self): |
| """Validate required configuration""" |
| required_fields = ['HF_TOKEN', 'SUPABASE_URL', 'SUPABASE_KEY'] |
| missing_fields = [field for field in required_fields if not getattr(self.config, field)] |
| |
| if missing_fields: |
| raise ValueError(f"Missing required environment variables: {', '.join(missing_fields)}") |
| |
| async def initialize(self): |
| """Initialize all services""" |
| await self.model_manager.initialize() |
| |
| async def create_user(self, name: str, email: str, plan: str = "free") -> APIResponse: |
| """Create new user with API key""" |
| try: |
| user_data = UserCreate(name=name, email=email, plan=plan) |
| success, message, api_key = await self.db.create_user(user_data) |
| |
| return APIResponse( |
| success=success, |
| data={"api_key": api_key, "message": message} if success else None, |
| error=message if not success else None |
| ) |
| except ValidationError as e: |
| return APIResponse( |
| success=False, |
| error=f"Validation error: {str(e)}" |
| ) |
| except Exception as e: |
| logger.error(f"Service error creating user: {e}") |
| return APIResponse( |
| success=False, |
| error="Internal service error" |
| ) |
| |
| async def generate_text(self, prompt: str, api_key: str, **kwargs) -> APIResponse: |
| """Generate text with authentication and rate limiting""" |
| try: |
| |
| user = await self.db.validate_api_key(api_key) |
| if not user: |
| return APIResponse( |
| success=False, |
| error="⚠️ Invalid API key" |
| ) |
| |
| |
| can_make_request, requests_used = await self.db.check_rate_limit(user['id']) |
| if not can_make_request: |
| reset_time = datetime.fromisoformat(user.get('rate_limit_reset', '')) |
| time_remaining = reset_time - datetime.now() |
| mins_remaining = max(0, int(time_remaining.total_seconds() / 60)) |
| |
| return APIResponse( |
| success=False, |
| error=f"⚠️ Rate limit exceeded ({requests_used}/{self.config.RATE_LIMIT_PER_HOUR}). Try again in {mins_remaining} minutes." |
| ) |
| |
| |
| request = GenerationRequest(prompt=prompt, **kwargs) |
| success, text, tokens_used = await self.model_manager.generate(request) |
| |
| if success: |
| |
| await self.db.increment_usage(user['id'], tokens_used) |
| |
| return APIResponse( |
| success=True, |
| data={ |
| "generated_text": text, |
| "tokens_used": tokens_used, |
| "user_plan": user.get('plan', 'free'), |
| "requests_remaining": self.config.RATE_LIMIT_PER_HOUR - requests_used - 1 |
| } |
| ) |
| else: |
| return APIResponse( |
| success=False, |
| error=text |
| ) |
| |
| except Exception as e: |
| logger.error(f"Service error generating text: {e}") |
| return APIResponse( |
| success=False, |
| error="Internal service error" |
| ) |
| |
| async def get_user_stats(self, api_key: str) -> APIResponse: |
| """Get user statistics""" |
| try: |
| user = await self.db.validate_api_key(api_key) |
| if not user: |
| return APIResponse( |
| success=False, |
| error="Invalid API key" |
| ) |
| |
| |
| plot = await self.analytics_manager.generate_usage_plot(user) |
| |
| stats = { |
| "name": user.get('name'), |
| "email": user.get('email'), |
| "plan": user.get('plan', 'free'), |
| "total_requests": user.get('requests', 0), |
| "tokens_used": user.get('tokens_used', 0), |
| "requests_this_hour": user.get('requests_this_hour', 0), |
| "rate_limit": self.config.RATE_LIMIT_PER_HOUR, |
| "created_at": user.get('created_at'), |
| "last_request": user.get('last_request'), |
| "plot": plot |
| } |
| |
| return APIResponse(success=True, data=stats) |
| |
| except Exception as e: |
| logger.error(f"Error getting user stats: {e}") |
| return APIResponse( |
| success=False, |
| error="Error retrieving stats" |
| ) |
| |
| async def get_admin_stats(self, api_key: str) -> APIResponse: |
| """Get admin statistics (only for admin users)""" |
| try: |
| user = await self.db.validate_api_key(api_key) |
| if not user or user.get('email') != self.config.ADMIN_EMAIL: |
| return APIResponse( |
| success=False, |
| error="Unauthorized: Admin access required" |
| ) |
| |
| all_users = await self.db.get_all_users_stats() |
| |
| total_requests = sum(user.get('requests', 0) for user in all_users) |
| total_tokens = sum(user.get('tokens_used', 0) for user in all_users) |
| active_users = len([user for user in all_users if user.get('last_request')]) |
| |
| |
| fig = go.Figure() |
| |
| user_names = [user.get('name', 'Unknown') for user in all_users] |
| user_requests = [user.get('requests', 0) for user in all_users] |
| |
| fig.add_trace(go.Bar( |
| x=user_names, |
| y=user_requests, |
| name="User Requests", |
| marker_color='indianred' |
| )) |
| |
| fig.update_layout( |
| title='User Activity', |
| xaxis_title='Users', |
| yaxis_title='Number of Requests', |
| paper_bgcolor='rgba(0,0,0,0)', |
| plot_bgcolor='rgba(0,0,0,0)', |
| font=dict(color='white') |
| ) |
| |
| stats = { |
| "total_users": len(all_users), |
| "active_users": active_users, |
| "total_requests": total_requests, |
| "total_tokens": total_tokens, |
| "users": all_users, |
| "plot": fig |
| } |
| |
| return APIResponse(success=True, data=stats) |
| |
| except Exception as e: |
| logger.error(f"Error getting admin stats: {e}") |
| return APIResponse( |
| success=False, |
| error="Error retrieving admin stats" |
| ) |
|
|
| |
| class GradioInterface: |
| def __init__(self, service: GemmaSaaSService): |
| self.service = service |
| self.examples = [ |
| ["Write a short story about a robot discovering emotions"], |
| ["Explain quantum computing in simple terms"], |
| ["Create a recipe for chocolate chip cookies"], |
| ["Write a poem about the changing seasons"], |
| ["How can I improve my time management skills?"] |
| ] |
| |
| def create_advanced_css(self): |
| return """ |
| :root { |
| --primary: #6366f1; |
| --primary-dark: #4338ca; |
| --secondary: #10b981; |
| --accent: #f59e0b; |
| --danger: #ef4444; |
| --dark: #1f2937; |
| --darker: #111827; |
| --light: #f3f4f6; |
| --lighter: #f9fafb; |
| --card-bg: rgba(255, 255, 255, 0.05); |
| --card-border: rgba(255, 255, 255, 0.1); |
| } |
| |
| body, .gradio-container { |
| background: linear-gradient(135deg, var(--darker) 0%, var(--dark) 100%) !important; |
| color: var(--lighter) !important; |
| font-family: 'Inter', 'Segoe UI', system-ui, sans-serif !important; |
| } |
| |
| .gradio-container { |
| max-width: 1400px !important; |
| margin: 0 auto !important; |
| padding: 1rem !important; |
| } |
| |
| .header { |
| text-align: center; |
| padding: 2rem 1rem; |
| margin-bottom: 2rem; |
| background: var(--card-bg); |
| backdrop-filter: blur(10px); |
| border-radius: 16px; |
| border: 1px solid var(--card-border); |
| box-shadow: 0 8px 32px rgba(0, 0, 0, 0.2); |
| } |
| |
| .header h1 { |
| font-size: 3.5rem; |
| font-weight: 800; |
| background: linear-gradient(135deg, var(--primary) 0%, var(--secondary) 100%); |
| -webkit-background-clip: text; |
| -webkit-text-fill-color: transparent; |
| margin-bottom: 0.5rem; |
| } |
| |
| .header p { |
| font-size: 1.2rem; |
| color: var(--light); |
| opacity: 0.8; |
| max-width: 600px; |
| margin: 0 auto; |
| } |
| |
| .card { |
| background: var(--card-bg) !important; |
| backdrop-filter: blur(10px); |
| border-radius: 16px !important; |
| border: 1px solid var(--card-border) !important; |
| padding: 1.5rem !important; |
| transition: all 0.3s ease !important; |
| box-shadow: 0 4px 20px rgba(0, 0, 0, 0.15) !important; |
| } |
| |
| .card:hover { |
| transform: translateY(-5px); |
| box-shadow: 0 12px 40px rgba(0, 0, 0, 0.25) !important; |
| } |
| |
| .btn { |
| border-radius: 12px !important; |
| padding: 0.75rem 1.5rem !important; |
| font-weight: 600 !important; |
| transition: all 0.2s ease !important; |
| border: none !important; |
| } |
| |
| .btn-primary { |
| background: linear-gradient(135deg, var(--primary) 0%, var(--primary-dark) 100%) !important; |
| color: white !important; |
| } |
| |
| .btn-primary:hover { |
| transform: translateY(-2px); |
| box-shadow: 0 6px 20px rgba(99, 102, 241, 0.4) !important; |
| } |
| |
| .btn-secondary { |
| background: rgba(255, 255, 255, 0.1) !important; |
| color: white !important; |
| border: 1px solid rgba(255, 255, 255, 0.2) !important; |
| } |
| |
| .btn-secondary:hover { |
| background: rgba(255, 255, 255, 0.2) !important; |
| transform: translateY(-2px); |
| } |
| |
| .stat-card { |
| background: linear-gradient(135deg, var(--primary) 0%, var(--primary-dark) 100%); |
| color: white; |
| border-radius: 12px; |
| padding: 1.5rem; |
| text-align: center; |
| } |
| |
| .stat-value { |
| font-size: 2.5rem; |
| font-weight: 700; |
| margin-bottom: 0.5rem; |
| } |
| |
| .stat-label { |
| font-size: 0.9rem; |
| opacity: 0.8; |
| text-transform: uppercase; |
| letter-spacing: 1px; |
| } |
| |
| .alert { |
| padding: 1rem 1.5rem; |
| border-radius: 12px; |
| margin: 1rem 0; |
| } |
| |
| .alert-success { |
| background: linear-gradient(135deg, var(--secondary) 0%, #059669 100%); |
| color: white; |
| } |
| |
| .alert-error { |
| background: linear-gradient(135deg, var(--danger) 0%, #dc2626 100%); |
| color: white; |
| } |
| |
| .tab-nav { |
| background: var(--card-bg) !important; |
| border-radius: 12px !important; |
| padding: 0.5rem !important; |
| margin-bottom: 1.5rem !important; |
| } |
| |
| .tab-nav button { |
| border-radius: 8px !important; |
| padding: 0.75rem 1.5rem !important; |
| font-weight: 600 !important; |
| } |
| |
| .tab-nav button.selected { |
| background: linear-gradient(135deg, var(--primary) 0%, var(--primary-dark) 100%) !important; |
| color: white !important; |
| } |
| |
| input, textarea, select { |
| background: rgba(255, 255, 255, 0.05) !important; |
| border: 1px solid rgba(255, 255, 255, 0.1) !important; |
| border-radius: 12px !important; |
| color: white !important; |
| padding: 0.75rem 1rem !important; |
| } |
| |
| input:focus, textarea:focus, select:focus { |
| border-color: var(--primary) !important; |
| box-shadow: 0 0 0 2px rgba(99, 102, 241, 0.2) !important; |
| } |
| |
| label { |
| color: var(--light) !important; |
| font-weight: 600 !important; |
| margin-bottom: 0.5rem !important; |
| } |
| |
| .footer { |
| text-align: center; |
| padding: 2rem 1rem; |
| margin-top: 3rem; |
| color: var(--light); |
| opacity: 0.7; |
| font-size: 0.9rem; |
| } |
| |
| .badge { |
| display: inline-block; |
| padding: 0.25rem 0.75rem; |
| border-radius: 9999px; |
| font-size: 0.75rem; |
| font-weight: 700; |
| } |
| |
| .badge-free { |
| background: var(--secondary); |
| color: white; |
| } |
| |
| .badge-pro { |
| background: var(--primary); |
| color: white; |
| } |
| |
| .badge-enterprise { |
| background: var(--accent); |
| color: white; |
| } |
| |
| .example-card { |
| cursor: pointer; |
| transition: all 0.2s ease; |
| } |
| |
| .example-card:hover { |
| transform: translateY(-3px); |
| box-shadow: 0 8px 25px rgba(0, 0, 0, 0.2) !important; |
| } |
| |
| .copy-btn { |
| position: absolute; |
| top: 0.5rem; |
| right: 0.5rem; |
| background: rgba(255, 255, 255, 0.1); |
| border: none; |
| border-radius: 8px; |
| padding: 0.5rem; |
| cursor: pointer; |
| opacity: 0.7; |
| transition: all 0.2s ease; |
| } |
| |
| .copy-btn:hover { |
| opacity: 1; |
| background: rgba(255, 255, 255, 0.2); |
| } |
| |
| .dark-mode-plot { |
| background: transparent !important; |
| } |
| """ |
| |
| def create_header(self): |
| return gr.HTML(""" |
| <div class="header"> |
| <h1>🚀 Gemma AI Platform</h1> |
| <p>Advanced text generation powered by state-of-the-art AI models. Create content, analyze text, and unlock new possibilities.</p> |
| </div> |
| """) |
| |
| def create_footer(self): |
| return gr.HTML(""" |
| <div class="footer"> |
| <p>Gemma AI Platform © 2023 | Built with ❤️ using Hugging Face, Gradio, and Supabase</p> |
| <p>Version 2.0 | <a href="#" style="color: #ccc; text-decoration: none;">Terms of Service</a> | <a href="#" style="color: #ccc; text-decoration: none;">Privacy Policy</a></p> |
| </div> |
| """) |
| |
| def create_examples_component(self): |
| with gr.Column(elem_classes=["card"]): |
| gr.Markdown("### 💡 Example Prompts") |
| |
| examples_html = """ |
| <div style="display: grid; grid-template-columns: repeat(auto-fit, minmax(250px, 1fr)); gap: 1rem;"> |
| """ |
| |
| for i, example in enumerate(self.examples): |
| examples_html += f""" |
| <div class="card example-card" onclick="document.getElementById('prompt-input').value = `{example[0]}`"> |
| <div style="font-size: 0.9rem; opacity: 0.9;">{example[0]}</div> |
| </div> |
| """ |
| |
| examples_html += "</div>" |
| |
| return gr.HTML(examples_html) |
| |
| async def create_interface(self): |
| """Create the enhanced Gradio interface""" |
| with gr.Blocks( |
| css=self.create_advanced_css(), |
| title="Gemma AI Platform", |
| theme=gr.themes.Default(primary_hue="indigo", secondary_hue="emerald") |
| ) as app: |
| |
| |
| self.create_header() |
| |
| |
| with gr.Tabs(elem_classes=["tab-nav"]): |
| |
| |
| with gr.Tab("🎮 Playground", elem_classes=["card"]): |
| with gr.Row(): |
| with gr.Column(scale=1): |
| gr.Markdown("### ⚙️ Configuration") |
| api_key_playground = gr.Textbox( |
| label="🔑 API Key", |
| type="password", |
| placeholder="Enter your API key...", |
| elem_classes=["input"] |
| ) |
| |
| with gr.Accordion("Advanced Settings", open=False): |
| max_tokens_input = gr.Slider( |
| minimum=50, maximum=1000, value=200, |
| label="Max Tokens", |
| info="Maximum number of tokens to generate" |
| ) |
| temperature_input = gr.Slider( |
| minimum=0.1, maximum=2.0, value=0.7, |
| label="Temperature", |
| info="Higher values = more creative, lower values = more focused" |
| ) |
| top_k_input = gr.Slider( |
| minimum=1, maximum=100, value=50, |
| label="Top K", |
| info="Consider only the top K tokens" |
| ) |
| top_p_input = gr.Slider( |
| minimum=0.1, maximum=1.0, value=0.95, |
| label="Top P", |
| info="Nucleus sampling: consider only tokens with cumulative probability" |
| ) |
| repetition_penalty_input = gr.Slider( |
| minimum=0.1, maximum=2.0, value=1.0, |
| label="Repetition Penalty", |
| info="Penalize repeated tokens (1.0 = no penalty)" |
| ) |
| |
| with gr.Column(scale=2): |
| |
| self.create_examples_component() |
| |
| gr.Markdown("### 💬 Text Generation") |
| prompt_input = gr.Textbox( |
| label="✍️ Your Prompt", |
| lines=6, |
| placeholder="Enter your prompt here... (e.g., 'Write a short story about a robot discovering emotions')", |
| elem_id="prompt-input", |
| elem_classes=["input"] |
| ) |
| |
| with gr.Row(): |
| generate_btn = gr.Button( |
| "🚀 Generate", |
| elem_classes=["btn", "btn-primary"], |
| variant="primary" |
| ) |
| clear_btn = gr.Button( |
| "🗑️ Clear", |
| elem_classes=["btn", "btn-secondary"] |
| ) |
| |
| output_text = gr.Textbox( |
| label="📝 Generated Text", |
| lines=8, |
| interactive=False, |
| elem_classes=["input"] |
| ) |
| |
| |
| generation_stats = gr.JSON( |
| label="📊 Generation Statistics", |
| visible=False |
| ) |
| |
| |
| with gr.Tab("👤 Profile", elem_classes=["card"]): |
| with gr.Row(): |
| with gr.Column(scale=1): |
| gr.Markdown("### 🆕 Create Account") |
| name_input = gr.Textbox( |
| label="👤 Full Name", |
| elem_classes=["input"] |
| ) |
| email_input = gr.Textbox( |
| label="📧 Email Address", |
| elem_classes=["input"] |
| ) |
| plan_input = gr.Dropdown( |
| choices=["free", "pro", "enterprise"], |
| value="free", |
| label="📋 Plan", |
| elem_classes=["input"] |
| ) |
| |
| create_btn = gr.Button( |
| "✨ Create API Key", |
| elem_classes=["btn", "btn-primary"], |
| variant="primary" |
| ) |
| |
| creation_status = gr.HTML() |
| api_key_display = gr.Textbox( |
| label="🔑 Your API Key", |
| interactive=False, |
| visible=False, |
| elem_classes=["input"] |
| ) |
| |
| with gr.Column(scale=1): |
| gr.Markdown("### 📊 Account Statistics") |
| stats_api_key = gr.Textbox( |
| label="🔑 API Key", |
| type="password", |
| placeholder="Enter API key to view stats", |
| elem_classes=["input"] |
| ) |
| |
| refresh_stats_btn = gr.Button( |
| "🔄 Refresh Stats", |
| elem_classes=["btn", "btn-secondary"] |
| ) |
| |
| user_stats_display = gr.HTML() |
| |
| |
| with gr.Tab("📈 Analytics", elem_classes=["card"]): |
| gr.Markdown("### 📊 Usage Analytics") |
| |
| with gr.Row(): |
| with gr.Column(scale=1): |
| analytics_api_key = gr.Textbox( |
| label="🔑 API Key", |
| type="password", |
| placeholder="Enter API key to view analytics", |
| elem_classes=["input"] |
| ) |
| refresh_analytics_btn = gr.Button( |
| "📈 Generate Analytics", |
| elem_classes=["btn", "btn-primary"] |
| ) |
| |
| analytics_plot = gr.Plot(label="Usage Analytics") |
| |
| |
| with gr.Tab("👑 Admin", elem_classes=["card"], visible=bool(self.service.config.ADMIN_EMAIL)): |
| gr.Markdown("### 👑 Admin Dashboard") |
| |
| with gr.Row(): |
| with gr.Column(scale=1): |
| admin_api_key = gr.Textbox( |
| label="🔑 Admin API Key", |
| type="password", |
| placeholder="Enter admin API key", |
| elem_classes=["input"] |
| ) |
| refresh_admin_btn = gr.Button( |
| "🔄 Refresh Admin Stats", |
| elem_classes=["btn", "btn-primary"] |
| ) |
| |
| admin_plot = gr.Plot(label="Platform Analytics") |
| admin_stats_display = gr.HTML() |
| |
| |
| self.create_footer() |
| |
| |
| async def handle_generation(prompt, api_key, max_tokens, temperature, top_k, top_p, repetition_penalty): |
| if not api_key.strip(): |
| return "⚠️ Please enter your API key", {}, False |
| |
| response = await self.service.generate_text( |
| prompt=prompt, |
| api_key=api_key, |
| max_tokens=max_tokens, |
| temperature=temperature, |
| top_k=top_k, |
| top_p=top_p, |
| repetition_penalty=repetition_penalty |
| ) |
| |
| if response.success: |
| stats_html = f""" |
| <div style="display: grid; grid-template-columns: repeat(auto-fit, minmax(150px, 1fr)); gap: 1rem; margin-top: 1rem;"> |
| <div class="stat-card"> |
| <div class="stat-value">{response.data['tokens_used']}</div> |
| <div class="stat-label">Tokens Used</div> |
| </div> |
| <div class="stat-card"> |
| <div class="stat-value">{response.data['requests_remaining']}</div> |
| <div class="stat-label">Requests Left</div> |
| </div> |
| <div class="stat-card"> |
| <div class="stat-value">{response.data['user_plan'].title()}</div> |
| <div class="stat-label">Current Plan</div> |
| </div> |
| </div> |
| """ |
| return ( |
| response.data["generated_text"], |
| stats_html, |
| True |
| ) |
| else: |
| return ( |
| response.error, |
| f'<div class="alert alert-error">{response.error}</div>', |
| False |
| ) |
| |
| async def handle_user_creation(name, email, plan): |
| if not name or not name.strip(): |
| return ( |
| f'<div class="alert alert-error">❌ Name is required</div>', |
| "", |
| False |
| ) |
| |
| if not email or not email.strip(): |
| return ( |
| f'<div class="alert alert-error">❌ Email is required</div>', |
| "", |
| False |
| ) |
| |
| response = await self.service.create_user(name, email, plan) |
| |
| if response.success: |
| return ( |
| f'<div class="alert alert-success">✅ Account created successfully! Your API key is below.</div>', |
| response.data["api_key"], |
| True |
| ) |
| else: |
| return ( |
| f'<div class="alert alert-error">❌ {response.error}</div>', |
| "", |
| False |
| ) |
| |
| async def handle_stats_refresh(api_key): |
| response = await self.service.get_user_stats(api_key) |
| |
| if response.success: |
| stats = response.data |
| return f""" |
| <div> |
| <div style="display: grid; grid-template-columns: repeat(auto-fit, minmax(150px, 1fr)); gap: 1rem; margin-bottom: 2rem;"> |
| <div class="stat-card"> |
| <div class="stat-value">{stats['total_requests']}</div> |
| <div class="stat-label">Total Requests</div> |
| </div> |
| <div class="stat-card"> |
| <div class="stat-value">{stats['tokens_used']:,}</div> |
| <div class="stat-label">Tokens Used</div> |
| </div> |
| <div class="stat-card"> |
| <div class="stat-value">{stats['requests_this_hour']}/{stats['rate_limit']}</div> |
| <div class="stat-label">Hourly Usage</div> |
| </div> |
| <div class="stat-card"> |
| <div class="stat-value">{stats['plan'].title()}</div> |
| <div class="stat-label">Current Plan</div> |
| </div> |
| </div> |
| <div style="background: var(--card-bg); padding: 1.5rem; border-radius: 12px; border: 1px solid var(--card-border);"> |
| <h3 style="margin-top: 0;">Account Details</h3> |
| <p><strong>Name:</strong> {stats['name']}</p> |
| <p><strong>Email:</strong> {stats['email']}</p> |
| <p><strong>Member since:</strong> {stats['created_at'][:10] if stats['created_at'] else 'N/A'}</p> |
| <p><strong>Last request:</strong> {stats['last_request'][:19] if stats['last_request'] else 'Never'}</p> |
| </div> |
| </div> |
| """ |
| else: |
| return f'<div class="alert alert-error">❌ {response.error}</div>' |
| |
| async def handle_analytics_refresh(api_key): |
| response = await self.service.get_user_stats(api_key) |
| |
| if response.success: |
| return response.data["plot"] |
| else: |
| return go.Figure().update_layout( |
| title="Error loading analytics", |
| paper_bgcolor='rgba(0,0,0,0)', |
| plot_bgcolor='rgba(0,0,0,0)', |
| font=dict(color='white') |
| ) |
| |
| async def handle_admin_refresh(api_key): |
| response = await self.service.get_admin_stats(api_key) |
| |
| if response.success: |
| stats = response.data |
| |
| stats_html = f""" |
| <div style="margin-bottom: 2rem;"> |
| <div style="display: grid; grid-template-columns: repeat(auto-fit, minmax(150px, 1fr)); gap: 1rem;"> |
| <div class="stat-card"> |
| <div class="stat-value">{stats['total_users']}</div> |
| <div class="stat-label">Total Users</div> |
| </div> |
| <div class="stat-card"> |
| <div class="stat-value">{stats['active_users']}</div> |
| <div class="stat-label">Active Users</div> |
| </div> |
| <div class="stat-card"> |
| <div class="stat-value">{stats['total_requests']}</div> |
| <div class="stat-label">Total Requests</div> |
| </div> |
| <div class="stat-card"> |
| <div class="stat-value">{stats['total_tokens']:,}</div> |
| <div class="stat-label">Tokens Generated</div> |
| </div> |
| </div> |
| </div> |
| <div> |
| <h3>User List</h3> |
| <div style="max-height: 300px; overflow-y: auto;"> |
| <table style="width: 100%; border-collapse: collapse;"> |
| <thead> |
| <tr style="border-bottom: 1px solid rgba(255,255,255,0.1);"> |
| <th style="text-align: left; padding: 0.75rem;">Name</th> |
| <th style="text-align: left; padding: 0.75rem;">Email</th> |
| <th style="text-align: left; padding: 0.75rem;">Plan</th> |
| <th style="text-align: left; padding: 0.75rem;">Requests</th> |
| <th style="text-align: left; padding: 0.75rem;">Tokens</th> |
| </tr> |
| </thead> |
| <tbody> |
| """ |
| |
| for user in stats['users']: |
| stats_html += f""" |
| <tr style="border-bottom: 1px solid rgba(255,255,255,0.05);"> |
| <td style="padding: 0.75rem;">{user.get('name', 'N/A')}</td> |
| <td style="padding: 0.75rem;">{user.get('email', 'N/A')}</td> |
| <td style="padding: 0.75rem;"><span class="badge badge-{user.get('plan', 'free')}">{user.get('plan', 'free').title()}</span></td> |
| <td style="padding: 0.75rem;">{user.get('requests', 0)}</td> |
| <td style="padding: 0.75rem;">{user.get('tokens_used', 0):,}</td> |
| </tr> |
| """ |
| |
| stats_html += """ |
| </tbody> |
| </table> |
| </div> |
| </div> |
| """ |
| |
| return stats["plot"], stats_html |
| else: |
| return go.Figure().update_layout( |
| title="Error loading admin data", |
| paper_bgcolor='rgba(0,0,0,0)', |
| plot_bgcolor='rgba(0,0,0,0)', |
| font=dict(color='white') |
| ), f'<div class="alert alert-error">❌ {response.error}</div>' |
| |
| |
| generate_btn.click( |
| fn=handle_generation, |
| inputs=[ |
| prompt_input, api_key_playground, max_tokens_input, |
| temperature_input, top_k_input, top_p_input, repetition_penalty_input |
| ], |
| outputs=[output_text, generation_stats, generation_stats] |
| ) |
| |
| clear_btn.click( |
| fn=lambda: ("", "", False), |
| inputs=[], |
| outputs=[prompt_input, output_text, generation_stats] |
| ) |
| |
| create_btn.click( |
| fn=handle_user_creation, |
| inputs=[name_input, email_input, plan_input], |
| outputs=[creation_status, api_key_display, api_key_display] |
| ) |
| |
| refresh_stats_btn.click( |
| fn=handle_stats_refresh, |
| inputs=[stats_api_key], |
| outputs=[user_stats_display] |
| ) |
| |
| refresh_analytics_btn.click( |
| fn=handle_analytics_refresh, |
| inputs=[analytics_api_key], |
| outputs=[analytics_plot] |
| ) |
| |
| refresh_admin_btn.click( |
| fn=handle_admin_refresh, |
| inputs=[admin_api_key], |
| outputs=[admin_plot, admin_stats_display] |
| ) |
| |
| return app |
|
|
| |
| async def main(): |
| """Main application entry point""" |
| try: |
| |
| service = GemmaSaaSService() |
| await service.initialize() |
| |
| |
| interface = GradioInterface(service) |
| app = await interface.create_interface() |
| |
| |
| app.launch( |
| server_name="0.0.0.0", |
| server_port=7860, |
| share=False, |
| debug=False, |
| show_error=True, |
| quiet=False, |
| favicon_path=None, |
| ssl_keyfile=None, |
| ssl_certfile=None, |
| auth=None, |
| max_threads=10 |
| ) |
| |
| except Exception as e: |
| logger.error(f"Failed to start application: {e}") |
| raise |
|
|
| if __name__ == "__main__": |
| |
| asyncio.run(main()) |