Teste / app.py
GuXSs's picture
Update app.py
cd1e1eb verified
raw
history blame
12.9 kB
import os
import uuid
import json
import asyncio
import logging
import time
from datetime import datetime, timedelta
from typing import Dict, List, Optional, Tuple, Any
from dataclasses import dataclass
import gradio as gr
import aiohttp
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from dotenv import load_dotenv
from pydantic import BaseModel, ValidationError
import secrets
import plotly.graph_objects as go
from plotly.subplots import make_subplots
# ----------------- Configuration & Models -----------------
load_dotenv()
@dataclass
class Config:
HF_TOKEN: str = os.getenv("HF_TOKEN", "")
MODEL_NAME: str = os.getenv("MODEL_NAME", "google/gemma-2-9b-it")
MAX_TOKENS: int = int(os.getenv("MAX_TOKENS", "1500"))
LOG_LEVEL: str = os.getenv("LOG_LEVEL", "INFO")
class GenerationRequest(BaseModel):
prompt: str
max_tokens: int = 500
temperature: float = 0.75
top_k: int = 50
top_p: float = 0.95
repetition_penalty: float = 1.1
class APIResponse(BaseModel):
success: bool
data: Any = None
error: Optional[str] = None
timestamp: datetime = datetime.now()
# ----------------- Enhanced Logger -----------------
def setup_logger():
logging.basicConfig(
level=getattr(logging, Config().LOG_LEVEL),
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler('gemma_saas.log'),
logging.StreamHandler()
]
)
return logging.getLogger(__name__)
logger = setup_logger()
# ----------------- Model Manager -----------------
class ModelManager:
def __init__(self, config: Config):
self.config = config
self.tokenizer = None
self.model = None
self.pipeline = None
self.model_loaded = False
async def initialize(self):
"""Initialize the model, tokenizer, and pipeline asynchronously."""
if not self.config.HF_TOKEN:
logger.error("Hugging Face token not found. Model loading will fail.")
self.model_loaded = False
return
try:
logger.info(f"Loading model: {self.config.MODEL_NAME}...")
loop = asyncio.get_event_loop()
def load_components():
tokenizer = AutoTokenizer.from_pretrained(self.config.MODEL_NAME, token=self.config.HF_TOKEN)
model = AutoModelForCausalLM.from_pretrained(
self.config.MODEL_NAME,
token=self.config.HF_TOKEN,
device_map="auto",
torch_dtype="auto"
)
text_pipeline = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
)
return tokenizer, model, text_pipeline
self.tokenizer, self.model, self.pipeline = await loop.run_in_executor(None, load_components)
self.model_loaded = True
logger.info("✅ Model loaded successfully!")
except Exception as e:
logger.error(f"❌ Error loading model: {e}")
self.model_loaded = False
async def generate(self, request: GenerationRequest) -> Tuple[bool, str, int]:
"""Generate text based on the provided request."""
if not self.model_loaded:
return False, "❌ O modelo não está disponível. Por favor, verifique os logs do servidor.", 0
try:
if not request.prompt.strip():
return False, "⚠️ O prompt não pode estar vazio.", 0
if len(request.prompt) > 8000:
return False, "⚠️ O prompt é muito longo (máximo de 8000 caracteres).", 0
loop = asyncio.get_event_loop()
messages = [
{"role": "user", "content": request.prompt.strip()},
]
def do_generation():
prompt = self.pipeline.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
outputs = self.pipeline(
prompt,
max_new_tokens=min(request.max_tokens, self.config.MAX_TOKENS),
do_sample=True,
temperature=request.temperature,
top_k=request.top_k,
top_p=request.top_p,
)
return outputs[0]["generated_text"][len(prompt):]
generated_text = await loop.run_in_executor(None, do_generation)
tokens_used = len(self.tokenizer.encode(generated_text))
return True, generated_text, tokens_used
except Exception as e:
logger.error(f"Generation error: {e}")
return False, f"❌ A geração falhou: {str(e)}", 0
# ----------------- Service Layer -----------------
class GemmaService:
def __init__(self):
self.config = Config()
self.model_manager = ModelManager(self.config)
self._validate_config()
def _validate_config(self):
"""Validate that required environment variables are set."""
if not self.config.HF_TOKEN:
raise ValueError("Missing required environment variable: HF_TOKEN")
async def initialize(self):
await self.model_manager.initialize()
async def generate_text(self, prompt: str, **kwargs) -> APIResponse:
"""Generate text directly."""
try:
request = GenerationRequest(prompt=prompt, **kwargs)
success, text, tokens_used = await self.model_manager.generate(request)
if success:
return APIResponse(
success=True,
data={"generated_text": text, "tokens_used": tokens_used}
)
else:
return APIResponse(success=False, error=text)
except Exception as e:
logger.error(f"Service error during text generation: {e}")
return APIResponse(success=False, error="Ocorreu um erro interno no serviço.")
# ----------------- Enhanced UI -----------------
class GradioInterface:
def __init__(self, service: GemmaService):
self.service = service
def create_custom_css(self):
return """
:root {
--dark-bg: #111111;
--panel-bg: #1C1C1C;
--border-color: #333333;
--text-color: #E0E0E0;
--text-light: #A0A0A0;
--accent-orange: #FF4500;
--accent-orange-hover: #FF6347;
}
.gradio-container { background-color: var(--dark-bg) !important; }
#main_layout { background-color: transparent; border: none !important; box-shadow: none !important; }
#right_panel { background-color: var(--panel-bg); border-left: 1px solid var(--border-color); border-radius: 12px; padding: 2rem !important; }
#left_panel { background-color: var(--panel-bg); border-radius: 12px; padding: 1rem !important; display: flex !important; flex-direction: column !important; height: 70vh; }
#output_display { flex-grow: 1; overflow-y: auto; padding: 1rem; color: var(--text-color); }
#output_display p { margin-bottom: 1rem; line-height: 1.6; }
#prompt_row { border-top: 1px solid var(--border-color); padding-top: 1rem; }
#prompt_input textarea { background-color: #2C2C2C !important; border-color: var(--border-color) !important; color: var(--text-color) !important; border-radius: 8px !important; }
#send_button { background-color: var(--accent-orange); color: white; border: none; border-radius: 50% !important; width: 50px !important; height: 50px !important; min-width: 50px !important; transition: background-color 0.3s ease; }
#send_button:hover { background-color: var(--accent-orange-hover); }
#generate_button {
background: linear-gradient(135deg, var(--accent-orange), var(--accent-orange-hover));
color: white !important;
font-size: 1.2rem !important;
font-weight: bold !important;
border: none;
border-radius: 12px !important;
padding: 1rem !important;
box-shadow: 0 4px 15px rgba(255, 69, 0, 0.4);
transition: all 0.3s ease;
}
#generate_button:hover {
transform: translateY(-2px);
box-shadow: 0 6px 20px rgba(255, 69, 0, 0.6);
}
.gr-label { color: var(--text-light) !important; }
h2 { color: white; border-bottom: 1px solid var(--border-color); padding-bottom: 0.5rem; margin-bottom: 1rem; }
#info_text { color: var(--text-light); line-height: 1.7; }
"""
async def create_interface(self):
with gr.Blocks(css=self.create_custom_css(), theme=None) as app:
with gr.Row(elem_id="main_layout", equal_height=False):
with gr.Column(scale=2, elem_id="left_panel_container"):
with gr.Column(elem_id="left_panel"):
output_display = gr.Markdown(elem_id="output_display", value="<p style='color: #A0A0A0;'>Sua resposta aparecerá aqui...</p>")
with gr.Row(elem_id="prompt_row"):
prompt_input = gr.Textbox(
show_label=False,
placeholder="Digite sua mensagem aqui...",
elem_id="prompt_input",
scale=10
)
send_button = gr.Button("➤", elem_id="send_button", scale=1)
with gr.Column(scale=1, elem_id="right_panel"):
gr.Markdown("## Informações")
gr.Markdown(
"""
Este é um ambiente interativo para o modelo de linguagem **Gemma**.
- **Como usar:** Digite seu prompt na caixa de texto à esquerda e clique no botão de envio para gerar uma resposta.
- **Gerar Chave:** Use o botão abaixo para gerar uma chave de API de exemplo.
""",
elem_id="info_text"
)
key_button = gr.Button("✨ Gerar Key", elem_id="generate_button")
key_display = gr.Markdown()
# --- Event Handlers ---
async def handle_generation(prompt):
if not prompt:
# FIX: Use yield and return to exit the generator correctly
yield "<p style='color: #FFCC00;'>Por favor, digite um prompt para começar.</p>"
return
# Show a loading indicator
yield "<p style='color: #A0A0A0;'>Gerando resposta...</p>"
response = await self.service.generate_text(prompt=prompt)
if response.success:
yield response.data["generated_text"]
else:
yield f"<p style='color: #FF4500;'>{response.error}</p>"
def handle_key_generation():
"""Generates a random API key in the specified format."""
random_part = secrets.token_urlsafe(24).replace("_", "").replace("-", "")
key = f"gsk-{random_part}"
return f"<p style='color: #A0A0A0; text-align: center; margin-top: 1rem;'>Sua chave de exemplo:</p><pre style='background: #2C2C2C; padding: 1rem; border-radius: 8px; text-align: center; word-wrap: break-word;'><code>{key}</code></pre>"
# --- Wiring ---
key_button.click(
handle_key_generation,
inputs=[],
outputs=[key_display]
)
send_button.click(
handle_generation,
inputs=[prompt_input],
outputs=[output_display]
)
prompt_input.submit(
handle_generation,
inputs=[prompt_input],
outputs=[output_display]
)
return app
# ----------------- Main Application -----------------
async def main():
"""Main application entry point"""
try:
service = GemmaService()
await service.initialize()
interface = GradioInterface(service)
app = await interface.create_interface()
app.launch(
server_name="0.0.0.0",
server_port=7860,
share=False,
debug=False,
show_error=True
)
except Exception as e:
logger.critical(f"Failed to start application: {e}", exc_info=True)
raise
if __name__ == "__main__":
# To run this, you need a .env file with:
# HF_TOKEN="your_hugging_face_token"
asyncio.run(main())