Teste / app.py
GuXSs's picture
Update app.py
9d2d92e verified
raw
history blame
17.2 kB
import os
import secrets
import logging
import asyncio
import html
from dataclasses import dataclass
from typing import Any, Optional, Tuple
import gradio as gr
from transformers import pipeline
from dotenv import load_dotenv
from pydantic import BaseModel
# ----------------- Configuration & Models -----------------
load_dotenv()
@dataclass
class Config:
HF_TOKEN: str = os.getenv("HF_TOKEN", "")
MODEL_NAME: str = os.getenv("MODEL_NAME", "google/gemma-3-270m-it")
MAX_TOKENS: int = int(os.getenv("MAX_TOKENS", "2048"))
LOG_LEVEL: str = os.getenv("LOG_LEVEL", "INFO")
class GenerationRequest(BaseModel):
prompt: str
max_tokens: int = 512
temperature: float = 0.7
top_k: int = 50
top_p: float = 0.95
class APIResponse(BaseModel):
success: bool
data: Any = None
error: Optional[str] = None
# ----------------- Enhanced Logger -----------------
def setup_logger() -> logging.Logger:
cfg = Config()
log_level = getattr(logging, cfg.LOG_LEVEL.upper(), logging.INFO)
logger = logging.getLogger("gemma_saas")
if not logger.handlers:
logger.setLevel(log_level)
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
file_handler = logging.FileHandler("gemma_saas.log")
file_handler.setFormatter(formatter)
stream_handler = logging.StreamHandler()
stream_handler.setFormatter(formatter)
logger.addHandler(file_handler)
logger.addHandler(stream_handler)
return logger
logger = setup_logger()
# ----------------- Model Manager -----------------
class ModelManager:
def __init__(self, config: Config):
self.config = config
self.pipeline = None
self.model_loaded = False
async def initialize(self) -> None:
if not self.config.HF_TOKEN:
logger.error("Token do Hugging Face não encontrado. O carregamento do modelo irá falhar.")
return
try:
logger.info(f"A carregar o modelo: {self.config.MODEL_NAME}...")
os.environ.setdefault("HF_TOKEN", self.config.HF_TOKEN)
loop = asyncio.get_event_loop()
def load_pipeline():
return pipeline(
"text-generation",
model=self.config.MODEL_NAME,
token=self.config.HF_TOKEN,
torch_dtype="auto",
device_map="auto",
)
self.pipeline = await loop.run_in_executor(None, load_pipeline)
self.model_loaded = True
logger.info("✅ Modelo carregado com sucesso!")
except Exception as e:
logger.error(f"❌ Erro ao carregar o modelo: {e}", exc_info=True)
async def generate(self, request: GenerationRequest) -> Tuple[bool, str, int]:
if not self.model_loaded or self.pipeline is None:
return (
False,
"❌ O modelo não está disponível. Por favor, verifique os logs do servidor.",
0,
)
try:
if not request.prompt.strip():
return False, "⚠️ O prompt não pode estar vazio.", 0
loop = asyncio.get_event_loop()
messages = [{"role": "user", "content": request.prompt.strip()}]
def do_generation():
tokenizer = getattr(self.pipeline, "tokenizer", None)
if tokenizer and hasattr(tokenizer, "apply_chat_template"):
prompt_text = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
else:
prompt_text = request.prompt.strip()
outputs = self.pipeline(
prompt_text,
max_new_tokens=min(request.max_tokens, self.config.MAX_TOKENS),
do_sample=True,
temperature=request.temperature,
top_k=request.top_k,
top_p=request.top_p,
)
generated_text = outputs[0].get("generated_text", "")
if generated_text.startswith(prompt_text):
generated_text = generated_text[len(prompt_text) :]
tokens_used = 0
if tokenizer and hasattr(tokenizer, "encode"):
try:
tokens_used = len(tokenizer.encode(generated_text))
except Exception:
tokens_used = 0
return generated_text, tokens_used
generated_text, tokens_used = await loop.run_in_executor(None, do_generation)
return True, generated_text, tokens_used
except Exception as e:
logger.error(f"Erro na geração: {e}", exc_info=True)
return False, f"❌ A geração falhou: {str(e)}", 0
# ----------------- Service Layer -----------------
class GemmaService:
def __init__(self):
self.config = Config()
self.model_manager = ModelManager(self.config)
async def initialize(self):
await self.model_manager.initialize()
async def generate_text(self, api_key: str, prompt: str, **kwargs) -> APIResponse:
if not api_key or not api_key.startswith("gsk-"):
return APIResponse(success=False, error="Chave de API inválida ou ausente.")
try:
request = GenerationRequest(prompt=prompt, **kwargs)
success, text, tokens_used = await self.model_manager.generate(request)
if success:
return APIResponse(success=True, data={"generated_text": text, "tokens_used": tokens_used})
else:
return APIResponse(success=False, error=text)
except Exception as e:
logger.error(f"Erro de serviço durante a geração de texto: {e}", exc_info=True)
return APIResponse(success=False, error="Ocorreu um erro interno no serviço.")
# ----------------- Enhanced UI -----------------
class GradioInterface:
def __init__(self, service: GemmaService):
self.service = service
def create_custom_css(self) -> str:
# Importa Material Icons e adiciona ícones via pseudo-elementos nos botões
return """
/* importar Material Icons */
@import url('https://fonts.googleapis.com/css2?family=Material+Icons&display=swap');
:root {
--dark-bg: #0a0a0a; --panel-bg: #1a1a1a; --border-color: #333;
--text-color: #f0f0f0; --text-light: #a0a0a0; --accent-orange: #FF4500;
--accent-orange-hover: #FF6347; --code-bg: #282c34;
}
.gradio-container { background: var(--dark-bg) !important; color: var(--text-color); }
#main_layout { background: transparent; border: none !important; box-shadow: none !important; gap: 2rem; }
#right_panel, #left_panel { background: var(--panel-bg); border: 1px solid var(--border-color); border-radius: 16px; padding: 2rem !important; }
#left_panel { display: flex !important; flex-direction: column !important; height: 80vh; }
#output_display { flex-grow: 1; overflow-y: auto; padding-right: 1rem; color: var(--text-color); }
#output_display p { margin-bottom: 1rem; line-height: 1.7; }
#input_area { margin-top: 1rem; }
#api_key_input textarea, #prompt_input textarea { background-color: #2C2C2C !important; border-color: var(--border-color) !important; color: var(--text-color) !important; border-radius: 12px !important; }
#send_button { background: var(--accent-orange); color: white; border: none; border-radius: 12px !important; transition: background-color 0.3s ease; position: relative; padding-left: 3rem; }
#send_button:hover { background-color: var(--accent-orange-hover); }
#generate_button {
background: linear-gradient(135deg, var(--accent-orange), var(--accent-orange-hover)); color: white !important;
font-size: 1.1rem !important; font-weight: bold !important; border: none; border-radius: 12px !important;
padding: 1rem 1.25rem !important; box-shadow: 0 4px 15px rgba(255, 69, 0, 0.4); transition: all 0.3s ease; position: relative; padding-left: 3rem;
}
#generate_button:hover { transform: translateY(-2px); box-shadow: 0 6px 20px rgba(255, 69, 0, 0.6); }
h2, h3 { color: white; border-bottom: 1px solid var(--border-color); padding-bottom: 0.75rem; margin-bottom: 1.5rem; font-weight: 600; }
.code-snippet { background-color: var(--code-bg); color: #abb2bf; padding: 1.5rem; border-radius: 12px; font-family: 'Courier New', monospace; white-space: pre-wrap; word-wrap: break-word; border: 1px solid var(--border-color); }
.code-snippet .keyword { color: #c678dd; } .code-snippet .string { color: #98c379; } .code-snippet .number { color: #d19a66; }
.gr-slider { color: var(--text-light); }
/* estilo para usar as Material Icons como ligatures */
.material-icon {
font-family: 'Material Icons', sans-serif;
font-weight: normal;
font-style: normal;
font-size: 20px;
line-height: 1;
letter-spacing: normal;
text-transform: none;
display: inline-block;
white-space: nowrap;
word-wrap: normal;
direction: ltr;
-webkit-font-feature-settings: 'liga';
-webkit-font-smoothing: antialiased;
}
/* adicionar ícones antes dos botões (usando ligatures) */
#send_button::before {
content: "send"; /* ligature do ícone */
font-family: 'Material Icons', sans-serif;
position: absolute;
left: 12px;
top: 50%;
transform: translateY(-50%);
font-size: 18px;
line-height: 1;
opacity: 0.95;
}
#generate_button::before {
content: "auto_awesome";
font-family: 'Material Icons', sans-serif;
position: absolute;
left: 12px;
top: 50%;
transform: translateY(-50%);
font-size: 18px;
line-height: 1;
opacity: 0.95;
}
/* ícone para o botão de gerar chave (se usar outro botão, adapte o id) */
#generate_button[aria-label], #generate_button[title] { /* fallback */
padding-left: 3rem;
}
/* ícone ao lado do exemplo de código (vpn_key) */
#right_panel .code-snippet::before {
content: "vpn_key";
font-family: 'Material Icons', sans-serif;
display: inline-block;
margin-right: 0.5rem;
vertical-align: middle;
font-size: 18px;
opacity: 0.9;
}
"""
async def create_interface(self) -> gr.Blocks:
with gr.Blocks(css=self.create_custom_css(), theme=None) as app:
with gr.Row(elem_id="main_layout", equal_height=False):
with gr.Column(scale=2):
with gr.Column(elem_id="left_panel"):
output_display = gr.Markdown(
elem_id="output_display",
value="<p style='color: #a0a0a0;'>A sua resposta aparecerá aqui...</p>",
)
with gr.Column(elem_id="input_area"):
api_key_input = gr.Textbox(
label="A Sua Chave de API",
placeholder="Cole a sua chave gsk-... aqui",
type="password",
elem_id="api_key_input",
)
with gr.Row():
prompt_input = gr.Textbox(
show_label=False,
placeholder="Digite a sua mensagem...",
elem_id="prompt_input",
scale=10,
)
send_button = gr.Button("➤ Enviar", elem_id="send_button", scale=2)
with gr.Column(scale=1):
with gr.Column(elem_id="right_panel"):
gr.Markdown("## Controlo")
key_button = gr.Button("✨ Gerar Nova Chave", elem_id="generate_button")
with gr.Accordion("Parâmetros Avançados", open=False):
temp_slider = gr.Slider(minimum=0.1, maximum=2.0, value=0.7, step=0.1, label="Temperatura")
max_tokens_slider = gr.Slider(minimum=64, maximum=self.service.config.MAX_TOKENS, value=512, step=64, label="Max Tokens")
top_k_slider = gr.Slider(minimum=1, maximum=100, value=50, step=1, label="Top-K")
top_p_slider = gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-P")
gr.Markdown("### Como Usar a API")
api_example_display = gr.HTML("<p style='color: #a0a0a0;'>Clique em 'Gerar Nova Chave' para ver um exemplo de código.</p>")
def handle_key_generation():
key = f"gsk-{secrets.token_urlsafe(24).replace('_', '').replace('-', '')}"
code_html = f"""
<div class="code-snippet">
<div><span class="keyword">import</span> requests</div>
<div>&nbsp;</div>
<div>url = <span class="string">"https://GuXSs.hf.space/run/generate"</span></div>
<div>payload = {{</div>
<div>&nbsp;&nbsp;&nbsp;&nbsp;<span class="string">"api_key"</span>: <span class="string">"{key}"</span>,</div>
<div>&nbsp;&nbsp;&nbsp;&nbsp;<span class="string">"prompt"</span>: <span class="string">"Escreva um haikai sobre o universo"</span>,</div>
<div>&nbsp;&nbsp;&nbsp;&nbsp;<span class="string">"max_tokens"</span>: <span class="number">50</span></div>
<div>}}</div>
<div>&nbsp;</div>
<div>response = requests.post(url, json=payload)</div>
<div><span class="keyword">print</span>(response.json())</div>
</div>
"""
return key, gr.update(value=code_html)
async def handle_generation(api_key, prompt, temp, max_tokens, top_k, top_p, btn):
if not api_key:
yield (
"<p style='color: #FFCC00;'>Por favor, insira a sua chave de API para começar.</p>",
gr.update(value="➤ Enviar", interactive=True),
)
return
if not prompt:
yield (
"<p style='color: #FFCC00;'>Por favor, digite um prompt.</p>",
gr.update(value="➤ Enviar", interactive=True),
)
return
yield "<p style='color: #a0a0a0;'>A gerar resposta...</p>", gr.update(value="A gerar...", interactive=False)
response = await self.service.generate_text(
api_key=api_key,
prompt=prompt,
temperature=temp,
max_tokens=int(max_tokens),
top_k=int(top_k),
top_p=top_p,
)
if response.success:
formatted_text = html.escape(response.data["generated_text"]).replace("\n", "<br>")
yield formatted_text, gr.update(value="➤ Enviar", interactive=True)
else:
yield f"<p style='color: #FF4500;'>{response.error}</p>", gr.update(value="➤ Enviar", interactive=True)
send_button.click(
handle_generation,
inputs=[api_key_input, prompt_input, temp_slider, max_tokens_slider, top_k_slider, top_p_slider, send_button],
outputs=[output_display, send_button],
api_name="generate",
)
key_button.click(handle_key_generation, outputs=[api_key_input, api_example_display])
app.load(
lambda: gr.update(value="<p style='color: #a0a0a0;'>Clique em 'Gerar Nova Chave' para ver um exemplo de código.</p>"),
[],
[api_example_display],
)
return app
# ----------------- Main Application -----------------
async def main():
try:
service = GemmaService()
await service.initialize()
interface = GradioInterface(service)
app = await interface.create_interface()
app.launch(server_name="0.0.0.0", server_port=7860, share=False, debug=False)
except Exception as e:
logger.critical(f"Falha ao iniciar a aplicação: {e}", exc_info=True)
if __name__ == "__main__":
asyncio.run(main())