| import os |
| import uuid |
| import json |
| import asyncio |
| import logging |
| import time |
| from datetime import datetime, timedelta |
| from typing import Dict, List, Optional, Tuple, Any |
| from dataclasses import dataclass |
|
|
| import gradio as gr |
| import aiohttp |
| from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline |
| from dotenv import load_dotenv |
| from pydantic import BaseModel, ValidationError |
| import secrets |
| import plotly.graph_objects as go |
| from plotly.subplots import make_subplots |
|
|
| |
| load_dotenv() |
|
|
| @dataclass |
| class Config: |
| HF_TOKEN: str = os.getenv("HF_TOKEN", "") |
| MODEL_NAME: str = os.getenv("MODEL_NAME", "google/gemma-2-9b-it") |
| MAX_TOKENS: int = int(os.getenv("MAX_TOKENS", "1500")) |
| LOG_LEVEL: str = os.getenv("LOG_LEVEL", "INFO") |
|
|
| class GenerationRequest(BaseModel): |
| prompt: str |
| max_tokens: int = 500 |
| temperature: float = 0.75 |
| top_k: int = 50 |
| top_p: float = 0.95 |
| repetition_penalty: float = 1.1 |
|
|
| class APIResponse(BaseModel): |
| success: bool |
| data: Any = None |
| error: Optional[str] = None |
| timestamp: datetime = datetime.now() |
|
|
| |
| def setup_logger(): |
| logging.basicConfig( |
| level=getattr(logging, Config().LOG_LEVEL), |
| format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', |
| handlers=[ |
| logging.FileHandler('gemma_saas.log'), |
| logging.StreamHandler() |
| ] |
| ) |
| return logging.getLogger(__name__) |
|
|
| logger = setup_logger() |
|
|
| |
| class ModelManager: |
| def __init__(self, config: Config): |
| self.config = config |
| self.tokenizer = None |
| self.model = None |
| self.pipeline = None |
| self.model_loaded = False |
|
|
| async def initialize(self): |
| """Initialize the model, tokenizer, and pipeline asynchronously.""" |
| if not self.config.HF_TOKEN: |
| logger.error("Hugging Face token not found. Model loading will fail.") |
| self.model_loaded = False |
| return |
| try: |
| logger.info(f"Loading model: {self.config.MODEL_NAME}...") |
| loop = asyncio.get_event_loop() |
|
|
| def load_components(): |
| tokenizer = AutoTokenizer.from_pretrained(self.config.MODEL_NAME, token=self.config.HF_TOKEN) |
| model = AutoModelForCausalLM.from_pretrained( |
| self.config.MODEL_NAME, |
| token=self.config.HF_TOKEN, |
| device_map="auto", |
| torch_dtype="auto" |
| ) |
| text_pipeline = pipeline( |
| "text-generation", |
| model=model, |
| tokenizer=tokenizer, |
| ) |
| return tokenizer, model, text_pipeline |
|
|
| self.tokenizer, self.model, self.pipeline = await loop.run_in_executor(None, load_components) |
| self.model_loaded = True |
| logger.info("✅ Model loaded successfully!") |
| except Exception as e: |
| logger.error(f"❌ Error loading model: {e}") |
| self.model_loaded = False |
|
|
| async def generate(self, request: GenerationRequest) -> Tuple[bool, str, int]: |
| """Generate text based on the provided request.""" |
| if not self.model_loaded: |
| return False, "❌ O modelo não está disponível. Por favor, verifique os logs do servidor.", 0 |
| try: |
| if not request.prompt.strip(): |
| return False, "⚠️ O prompt não pode estar vazio.", 0 |
| if len(request.prompt) > 8000: |
| return False, "⚠️ O prompt é muito longo (máximo de 8000 caracteres).", 0 |
|
|
| loop = asyncio.get_event_loop() |
| |
| messages = [ |
| {"role": "user", "content": request.prompt.strip()}, |
| ] |
|
|
| def do_generation(): |
| prompt = self.pipeline.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) |
| outputs = self.pipeline( |
| prompt, |
| max_new_tokens=min(request.max_tokens, self.config.MAX_TOKENS), |
| do_sample=True, |
| temperature=request.temperature, |
| top_k=request.top_k, |
| top_p=request.top_p, |
| ) |
| return outputs[0]["generated_text"][len(prompt):] |
|
|
| generated_text = await loop.run_in_executor(None, do_generation) |
| tokens_used = len(self.tokenizer.encode(generated_text)) |
| return True, generated_text, tokens_used |
| except Exception as e: |
| logger.error(f"Generation error: {e}") |
| return False, f"❌ A geração falhou: {str(e)}", 0 |
|
|
| |
| class GemmaService: |
| def __init__(self): |
| self.config = Config() |
| self.model_manager = ModelManager(self.config) |
| self._validate_config() |
|
|
| def _validate_config(self): |
| """Validate that required environment variables are set.""" |
| if not self.config.HF_TOKEN: |
| raise ValueError("Missing required environment variable: HF_TOKEN") |
|
|
| async def initialize(self): |
| await self.model_manager.initialize() |
|
|
| async def generate_text(self, prompt: str, **kwargs) -> APIResponse: |
| """Generate text directly.""" |
| try: |
| request = GenerationRequest(prompt=prompt, **kwargs) |
| success, text, tokens_used = await self.model_manager.generate(request) |
|
|
| if success: |
| return APIResponse( |
| success=True, |
| data={"generated_text": text, "tokens_used": tokens_used} |
| ) |
| else: |
| return APIResponse(success=False, error=text) |
|
|
| except Exception as e: |
| logger.error(f"Service error during text generation: {e}") |
| return APIResponse(success=False, error="Ocorreu um erro interno no serviço.") |
|
|
| |
| class GradioInterface: |
| def __init__(self, service: GemmaService): |
| self.service = service |
|
|
| def create_custom_css(self): |
| return """ |
| :root { |
| --dark-bg: #111111; |
| --panel-bg: #1C1C1C; |
| --border-color: #333333; |
| --text-color: #E0E0E0; |
| --text-light: #A0A0A0; |
| --accent-orange: #FF4500; |
| --accent-orange-hover: #FF6347; |
| } |
| .gradio-container { background-color: var(--dark-bg) !important; } |
| #main_layout { background-color: transparent; border: none !important; box-shadow: none !important; } |
| #right_panel { background-color: var(--panel-bg); border-left: 1px solid var(--border-color); border-radius: 12px; padding: 2rem !important; } |
| #left_panel { background-color: var(--panel-bg); border-radius: 12px; padding: 1rem !important; display: flex !important; flex-direction: column !important; height: 70vh; } |
| #output_display { flex-grow: 1; overflow-y: auto; padding: 1rem; color: var(--text-color); } |
| #output_display p { margin-bottom: 1rem; line-height: 1.6; } |
| #prompt_row { border-top: 1px solid var(--border-color); padding-top: 1rem; } |
| #prompt_input textarea { background-color: #2C2C2C !important; border-color: var(--border-color) !important; color: var(--text-color) !important; border-radius: 8px !important; } |
| #send_button { background-color: var(--accent-orange); color: white; border: none; border-radius: 50% !important; width: 50px !important; height: 50px !important; min-width: 50px !important; transition: background-color 0.3s ease; } |
| #send_button:hover { background-color: var(--accent-orange-hover); } |
| #generate_button { |
| background: linear-gradient(135deg, var(--accent-orange), var(--accent-orange-hover)); |
| color: white !important; |
| font-size: 1.2rem !important; |
| font-weight: bold !important; |
| border: none; |
| border-radius: 12px !important; |
| padding: 1rem !important; |
| box-shadow: 0 4px 15px rgba(255, 69, 0, 0.4); |
| transition: all 0.3s ease; |
| } |
| #generate_button:hover { |
| transform: translateY(-2px); |
| box-shadow: 0 6px 20px rgba(255, 69, 0, 0.6); |
| } |
| .gr-label { color: var(--text-light) !important; } |
| h2 { color: white; border-bottom: 1px solid var(--border-color); padding-bottom: 0.5rem; margin-bottom: 1rem; } |
| #info_text { color: var(--text-light); line-height: 1.7; } |
| """ |
|
|
| async def create_interface(self): |
| with gr.Blocks(css=self.create_custom_css(), theme=None) as app: |
| with gr.Row(elem_id="main_layout", equal_height=False): |
| with gr.Column(scale=2, elem_id="left_panel_container"): |
| with gr.Column(elem_id="left_panel"): |
| output_display = gr.Markdown(elem_id="output_display", value="<p style='color: #A0A0A0;'>Sua resposta aparecerá aqui...</p>") |
| with gr.Row(elem_id="prompt_row"): |
| prompt_input = gr.Textbox( |
| show_label=False, |
| placeholder="Digite sua mensagem aqui...", |
| elem_id="prompt_input", |
| scale=10 |
| ) |
| send_button = gr.Button("➤", elem_id="send_button", scale=1) |
|
|
| with gr.Column(scale=1, elem_id="right_panel"): |
| gr.Markdown("## Informações") |
| gr.Markdown( |
| """ |
| Este é um ambiente interativo para o modelo de linguagem **Gemma**. |
| |
| - **Como usar:** Digite seu prompt na caixa de texto à esquerda e clique no botão de envio para gerar uma resposta. |
| - **Gerar Chave:** Use o botão abaixo para gerar uma chave de API de exemplo. |
| """, |
| elem_id="info_text" |
| ) |
| key_button = gr.Button("✨ Gerar Key", elem_id="generate_button") |
| key_display = gr.Markdown() |
|
|
|
|
| |
| async def handle_generation(prompt): |
| if not prompt: |
| |
| yield "<p style='color: #FFCC00;'>Por favor, digite um prompt para começar.</p>" |
| return |
| |
| |
| yield "<p style='color: #A0A0A0;'>Gerando resposta...</p>" |
|
|
| response = await self.service.generate_text(prompt=prompt) |
| |
| if response.success: |
| yield response.data["generated_text"] |
| else: |
| yield f"<p style='color: #FF4500;'>{response.error}</p>" |
|
|
| def handle_key_generation(): |
| """Generates a random API key in the specified format.""" |
| random_part = secrets.token_urlsafe(24).replace("_", "").replace("-", "") |
| key = f"gsk-{random_part}" |
| return f"<p style='color: #A0A0A0; text-align: center; margin-top: 1rem;'>Sua chave de exemplo:</p><pre style='background: #2C2C2C; padding: 1rem; border-radius: 8px; text-align: center; word-wrap: break-word;'><code>{key}</code></pre>" |
|
|
|
|
| |
| key_button.click( |
| handle_key_generation, |
| inputs=[], |
| outputs=[key_display] |
| ) |
| send_button.click( |
| handle_generation, |
| inputs=[prompt_input], |
| outputs=[output_display] |
| ) |
| prompt_input.submit( |
| handle_generation, |
| inputs=[prompt_input], |
| outputs=[output_display] |
| ) |
|
|
| return app |
|
|
| |
| async def main(): |
| """Main application entry point""" |
| try: |
| service = GemmaService() |
| await service.initialize() |
| |
| interface = GradioInterface(service) |
| app = await interface.create_interface() |
| |
| app.launch( |
| server_name="0.0.0.0", |
| server_port=7860, |
| share=False, |
| debug=False, |
| show_error=True |
| ) |
| except Exception as e: |
| logger.critical(f"Failed to start application: {e}", exc_info=True) |
| raise |
|
|
| if __name__ == "__main__": |
| |
| |
| asyncio.run(main()) |
|
|