Teste / app.py
GuXSs's picture
Update app.py
b85befe verified
# app.py
import os
import secrets
import logging
import asyncio
import html
from dataclasses import dataclass
from typing import Any, Optional, Tuple
import gradio as gr
from transformers import pipeline
from dotenv import load_dotenv
from pydantic import BaseModel
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse
# ----------------- Configuration & Models -----------------
load_dotenv()
@dataclass
class Config:
HF_TOKEN: str = os.getenv("HF_TOKEN", "")
MODEL_NAME: str = os.getenv("MODEL_NAME", "google/gemma-3-270m-it")
MAX_TOKENS: int = int(os.getenv("MAX_TOKENS", "2048"))
LOG_LEVEL: str = os.getenv("LOG_LEVEL", "INFO")
class GenerationRequest(BaseModel):
prompt: str
max_tokens: int = 512
temperature: float = 0.7
top_k: int = 50
top_p: float = 0.95
class APIResponse(BaseModel):
success: bool
data: Any = None
error: Optional[str] = None
# ----------------- Logger -----------------
def setup_logger() -> logging.Logger:
cfg = Config()
log_level = getattr(logging, cfg.LOG_LEVEL.upper(), logging.INFO)
logger = logging.getLogger("gemma_saas")
if not logger.handlers:
logger.setLevel(log_level)
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
fh = logging.FileHandler("gemma_saas.log")
fh.setFormatter(formatter)
sh = logging.StreamHandler()
sh.setFormatter(formatter)
logger.addHandler(fh)
logger.addHandler(sh)
return logger
logger = setup_logger()
# ----------------- Model Manager -----------------
class ModelManager:
def __init__(self, config: Config):
self.config = config
self.pipeline = None
self.model_loaded = False
async def initialize(self) -> None:
if not self.config.HF_TOKEN:
logger.error("Token do Hugging Face não encontrado. O carregamento do modelo poderá falhar.")
return
try:
logger.info(f"A carregar o modelo: {self.config.MODEL_NAME}...")
os.environ.setdefault("HF_TOKEN", self.config.HF_TOKEN)
loop = asyncio.get_event_loop()
def load_pipeline():
return pipeline(
"text-generation",
model=self.config.MODEL_NAME,
token=self.config.HF_TOKEN,
torch_dtype="auto",
device_map="auto",
)
self.pipeline = await loop.run_in_executor(None, load_pipeline)
self.model_loaded = True
logger.info("✅ Modelo carregado com sucesso!")
except Exception as e:
logger.error(f"❌ Erro ao carregar o modelo: {e}", exc_info=True)
async def generate(self, request: GenerationRequest) -> Tuple[bool, str, int]:
if not self.model_loaded or self.pipeline is None:
return False, "❌ O modelo não está disponível. Por favor, verifique os logs do servidor.", 0
if not request.prompt.strip():
return False, "⚠️ O prompt não pode estar vazio.", 0
loop = asyncio.get_event_loop()
messages = [{"role": "user", "content": request.prompt.strip()}]
def do_generation():
tokenizer = getattr(self.pipeline, "tokenizer", None)
if tokenizer and hasattr(tokenizer, "apply_chat_template"):
prompt_text = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
else:
prompt_text = request.prompt.strip()
outputs = self.pipeline(
prompt_text,
max_new_tokens=min(request.max_tokens, self.config.MAX_TOKENS),
do_sample=True,
temperature=request.temperature,
top_k=request.top_k,
top_p=request.top_p,
)
generated_text = outputs[0].get("generated_text", "")
if generated_text.startswith(prompt_text):
generated_text = generated_text[len(prompt_text):]
tokens_used = 0
if tokenizer and hasattr(tokenizer, "encode"):
try:
tokens_used = len(tokenizer.encode(generated_text))
except Exception:
tokens_used = 0
return generated_text, tokens_used
generated_text, tokens_used = await loop.run_in_executor(None, do_generation)
return True, generated_text, tokens_used
# ----------------- Service Layer -----------------
class GemmaService:
def __init__(self):
self.config = Config()
self.model_manager = ModelManager(self.config)
async def initialize(self):
await self.model_manager.initialize()
async def generate_text(self, api_key: str, prompt: str, **kwargs) -> APIResponse:
if not api_key or not isinstance(api_key, str) or not api_key.startswith("gsk-"):
return APIResponse(success=False, error="Chave de API inválida ou ausente.")
try:
req = GenerationRequest(prompt=prompt, **kwargs)
success, text, tokens_used = await self.model_manager.generate(req)
if success:
return APIResponse(success=True, data={"generated_text": text, "tokens_used": tokens_used})
else:
return APIResponse(success=False, error=text)
except Exception as e:
logger.error(f"Erro de serviço durante a geração de texto: {e}", exc_info=True)
return APIResponse(success=False, error="Ocorreu um erro interno no serviço.")
# ----------------- Build Gradio UI (síncrono) -----------------
class GradioInterface:
def __init__(self, service: GemmaService):
self.service = service
def create_custom_css(self) -> str:
return """
@import url('https://fonts.googleapis.com/css2?family=Material+Icons&display=swap');
:root { --dark-bg:#0a0a0a; --panel-bg:#1a1a1a; --border-color:#333; --text-color:#f0f0f0; --text-light:#a0a0a0; --accent-orange:#FF4500; --accent-orange-hover:#FF6347; --code-bg:#282c34; }
.gradio-container { background: var(--dark-bg) !important; color: var(--text-color); }
/* ... rest of CSS (trimmed for brevity) ... */
#send_button::before { content: "send"; font-family: 'Material Icons', sans-serif; position:absolute; left:12px; top:50%; transform:translateY(-50%); font-size:18px; opacity:0.95; }
#generate_button::before { content: "auto_awesome"; font-family: 'Material Icons', sans-serif; position:absolute; left:12px; top:50%; transform:translateY(-50%); font-size:18px; opacity:0.95; }
"""
def create_interface(self) -> gr.Blocks:
# Criar a interface de forma síncrona (não await)
demo = gr.Blocks(css=self.create_custom_css(), theme=None)
with demo:
with gr.Row(elem_id="main_layout", equal_height=False):
with gr.Column(scale=2):
with gr.Column(elem_id="left_panel"):
output_display = gr.Markdown(elem_id="output_display", value="<p style='color: #a0a0a0;'>A sua resposta aparecerá aqui...</p>")
with gr.Column(elem_id="input_area"):
api_key_input = gr.Textbox(label="A Sua Chave de API", placeholder="Cole a sua chave gsk-... aqui", type="password", elem_id="api_key_input")
with gr.Row():
prompt_input = gr.Textbox(show_label=False, placeholder="Digite a sua mensagem...", elem_id="prompt_input", scale=10)
send_button = gr.Button("➤ Enviar", elem_id="send_button", scale=2)
with gr.Column(scale=1):
with gr.Column(elem_id="right_panel"):
gr.Markdown("## Controlo")
key_button = gr.Button("✨ Gerar Nova Chave", elem_id="generate_button")
with gr.Accordion("Parâmetros Avançados", open=False):
temp_slider = gr.Slider(minimum=0.1, maximum=2.0, value=0.7, step=0.1, label="Temperatura")
max_tokens_slider = gr.Slider(minimum=64, maximum=self.service.config.MAX_TOKENS, value=512, step=64, label="Max Tokens")
top_k_slider = gr.Slider(minimum=1, maximum=100, value=50, step=1, label="Top-K")
top_p_slider = gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-P")
gr.Markdown("### Como Usar a API")
api_example_display = gr.HTML("<p style='color: #a0a0a0;'>Clique em 'Gerar Nova Chave' para ver um exemplo de código.</p>")
def handle_key_generation():
key = f"gsk-{secrets.token_urlsafe(24).replace('_', '').replace('-', '')}"
code_html = f"<div class='code-snippet'> ... </div>"
return key, gr.update(value=code_html)
async def handle_generation(api_key, prompt, temp, max_tokens, top_k, top_p, btn):
if not api_key:
yield "<p style='color: #FFCC00;'>Por favor, insira a sua chave de API para começar.</p>", gr.update(value="➤ Enviar", interactive=True)
return
if not prompt:
yield "<p style='color: #FFCC00;'>Por favor, digite um prompt.</p>", gr.update(value="➤ Enviar", interactive=True)
return
yield "<p style='color: #a0a0a0;'>A gerar resposta...</p>", gr.update(value="A gerar...", interactive=False)
response = await self.service.generate_text(api_key=api_key, prompt=prompt, temperature=temp, max_tokens=int(max_tokens), top_k=int(top_k), top_p=top_p)
if response.success:
formatted_text = html.escape(response.data["generated_text"]).replace("\n", "<br>")
yield formatted_text, gr.update(value="➤ Enviar", interactive=True)
else:
yield f"<p style='color: #FF4500;'>{response.error}</p>", gr.update(value="➤ Enviar", interactive=True)
# conectar o callback
send_button.click(
handle_generation,
inputs=[api_key_input, prompt_input, temp_slider, max_tokens_slider, top_k_slider, top_p_slider, send_button],
outputs=[output_display, send_button],
api_name="generate",
)
key_button.click(handle_key_generation, outputs=[api_key_input, api_example_display])
demo.load(lambda: gr.update(value="<p style='color: #a0a0a0;'>Clique em 'Gerar Nova Chave' para ver um exemplo de código.</p>"), [], [api_example_display])
return demo
# ----------------- FastAPI app and endpoints -----------------
service = GemmaService()
gradio_interface = GradioInterface(service)
gradio_blocks = gradio_interface.create_interface()
app = FastAPI(title="Gemma Service (Gradio + API)")
# montar Gradio na raiz "/" - se mount falhar, a UI ainda poderá ser servida pelo Space.
try:
gr.mount_gradio_app(app, gradio_blocks, path="/")
except Exception as exc:
logger.warning("Não foi possível montar Gradio automaticamente: %s", exc)
@app.on_event("startup")
async def startup_event():
# inicializa modelo em background (não bloqueia o startup)
# se preferir aguarde a carga antes de aceitar requests, substitua create_task por await
asyncio.create_task(service.initialize())
@app.post("/api/generate")
async def api_generate(req: Request):
try:
body = await req.json()
except Exception:
return JSONResponse(status_code=400, content={"success": False, "error": "Payload inválido (JSON esperado)."})
api_key = body.get("api_key")
prompt = body.get("prompt", "")
max_tokens = int(body.get("max_tokens", 512))
temperature = float(body.get("temperature", 0.7))
top_k = int(body.get("top_k", 50))
top_p = float(body.get("top_p", 0.95))
resp = await service.generate_text(api_key=api_key, prompt=prompt, max_tokens=max_tokens, temperature=temperature, top_k=top_k, top_p=top_p)
status = 200 if resp.success else 400
return JSONResponse(status_code=status, content=resp.dict())
@app.post("/run/generate")
async def gradio_compatible_generate(req: Request):
try:
body = await req.json()
except Exception:
return JSONResponse(status_code=400, content={"success": False, "error": "Payload inválido (JSON esperado)."})
data = body.get("data")
if not isinstance(data, list):
return JSONResponse(status_code=400, content={"success": False, "error": "Campo 'data' inválido. Esperado array."})
try:
api_key = data[0]
prompt = data[1] if len(data) > 1 else ""
max_tokens = int(data[2]) if len(data) > 2 else 512
temperature = float(data[3]) if len(data) > 3 else 0.7
top_k = int(data[4]) if len(data) > 4 else 50
top_p = float(data[5]) if len(data) > 5 else 0.95
except Exception as e:
return JSONResponse(status_code=400, content={"success": False, "error": f"Erro ao parsear 'data': {e}"})
resp = await service.generate_text(api_key=api_key, prompt=prompt, max_tokens=max_tokens, temperature=temperature, top_k=top_k, top_p=top_p)
status = 200 if resp.success else 400
return JSONResponse(status_code=status, content=resp.dict())