Teste / app.py
GuXSs's picture
Update app.py
7d3e98e verified
raw
history blame
52.8 kB
import os
import uuid
import json
import asyncio
import logging
import time
from datetime import datetime, timedelta
from typing import Dict, List, Optional, Tuple, Any
from dataclasses import dataclass
from contextlib import asynccontextmanager
import gradio as gr
import requests
import aiohttp
import asyncpg
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from dotenv import load_dotenv
from pydantic import BaseModel, ValidationError
import jwt
from functools import wraps
import hashlib
import secrets
import plotly.graph_objects as go
from plotly.subplots import make_subplots
# ----------------- Configuration & Models -----------------
load_dotenv()
@dataclass
class Config:
HF_TOKEN: str = os.getenv("HF_TOKEN", "")
MODEL_NAME: str = os.getenv("MODEL_NAME", "google/gemma-3-270m-it")
SUPABASE_URL: str = os.getenv("SUPABASE_URL", "")
SUPABASE_KEY: str = os.getenv("SUPABASE_KEY", "")
JWT_SECRET: str = os.getenv("JWT_SECRET", secrets.token_urlsafe(32))
RATE_LIMIT_PER_HOUR: int = int(os.getenv("RATE_LIMIT_PER_HOUR", "100"))
MAX_TOKENS: int = int(os.getenv("MAX_TOKENS", "500"))
LOG_LEVEL: str = os.getenv("LOG_LEVEL", "INFO")
ADMIN_EMAIL: str = os.getenv("ADMIN_EMAIL", "")
class GenerationRequest(BaseModel):
prompt: str
max_tokens: int = 200
temperature: float = 0.7
top_k: int = 50
top_p: float = 0.95
repetition_penalty: float = 1.0
class UserCreate(BaseModel):
name: str
email: str
plan: str = "free"
class APIResponse(BaseModel):
success: bool
data: Any = None
error: Optional[str] = None
timestamp: datetime = datetime.now()
# ----------------- Enhanced Logger -----------------
def setup_logger():
logging.basicConfig(
level=getattr(logging, Config().LOG_LEVEL),
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler('gemma_saas.log'),
logging.StreamHandler()
]
)
return logging.getLogger(__name__)
logger = setup_logger()
# ----------------- Database Manager -----------------
class DatabaseManager:
def __init__(self, config: Config):
self.config = config
self.headers = {
"apikey": config.SUPABASE_KEY,
"Authorization": f"Bearer {config.SUPABASE_KEY}",
"Content-Type": "application/json"
}
async def create_user(self, user_data: UserCreate, hf_user_id: str = None) -> Tuple[bool, str, str]:
"""Create user with enhanced validation and security"""
try:
# Check if user already exists
async with aiohttp.ClientSession() as session:
async with session.get(
f"{self.config.SUPABASE_URL}/rest/v1/users?email=eq.{user_data.email}",
headers=self.headers
) as response:
if response.status == 200:
existing_users = await response.json()
if existing_users:
return False, "❌ User with this email already exists", ""
# Generate secure API key
api_key = self._generate_api_key()
data = {
"name": user_data.name.strip(),
"email": user_data.email.strip(),
"api_key": api_key,
"hf_user_id": hf_user_id,
"requests": 0,
"plan": user_data.plan,
"created_at": datetime.now().isoformat(),
"last_request": None,
"requests_this_hour": 0,
"rate_limit_reset": (datetime.now() + timedelta(hours=1)).isoformat(),
"tokens_used": 0
}
async with aiohttp.ClientSession() as session:
async with session.post(
f"{self.config.SUPABASE_URL}/rest/v1/users",
headers=self.headers,
data=json.dumps(data)
) as response:
if response.status == 201:
logger.info(f"User created successfully: {user_data.email}")
return True, f"✅ User created successfully for {user_data.name}", api_key
else:
error_text = await response.text()
logger.error(f"Error creating user: {error_text}")
return False, f"❌ Error creating user: {error_text}", ""
except Exception as e:
logger.error(f"Database error creating user: {e}")
return False, f"❌ Database error: {str(e)}", ""
def _generate_api_key(self) -> str:
"""Generate secure API key with prefix"""
return f"gsa_{secrets.token_urlsafe(32)}"
async def validate_api_key(self, api_key: str) -> Optional[Dict]:
"""Validate API key and return user data"""
try:
async with aiohttp.ClientSession() as session:
async with session.get(
f"{self.config.SUPABASE_URL}/rest/v1/users?api_key=eq.{api_key}",
headers=self.headers
) as response:
if response.status == 200:
data = await response.json()
return data[0] if data else None
return None
except Exception as e:
logger.error(f"Error validating API key: {e}")
return None
async def check_rate_limit(self, user_id: int) -> Tuple[bool, int]:
"""Check if user has exceeded rate limit"""
try:
async with aiohttp.ClientSession() as session:
async with session.get(
f"{self.config.SUPABASE_URL}/rest/v1/users?id=eq.{user_id}",
headers=self.headers
) as response:
if response.status == 200:
data = await response.json()
if data:
user = data[0]
reset_time = datetime.fromisoformat(user.get('rate_limit_reset', datetime.now().isoformat()))
if datetime.now() > reset_time:
# Reset rate limit
await self._reset_rate_limit(user_id)
return True, 0
requests_this_hour = user.get('requests_this_hour', 0)
return requests_this_hour < self.config.RATE_LIMIT_PER_HOUR, requests_this_hour
return False, 0
except Exception as e:
logger.error(f"Error checking rate limit: {e}")
return False, 0
async def _reset_rate_limit(self, user_id: int):
"""Reset hourly rate limit"""
try:
data = {
"requests_this_hour": 0,
"rate_limit_reset": (datetime.now() + timedelta(hours=1)).isoformat()
}
async with aiohttp.ClientSession() as session:
async with session.patch(
f"{self.config.SUPABASE_URL}/rest/v1/users?id=eq.{user_id}",
headers=self.headers,
data=json.dumps(data)
) as response:
if response.status != 200:
logger.error(f"Failed to reset rate limit for user {user_id}")
except Exception as e:
logger.error(f"Error resetting rate limit: {e}")
async def increment_usage(self, user_id: int, tokens_used: int):
"""Increment user usage statistics"""
try:
# Get current stats
async with aiohttp.ClientSession() as session:
async with session.get(
f"{self.config.SUPABASE_URL}/rest/v1/users?id=eq.{user_id}",
headers=self.headers
) as response:
if response.status == 200:
data = await response.json()
if data:
user = data[0]
new_data = {
"requests": user.get('requests', 0) + 1,
"requests_this_hour": user.get('requests_this_hour', 0) + 1,
"tokens_used": user.get('tokens_used', 0) + tokens_used,
"last_request": datetime.now().isoformat()
}
await session.patch(
f"{self.config.SUPABASE_URL}/rest/v1/users?id=eq.{user_id}",
headers=self.headers,
data=json.dumps(new_data)
)
except Exception as e:
logger.error(f"Error incrementing usage: {e}")
async def get_all_users_stats(self):
"""Get statistics for all users (admin only)"""
try:
async with aiohttp.ClientSession() as session:
async with session.get(
f"{self.config.SUPABASE_URL}/rest/v1/users?select=*",
headers=self.headers
) as response:
if response.status == 200:
return await response.json()
return []
except Exception as e:
logger.error(f"Error getting all users stats: {e}")
return []
# ----------------- Model Manager -----------------
class ModelManager:
def __init__(self, config: Config):
self.config = config
self.tokenizer = None
self.model = None
self.pipeline = None
self.model_loaded = False
async def initialize(self):
"""Initialize model asynchronously"""
try:
logger.info("Loading model...")
loop = asyncio.get_event_loop()
# Load tokenizer
self.tokenizer = await loop.run_in_executor(
None,
lambda: AutoTokenizer.from_pretrained(
self.config.MODEL_NAME,
token=self.config.HF_TOKEN,
trust_remote_code=True
)
)
# Load model
self.model = await loop.run_in_executor(
None,
lambda: AutoModelForCausalLM.from_pretrained(
self.config.MODEL_NAME,
token=self.config.HF_TOKEN,
device_map="auto",
torch_dtype="auto",
trust_remote_code=True
)
)
# Create pipeline
self.pipeline = await loop.run_in_executor(
None,
lambda: pipeline(
"text-generation",
model=self.model,
tokenizer=self.tokenizer,
pad_token_id=self.tokenizer.eos_token_id
)
)
self.model_loaded = True
logger.info("✅ Model loaded successfully!")
except Exception as e:
logger.error(f"❌ Error loading model: {e}")
self.model_loaded = False
async def generate(self, request: GenerationRequest) -> Tuple[bool, str, int]:
"""Generate text with the model"""
if not self.model_loaded:
return False, "❌ Model not loaded", 0
try:
# Input validation
if len(request.prompt.strip()) == 0:
return False, "⚠️ Prompt cannot be empty", 0
if len(request.prompt) > 4000:
return False, "⚠️ Prompt too long (max 4000 characters)", 0
# Generate text
loop = asyncio.get_event_loop()
result = await loop.run_in_executor(
None,
lambda: self.pipeline(
request.prompt.strip(),
max_new_tokens=min(request.max_tokens, self.config.MAX_TOKENS),
do_sample=True,
temperature=request.temperature,
top_k=request.top_k,
top_p=request.top_p,
repetition_penalty=request.repetition_penalty,
pad_token_id=self.tokenizer.eos_token_id,
return_full_text=False
)
)
generated_text = result[0]["generated_text"]
tokens_used = len(self.tokenizer.encode(generated_text))
return True, generated_text, tokens_used
except Exception as e:
logger.error(f"Generation error: {e}")
return False, f"❌ Generation failed: {str(e)}", 0
# ----------------- Analytics Manager -----------------
class AnalyticsManager:
def __init__(self, db: DatabaseManager):
self.db = db
async def generate_usage_plot(self, user_data: Dict) -> go.Figure:
"""Generate a usage plot for the user"""
fig = make_subplots(
rows=1, cols=2,
subplot_titles=('Requests Over Time', 'Token Usage Distribution'),
specs=[[{"type": "scatter"}, {"type": "pie"}]]
)
# Mock data for the chart (in a real app, you'd get this from the database)
dates = [datetime.now() - timedelta(days=i) for i in range(7, 0, -1)]
requests = [user_data.get('requests', 0) // 7] * 7
fig.add_trace(
go.Scatter(x=dates, y=requests, mode='lines+markers', name='Requests'),
row=1, col=1
)
# Token usage pie chart
token_categories = ['Generated', 'Prompt', 'Other']
token_values = [user_data.get('tokens_used', 0), user_data.get('tokens_used', 0) // 3, user_data.get('tokens_used', 0) // 5]
fig.add_trace(
go.Pie(labels=token_categories, values=token_values, name="Token Usage"),
row=1, col=2
)
fig.update_layout(
height=400,
showlegend=True,
paper_bgcolor='rgba(0,0,0,0)',
plot_bgcolor='rgba(0,0,0,0)',
font=dict(color='white')
)
return fig
# ----------------- Service Layer -----------------
class GemmaSaaSService:
def __init__(self):
self.config = Config()
self.db = DatabaseManager(self.config)
self.model_manager = ModelManager(self.config)
self.analytics_manager = AnalyticsManager(self.db)
self._validate_config()
def _validate_config(self):
"""Validate required configuration"""
required_fields = ['HF_TOKEN', 'SUPABASE_URL', 'SUPABASE_KEY']
missing_fields = [field for field in required_fields if not getattr(self.config, field)]
if missing_fields:
raise ValueError(f"Missing required environment variables: {', '.join(missing_fields)}")
async def initialize(self):
"""Initialize all services"""
await self.model_manager.initialize()
async def create_user(self, name: str, email: str, plan: str = "free") -> APIResponse:
"""Create new user with API key"""
try:
user_data = UserCreate(name=name, email=email, plan=plan)
success, message, api_key = await self.db.create_user(user_data)
return APIResponse(
success=success,
data={"api_key": api_key, "message": message} if success else None,
error=message if not success else None
)
except ValidationError as e:
return APIResponse(
success=False,
error=f"Validation error: {str(e)}"
)
except Exception as e:
logger.error(f"Service error creating user: {e}")
return APIResponse(
success=False,
error="Internal service error"
)
async def generate_text(self, prompt: str, api_key: str, **kwargs) -> APIResponse:
"""Generate text with authentication and rate limiting"""
try:
# Validate API key
user = await self.db.validate_api_key(api_key)
if not user:
return APIResponse(
success=False,
error="⚠️ Invalid API key"
)
# Check rate limit
can_make_request, requests_used = await self.db.check_rate_limit(user['id'])
if not can_make_request:
reset_time = datetime.fromisoformat(user.get('rate_limit_reset', ''))
time_remaining = reset_time - datetime.now()
mins_remaining = max(0, int(time_remaining.total_seconds() / 60))
return APIResponse(
success=False,
error=f"⚠️ Rate limit exceeded ({requests_used}/{self.config.RATE_LIMIT_PER_HOUR}). Try again in {mins_remaining} minutes."
)
# Generate text
request = GenerationRequest(prompt=prompt, **kwargs)
success, text, tokens_used = await self.model_manager.generate(request)
if success:
# Update usage statistics
await self.db.increment_usage(user['id'], tokens_used)
return APIResponse(
success=True,
data={
"generated_text": text,
"tokens_used": tokens_used,
"user_plan": user.get('plan', 'free'),
"requests_remaining": self.config.RATE_LIMIT_PER_HOUR - requests_used - 1
}
)
else:
return APIResponse(
success=False,
error=text
)
except Exception as e:
logger.error(f"Service error generating text: {e}")
return APIResponse(
success=False,
error="Internal service error"
)
async def get_user_stats(self, api_key: str) -> APIResponse:
"""Get user statistics"""
try:
user = await self.db.validate_api_key(api_key)
if not user:
return APIResponse(
success=False,
error="Invalid API key"
)
# Generate analytics plot
plot = await self.analytics_manager.generate_usage_plot(user)
stats = {
"name": user.get('name'),
"email": user.get('email'),
"plan": user.get('plan', 'free'),
"total_requests": user.get('requests', 0),
"tokens_used": user.get('tokens_used', 0),
"requests_this_hour": user.get('requests_this_hour', 0),
"rate_limit": self.config.RATE_LIMIT_PER_HOUR,
"created_at": user.get('created_at'),
"last_request": user.get('last_request'),
"plot": plot
}
return APIResponse(success=True, data=stats)
except Exception as e:
logger.error(f"Error getting user stats: {e}")
return APIResponse(
success=False,
error="Error retrieving stats"
)
async def get_admin_stats(self, api_key: str) -> APIResponse:
"""Get admin statistics (only for admin users)"""
try:
user = await self.db.validate_api_key(api_key)
if not user or user.get('email') != self.config.ADMIN_EMAIL:
return APIResponse(
success=False,
error="Unauthorized: Admin access required"
)
all_users = await self.db.get_all_users_stats()
total_requests = sum(user.get('requests', 0) for user in all_users)
total_tokens = sum(user.get('tokens_used', 0) for user in all_users)
active_users = len([user for user in all_users if user.get('last_request')])
# Generate admin dashboard plot
fig = go.Figure()
user_names = [user.get('name', 'Unknown') for user in all_users]
user_requests = [user.get('requests', 0) for user in all_users]
fig.add_trace(go.Bar(
x=user_names,
y=user_requests,
name="User Requests",
marker_color='indianred'
))
fig.update_layout(
title='User Activity',
xaxis_title='Users',
yaxis_title='Number of Requests',
paper_bgcolor='rgba(0,0,0,0)',
plot_bgcolor='rgba(0,0,0,0)',
font=dict(color='white')
)
stats = {
"total_users": len(all_users),
"active_users": active_users,
"total_requests": total_requests,
"total_tokens": total_tokens,
"users": all_users,
"plot": fig
}
return APIResponse(success=True, data=stats)
except Exception as e:
logger.error(f"Error getting admin stats: {e}")
return APIResponse(
success=False,
error="Error retrieving admin stats"
)
# ----------------- Enhanced UI -----------------
class GradioInterface:
def __init__(self, service: GemmaSaaSService):
self.service = service
self.examples = [
["Write a short story about a robot discovering emotions"],
["Explain quantum computing in simple terms"],
["Create a recipe for chocolate chip cookies"],
["Write a poem about the changing seasons"],
["How can I improve my time management skills?"]
]
def create_advanced_css(self):
return """
:root {
--primary: #6366f1;
--primary-dark: #4338ca;
--secondary: #10b981;
--accent: #f59e0b;
--danger: #ef4444;
--dark: #1f2937;
--darker: #111827;
--light: #f3f4f6;
--lighter: #f9fafb;
--card-bg: rgba(255, 255, 255, 0.05);
--card-border: rgba(255, 255, 255, 0.1);
}
body, .gradio-container {
background: linear-gradient(135deg, var(--darker) 0%, var(--dark) 100%) !important;
color: var(--lighter) !important;
font-family: 'Inter', 'Segoe UI', system-ui, sans-serif !important;
}
.gradio-container {
max-width: 1400px !important;
margin: 0 auto !important;
padding: 1rem !important;
}
.header {
text-align: center;
padding: 2rem 1rem;
margin-bottom: 2rem;
background: var(--card-bg);
backdrop-filter: blur(10px);
border-radius: 16px;
border: 1px solid var(--card-border);
box-shadow: 0 8px 32px rgba(0, 0, 0, 0.2);
}
.header h1 {
font-size: 3.5rem;
font-weight: 800;
background: linear-gradient(135deg, var(--primary) 0%, var(--secondary) 100%);
-webkit-background-clip: text;
-webkit-text-fill-color: transparent;
margin-bottom: 0.5rem;
}
.header p {
font-size: 1.2rem;
color: var(--light);
opacity: 0.8;
max-width: 600px;
margin: 0 auto;
}
.card {
background: var(--card-bg) !important;
backdrop-filter: blur(10px);
border-radius: 16px !important;
border: 1px solid var(--card-border) !important;
padding: 1.5rem !important;
transition: all 0.3s ease !important;
box-shadow: 0 4px 20px rgba(0, 0, 0, 0.15) !important;
}
.card:hover {
transform: translateY(-5px);
box-shadow: 0 12px 40px rgba(0, 0, 0, 0.25) !important;
}
.btn {
border-radius: 12px !important;
padding: 0.75rem 1.5rem !important;
font-weight: 600 !important;
transition: all 0.2s ease !important;
border: none !important;
}
.btn-primary {
background: linear-gradient(135deg, var(--primary) 0%, var(--primary-dark) 100%) !important;
color: white !important;
}
.btn-primary:hover {
transform: translateY(-2px);
box-shadow: 0 6px 20px rgba(99, 102, 241, 0.4) !important;
}
.btn-secondary {
background: rgba(255, 255, 255, 0.1) !important;
color: white !important;
border: 1px solid rgba(255, 255, 255, 0.2) !important;
}
.btn-secondary:hover {
background: rgba(255, 255, 255, 0.2) !important;
transform: translateY(-2px);
}
.stat-card {
background: linear-gradient(135deg, var(--primary) 0%, var(--primary-dark) 100%);
color: white;
border-radius: 12px;
padding: 1.5rem;
text-align: center;
}
.stat-value {
font-size: 2.5rem;
font-weight: 700;
margin-bottom: 0.5rem;
}
.stat-label {
font-size: 0.9rem;
opacity: 0.8;
text-transform: uppercase;
letter-spacing: 1px;
}
.alert {
padding: 1rem 1.5rem;
border-radius: 12px;
margin: 1rem 0;
}
.alert-success {
background: linear-gradient(135deg, var(--secondary) 0%, #059669 100%);
color: white;
}
.alert-error {
background: linear-gradient(135deg, var(--danger) 0%, #dc2626 100%);
color: white;
}
.tab-nav {
background: var(--card-bg) !important;
border-radius: 12px !important;
padding: 0.5rem !important;
margin-bottom: 1.5rem !important;
}
.tab-nav button {
border-radius: 8px !important;
padding: 0.75rem 1.5rem !important;
font-weight: 600 !important;
}
.tab-nav button.selected {
background: linear-gradient(135deg, var(--primary) 0%, var(--primary-dark) 100%) !important;
color: white !important;
}
input, textarea, select {
background: rgba(255, 255, 255, 0.05) !important;
border: 1px solid rgba(255, 255, 255, 0.1) !important;
border-radius: 12px !important;
color: white !important;
padding: 0.75rem 1rem !important;
}
input:focus, textarea:focus, select:focus {
border-color: var(--primary) !important;
box-shadow: 0 0 0 2px rgba(99, 102, 241, 0.2) !important;
}
label {
color: var(--light) !important;
font-weight: 600 !important;
margin-bottom: 0.5rem !important;
}
.footer {
text-align: center;
padding: 2rem 1rem;
margin-top: 3rem;
color: var(--light);
opacity: 0.7;
font-size: 0.9rem;
}
.badge {
display: inline-block;
padding: 0.25rem 0.75rem;
border-radius: 9999px;
font-size: 0.75rem;
font-weight: 700;
}
.badge-free {
background: var(--secondary);
color: white;
}
.badge-pro {
background: var(--primary);
color: white;
}
.badge-enterprise {
background: var(--accent);
color: white;
}
.example-card {
cursor: pointer;
transition: all 0.2s ease;
}
.example-card:hover {
transform: translateY(-3px);
box-shadow: 0 8px 25px rgba(0, 0, 0, 0.2) !important;
}
.copy-btn {
position: absolute;
top: 0.5rem;
right: 0.5rem;
background: rgba(255, 255, 255, 0.1);
border: none;
border-radius: 8px;
padding: 0.5rem;
cursor: pointer;
opacity: 0.7;
transition: all 0.2s ease;
}
.copy-btn:hover {
opacity: 1;
background: rgba(255, 255, 255, 0.2);
}
.dark-mode-plot {
background: transparent !important;
}
"""
def create_header(self):
return gr.HTML("""
<div class="header">
<h1>🚀 Gemma AI Platform</h1>
<p>Advanced text generation powered by state-of-the-art AI models. Create content, analyze text, and unlock new possibilities.</p>
</div>
""")
def create_footer(self):
return gr.HTML("""
<div class="footer">
<p>Gemma AI Platform © 2023 | Built with ❤️ using Hugging Face, Gradio, and Supabase</p>
<p>Version 2.0 | <a href="#" style="color: #ccc; text-decoration: none;">Terms of Service</a> | <a href="#" style="color: #ccc; text-decoration: none;">Privacy Policy</a></p>
</div>
""")
def create_examples_component(self):
with gr.Column(elem_classes=["card"]):
gr.Markdown("### 💡 Example Prompts")
examples_html = """
<div style="display: grid; grid-template-columns: repeat(auto-fit, minmax(250px, 1fr)); gap: 1rem;">
"""
for i, example in enumerate(self.examples):
examples_html += f"""
<div class="card example-card" onclick="document.getElementById('prompt-input').value = `{example[0]}`">
<div style="font-size: 0.9rem; opacity: 0.9;">{example[0]}</div>
</div>
"""
examples_html += "</div>"
return gr.HTML(examples_html)
async def create_interface(self):
"""Create the enhanced Gradio interface"""
with gr.Blocks(
css=self.create_advanced_css(),
title="Gemma AI Platform",
theme=gr.themes.Default(primary_hue="indigo", secondary_hue="emerald")
) as app:
# Header
self.create_header()
# Main Content
with gr.Tabs(elem_classes=["tab-nav"]):
# Playground Tab
with gr.Tab("🎮 Playground", elem_classes=["card"]):
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### ⚙️ Configuration")
api_key_playground = gr.Textbox(
label="🔑 API Key",
type="password",
placeholder="Enter your API key...",
elem_classes=["input"]
)
with gr.Accordion("Advanced Settings", open=False):
max_tokens_input = gr.Slider(
minimum=50, maximum=1000, value=200,
label="Max Tokens",
info="Maximum number of tokens to generate"
)
temperature_input = gr.Slider(
minimum=0.1, maximum=2.0, value=0.7,
label="Temperature",
info="Higher values = more creative, lower values = more focused"
)
top_k_input = gr.Slider(
minimum=1, maximum=100, value=50,
label="Top K",
info="Consider only the top K tokens"
)
top_p_input = gr.Slider(
minimum=0.1, maximum=1.0, value=0.95,
label="Top P",
info="Nucleus sampling: consider only tokens with cumulative probability"
)
repetition_penalty_input = gr.Slider(
minimum=0.1, maximum=2.0, value=1.0,
label="Repetition Penalty",
info="Penalize repeated tokens (1.0 = no penalty)"
)
with gr.Column(scale=2):
# Examples
self.create_examples_component()
gr.Markdown("### 💬 Text Generation")
prompt_input = gr.Textbox(
label="✍️ Your Prompt",
lines=6,
placeholder="Enter your prompt here... (e.g., 'Write a short story about a robot discovering emotions')",
elem_id="prompt-input",
elem_classes=["input"]
)
with gr.Row():
generate_btn = gr.Button(
"🚀 Generate",
elem_classes=["btn", "btn-primary"],
variant="primary"
)
clear_btn = gr.Button(
"🗑️ Clear",
elem_classes=["btn", "btn-secondary"]
)
output_text = gr.Textbox(
label="📝 Generated Text",
lines=8,
interactive=False,
elem_classes=["input"]
)
# Stats display
generation_stats = gr.JSON(
label="📊 Generation Statistics",
visible=False
)
# Profile Tab
with gr.Tab("👤 Profile", elem_classes=["card"]):
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### 🆕 Create Account")
name_input = gr.Textbox(
label="👤 Full Name",
elem_classes=["input"]
)
email_input = gr.Textbox(
label="📧 Email Address",
elem_classes=["input"]
)
plan_input = gr.Dropdown(
choices=["free", "pro", "enterprise"],
value="free",
label="📋 Plan",
elem_classes=["input"]
)
create_btn = gr.Button(
"✨ Create API Key",
elem_classes=["btn", "btn-primary"],
variant="primary"
)
creation_status = gr.HTML()
api_key_display = gr.Textbox(
label="🔑 Your API Key",
interactive=False,
visible=False,
elem_classes=["input"]
)
with gr.Column(scale=1):
gr.Markdown("### 📊 Account Statistics")
stats_api_key = gr.Textbox(
label="🔑 API Key",
type="password",
placeholder="Enter API key to view stats",
elem_classes=["input"]
)
refresh_stats_btn = gr.Button(
"🔄 Refresh Stats",
elem_classes=["btn", "btn-secondary"]
)
user_stats_display = gr.HTML()
# Analytics Tab
with gr.Tab("📈 Analytics", elem_classes=["card"]):
gr.Markdown("### 📊 Usage Analytics")
with gr.Row():
with gr.Column(scale=1):
analytics_api_key = gr.Textbox(
label="🔑 API Key",
type="password",
placeholder="Enter API key to view analytics",
elem_classes=["input"]
)
refresh_analytics_btn = gr.Button(
"📈 Generate Analytics",
elem_classes=["btn", "btn-primary"]
)
analytics_plot = gr.Plot(label="Usage Analytics")
# Admin Tab (only visible to admin)
with gr.Tab("👑 Admin", elem_classes=["card"], visible=bool(self.service.config.ADMIN_EMAIL)):
gr.Markdown("### 👑 Admin Dashboard")
with gr.Row():
with gr.Column(scale=1):
admin_api_key = gr.Textbox(
label="🔑 Admin API Key",
type="password",
placeholder="Enter admin API key",
elem_classes=["input"]
)
refresh_admin_btn = gr.Button(
"🔄 Refresh Admin Stats",
elem_classes=["btn", "btn-primary"]
)
admin_plot = gr.Plot(label="Platform Analytics")
admin_stats_display = gr.HTML()
# Footer
self.create_footer()
# Event Handlers
async def handle_generation(prompt, api_key, max_tokens, temperature, top_k, top_p, repetition_penalty):
if not api_key.strip():
return "⚠️ Please enter your API key", {}, False
response = await self.service.generate_text(
prompt=prompt,
api_key=api_key,
max_tokens=max_tokens,
temperature=temperature,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty
)
if response.success:
stats_html = f"""
<div style="display: grid; grid-template-columns: repeat(auto-fit, minmax(150px, 1fr)); gap: 1rem; margin-top: 1rem;">
<div class="stat-card">
<div class="stat-value">{response.data['tokens_used']}</div>
<div class="stat-label">Tokens Used</div>
</div>
<div class="stat-card">
<div class="stat-value">{response.data['requests_remaining']}</div>
<div class="stat-label">Requests Left</div>
</div>
<div class="stat-card">
<div class="stat-value">{response.data['user_plan'].title()}</div>
<div class="stat-label">Current Plan</div>
</div>
</div>
"""
return (
response.data["generated_text"],
stats_html,
True
)
else:
return (
response.error,
f'<div class="alert alert-error">{response.error}</div>',
False
)
async def handle_user_creation(name, email, plan):
if not name or not name.strip():
return (
f'<div class="alert alert-error">❌ Name is required</div>',
"",
False
)
if not email or not email.strip():
return (
f'<div class="alert alert-error">❌ Email is required</div>',
"",
False
)
response = await self.service.create_user(name, email, plan)
if response.success:
return (
f'<div class="alert alert-success">✅ Account created successfully! Your API key is below.</div>',
response.data["api_key"],
True
)
else:
return (
f'<div class="alert alert-error">❌ {response.error}</div>',
"",
False
)
async def handle_stats_refresh(api_key):
response = await self.service.get_user_stats(api_key)
if response.success:
stats = response.data
return f"""
<div>
<div style="display: grid; grid-template-columns: repeat(auto-fit, minmax(150px, 1fr)); gap: 1rem; margin-bottom: 2rem;">
<div class="stat-card">
<div class="stat-value">{stats['total_requests']}</div>
<div class="stat-label">Total Requests</div>
</div>
<div class="stat-card">
<div class="stat-value">{stats['tokens_used']:,}</div>
<div class="stat-label">Tokens Used</div>
</div>
<div class="stat-card">
<div class="stat-value">{stats['requests_this_hour']}/{stats['rate_limit']}</div>
<div class="stat-label">Hourly Usage</div>
</div>
<div class="stat-card">
<div class="stat-value">{stats['plan'].title()}</div>
<div class="stat-label">Current Plan</div>
</div>
</div>
<div style="background: var(--card-bg); padding: 1.5rem; border-radius: 12px; border: 1px solid var(--card-border);">
<h3 style="margin-top: 0;">Account Details</h3>
<p><strong>Name:</strong> {stats['name']}</p>
<p><strong>Email:</strong> {stats['email']}</p>
<p><strong>Member since:</strong> {stats['created_at'][:10] if stats['created_at'] else 'N/A'}</p>
<p><strong>Last request:</strong> {stats['last_request'][:19] if stats['last_request'] else 'Never'}</p>
</div>
</div>
"""
else:
return f'<div class="alert alert-error">❌ {response.error}</div>'
async def handle_analytics_refresh(api_key):
response = await self.service.get_user_stats(api_key)
if response.success:
return response.data["plot"]
else:
return go.Figure().update_layout(
title="Error loading analytics",
paper_bgcolor='rgba(0,0,0,0)',
plot_bgcolor='rgba(0,0,0,0)',
font=dict(color='white')
)
async def handle_admin_refresh(api_key):
response = await self.service.get_admin_stats(api_key)
if response.success:
stats = response.data
stats_html = f"""
<div style="margin-bottom: 2rem;">
<div style="display: grid; grid-template-columns: repeat(auto-fit, minmax(150px, 1fr)); gap: 1rem;">
<div class="stat-card">
<div class="stat-value">{stats['total_users']}</div>
<div class="stat-label">Total Users</div>
</div>
<div class="stat-card">
<div class="stat-value">{stats['active_users']}</div>
<div class="stat-label">Active Users</div>
</div>
<div class="stat-card">
<div class="stat-value">{stats['total_requests']}</div>
<div class="stat-label">Total Requests</div>
</div>
<div class="stat-card">
<div class="stat-value">{stats['total_tokens']:,}</div>
<div class="stat-label">Tokens Generated</div>
</div>
</div>
</div>
<div>
<h3>User List</h3>
<div style="max-height: 300px; overflow-y: auto;">
<table style="width: 100%; border-collapse: collapse;">
<thead>
<tr style="border-bottom: 1px solid rgba(255,255,255,0.1);">
<th style="text-align: left; padding: 0.75rem;">Name</th>
<th style="text-align: left; padding: 0.75rem;">Email</th>
<th style="text-align: left; padding: 0.75rem;">Plan</th>
<th style="text-align: left; padding: 0.75rem;">Requests</th>
<th style="text-align: left; padding: 0.75rem;">Tokens</th>
</tr>
</thead>
<tbody>
"""
for user in stats['users']:
stats_html += f"""
<tr style="border-bottom: 1px solid rgba(255,255,255,0.05);">
<td style="padding: 0.75rem;">{user.get('name', 'N/A')}</td>
<td style="padding: 0.75rem;">{user.get('email', 'N/A')}</td>
<td style="padding: 0.75rem;"><span class="badge badge-{user.get('plan', 'free')}">{user.get('plan', 'free').title()}</span></td>
<td style="padding: 0.75rem;">{user.get('requests', 0)}</td>
<td style="padding: 0.75rem;">{user.get('tokens_used', 0):,}</td>
</tr>
"""
stats_html += """
</tbody>
</table>
</div>
</div>
"""
return stats["plot"], stats_html
else:
return go.Figure().update_layout(
title="Error loading admin data",
paper_bgcolor='rgba(0,0,0,0)',
plot_bgcolor='rgba(0,0,0,0)',
font=dict(color='white')
), f'<div class="alert alert-error">❌ {response.error}</div>'
# Wire up events
generate_btn.click(
fn=handle_generation,
inputs=[
prompt_input, api_key_playground, max_tokens_input,
temperature_input, top_k_input, top_p_input, repetition_penalty_input
],
outputs=[output_text, generation_stats, generation_stats]
)
clear_btn.click(
fn=lambda: ("", "", False),
inputs=[],
outputs=[prompt_input, output_text, generation_stats]
)
create_btn.click(
fn=handle_user_creation,
inputs=[name_input, email_input, plan_input],
outputs=[creation_status, api_key_display, api_key_display]
)
refresh_stats_btn.click(
fn=handle_stats_refresh,
inputs=[stats_api_key],
outputs=[user_stats_display]
)
refresh_analytics_btn.click(
fn=handle_analytics_refresh,
inputs=[analytics_api_key],
outputs=[analytics_plot]
)
refresh_admin_btn.click(
fn=handle_admin_refresh,
inputs=[admin_api_key],
outputs=[admin_plot, admin_stats_display]
)
return app
# ----------------- Main Application -----------------
async def main():
"""Main application entry point"""
try:
# Initialize service
service = GemmaSaaSService()
await service.initialize()
# Create interface
interface = GradioInterface(service)
app = await interface.create_interface()
# Launch with enhanced configuration
app.launch(
server_name="0.0.0.0",
server_port=7860,
share=False,
debug=False,
show_error=True,
quiet=False,
favicon_path=None,
ssl_keyfile=None,
ssl_certfile=None,
auth=None,
max_threads=10
)
except Exception as e:
logger.error(f"Failed to start application: {e}")
raise
if __name__ == "__main__":
# Run the application
asyncio.run(main())