|
|
|
|
|
import os |
|
|
import secrets |
|
|
import logging |
|
|
import asyncio |
|
|
import html |
|
|
from dataclasses import dataclass |
|
|
from typing import Any, Optional, Tuple |
|
|
|
|
|
import gradio as gr |
|
|
from transformers import pipeline |
|
|
from dotenv import load_dotenv |
|
|
from pydantic import BaseModel |
|
|
from fastapi import FastAPI, Request |
|
|
from fastapi.responses import JSONResponse |
|
|
|
|
|
|
|
|
load_dotenv() |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class Config: |
|
|
HF_TOKEN: str = os.getenv("HF_TOKEN", "") |
|
|
MODEL_NAME: str = os.getenv("MODEL_NAME", "google/gemma-3-270m-it") |
|
|
MAX_TOKENS: int = int(os.getenv("MAX_TOKENS", "2048")) |
|
|
LOG_LEVEL: str = os.getenv("LOG_LEVEL", "INFO") |
|
|
|
|
|
|
|
|
class GenerationRequest(BaseModel): |
|
|
prompt: str |
|
|
max_tokens: int = 512 |
|
|
temperature: float = 0.7 |
|
|
top_k: int = 50 |
|
|
top_p: float = 0.95 |
|
|
|
|
|
|
|
|
class APIResponse(BaseModel): |
|
|
success: bool |
|
|
data: Any = None |
|
|
error: Optional[str] = None |
|
|
|
|
|
|
|
|
|
|
|
def setup_logger() -> logging.Logger: |
|
|
cfg = Config() |
|
|
log_level = getattr(logging, cfg.LOG_LEVEL.upper(), logging.INFO) |
|
|
logger = logging.getLogger("gemma_saas") |
|
|
if not logger.handlers: |
|
|
logger.setLevel(log_level) |
|
|
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") |
|
|
fh = logging.FileHandler("gemma_saas.log") |
|
|
fh.setFormatter(formatter) |
|
|
sh = logging.StreamHandler() |
|
|
sh.setFormatter(formatter) |
|
|
logger.addHandler(fh) |
|
|
logger.addHandler(sh) |
|
|
return logger |
|
|
|
|
|
|
|
|
logger = setup_logger() |
|
|
|
|
|
|
|
|
|
|
|
class ModelManager: |
|
|
def __init__(self, config: Config): |
|
|
self.config = config |
|
|
self.pipeline = None |
|
|
self.model_loaded = False |
|
|
|
|
|
async def initialize(self) -> None: |
|
|
if not self.config.HF_TOKEN: |
|
|
logger.error("Token do Hugging Face não encontrado. O carregamento do modelo poderá falhar.") |
|
|
return |
|
|
|
|
|
try: |
|
|
logger.info(f"A carregar o modelo: {self.config.MODEL_NAME}...") |
|
|
os.environ.setdefault("HF_TOKEN", self.config.HF_TOKEN) |
|
|
|
|
|
loop = asyncio.get_event_loop() |
|
|
|
|
|
def load_pipeline(): |
|
|
return pipeline( |
|
|
"text-generation", |
|
|
model=self.config.MODEL_NAME, |
|
|
token=self.config.HF_TOKEN, |
|
|
torch_dtype="auto", |
|
|
device_map="auto", |
|
|
) |
|
|
|
|
|
self.pipeline = await loop.run_in_executor(None, load_pipeline) |
|
|
self.model_loaded = True |
|
|
logger.info("✅ Modelo carregado com sucesso!") |
|
|
except Exception as e: |
|
|
logger.error(f"❌ Erro ao carregar o modelo: {e}", exc_info=True) |
|
|
|
|
|
async def generate(self, request: GenerationRequest) -> Tuple[bool, str, int]: |
|
|
if not self.model_loaded or self.pipeline is None: |
|
|
return False, "❌ O modelo não está disponível. Por favor, verifique os logs do servidor.", 0 |
|
|
|
|
|
if not request.prompt.strip(): |
|
|
return False, "⚠️ O prompt não pode estar vazio.", 0 |
|
|
|
|
|
loop = asyncio.get_event_loop() |
|
|
messages = [{"role": "user", "content": request.prompt.strip()}] |
|
|
|
|
|
def do_generation(): |
|
|
tokenizer = getattr(self.pipeline, "tokenizer", None) |
|
|
|
|
|
if tokenizer and hasattr(tokenizer, "apply_chat_template"): |
|
|
prompt_text = tokenizer.apply_chat_template( |
|
|
messages, tokenize=False, add_generation_prompt=True |
|
|
) |
|
|
else: |
|
|
prompt_text = request.prompt.strip() |
|
|
|
|
|
outputs = self.pipeline( |
|
|
prompt_text, |
|
|
max_new_tokens=min(request.max_tokens, self.config.MAX_TOKENS), |
|
|
do_sample=True, |
|
|
temperature=request.temperature, |
|
|
top_k=request.top_k, |
|
|
top_p=request.top_p, |
|
|
) |
|
|
|
|
|
generated_text = outputs[0].get("generated_text", "") |
|
|
if generated_text.startswith(prompt_text): |
|
|
generated_text = generated_text[len(prompt_text):] |
|
|
|
|
|
tokens_used = 0 |
|
|
if tokenizer and hasattr(tokenizer, "encode"): |
|
|
try: |
|
|
tokens_used = len(tokenizer.encode(generated_text)) |
|
|
except Exception: |
|
|
tokens_used = 0 |
|
|
|
|
|
return generated_text, tokens_used |
|
|
|
|
|
generated_text, tokens_used = await loop.run_in_executor(None, do_generation) |
|
|
return True, generated_text, tokens_used |
|
|
|
|
|
|
|
|
|
|
|
class GemmaService: |
|
|
def __init__(self): |
|
|
self.config = Config() |
|
|
self.model_manager = ModelManager(self.config) |
|
|
|
|
|
async def initialize(self): |
|
|
await self.model_manager.initialize() |
|
|
|
|
|
async def generate_text(self, api_key: str, prompt: str, **kwargs) -> APIResponse: |
|
|
if not api_key or not isinstance(api_key, str) or not api_key.startswith("gsk-"): |
|
|
return APIResponse(success=False, error="Chave de API inválida ou ausente.") |
|
|
try: |
|
|
req = GenerationRequest(prompt=prompt, **kwargs) |
|
|
success, text, tokens_used = await self.model_manager.generate(req) |
|
|
if success: |
|
|
return APIResponse(success=True, data={"generated_text": text, "tokens_used": tokens_used}) |
|
|
else: |
|
|
return APIResponse(success=False, error=text) |
|
|
except Exception as e: |
|
|
logger.error(f"Erro de serviço durante a geração de texto: {e}", exc_info=True) |
|
|
return APIResponse(success=False, error="Ocorreu um erro interno no serviço.") |
|
|
|
|
|
|
|
|
|
|
|
class GradioInterface: |
|
|
def __init__(self, service: GemmaService): |
|
|
self.service = service |
|
|
|
|
|
def create_custom_css(self) -> str: |
|
|
return """ |
|
|
@import url('https://fonts.googleapis.com/css2?family=Material+Icons&display=swap'); |
|
|
:root { --dark-bg:#0a0a0a; --panel-bg:#1a1a1a; --border-color:#333; --text-color:#f0f0f0; --text-light:#a0a0a0; --accent-orange:#FF4500; --accent-orange-hover:#FF6347; --code-bg:#282c34; } |
|
|
.gradio-container { background: var(--dark-bg) !important; color: var(--text-color); } |
|
|
/* ... rest of CSS (trimmed for brevity) ... */ |
|
|
#send_button::before { content: "send"; font-family: 'Material Icons', sans-serif; position:absolute; left:12px; top:50%; transform:translateY(-50%); font-size:18px; opacity:0.95; } |
|
|
#generate_button::before { content: "auto_awesome"; font-family: 'Material Icons', sans-serif; position:absolute; left:12px; top:50%; transform:translateY(-50%); font-size:18px; opacity:0.95; } |
|
|
""" |
|
|
|
|
|
def create_interface(self) -> gr.Blocks: |
|
|
|
|
|
demo = gr.Blocks(css=self.create_custom_css(), theme=None) |
|
|
with demo: |
|
|
with gr.Row(elem_id="main_layout", equal_height=False): |
|
|
with gr.Column(scale=2): |
|
|
with gr.Column(elem_id="left_panel"): |
|
|
output_display = gr.Markdown(elem_id="output_display", value="<p style='color: #a0a0a0;'>A sua resposta aparecerá aqui...</p>") |
|
|
with gr.Column(elem_id="input_area"): |
|
|
api_key_input = gr.Textbox(label="A Sua Chave de API", placeholder="Cole a sua chave gsk-... aqui", type="password", elem_id="api_key_input") |
|
|
with gr.Row(): |
|
|
prompt_input = gr.Textbox(show_label=False, placeholder="Digite a sua mensagem...", elem_id="prompt_input", scale=10) |
|
|
send_button = gr.Button("➤ Enviar", elem_id="send_button", scale=2) |
|
|
|
|
|
with gr.Column(scale=1): |
|
|
with gr.Column(elem_id="right_panel"): |
|
|
gr.Markdown("## Controlo") |
|
|
key_button = gr.Button("✨ Gerar Nova Chave", elem_id="generate_button") |
|
|
|
|
|
with gr.Accordion("Parâmetros Avançados", open=False): |
|
|
temp_slider = gr.Slider(minimum=0.1, maximum=2.0, value=0.7, step=0.1, label="Temperatura") |
|
|
max_tokens_slider = gr.Slider(minimum=64, maximum=self.service.config.MAX_TOKENS, value=512, step=64, label="Max Tokens") |
|
|
top_k_slider = gr.Slider(minimum=1, maximum=100, value=50, step=1, label="Top-K") |
|
|
top_p_slider = gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-P") |
|
|
|
|
|
gr.Markdown("### Como Usar a API") |
|
|
api_example_display = gr.HTML("<p style='color: #a0a0a0;'>Clique em 'Gerar Nova Chave' para ver um exemplo de código.</p>") |
|
|
|
|
|
def handle_key_generation(): |
|
|
key = f"gsk-{secrets.token_urlsafe(24).replace('_', '').replace('-', '')}" |
|
|
code_html = f"<div class='code-snippet'> ... </div>" |
|
|
return key, gr.update(value=code_html) |
|
|
|
|
|
async def handle_generation(api_key, prompt, temp, max_tokens, top_k, top_p, btn): |
|
|
if not api_key: |
|
|
yield "<p style='color: #FFCC00;'>Por favor, insira a sua chave de API para começar.</p>", gr.update(value="➤ Enviar", interactive=True) |
|
|
return |
|
|
if not prompt: |
|
|
yield "<p style='color: #FFCC00;'>Por favor, digite um prompt.</p>", gr.update(value="➤ Enviar", interactive=True) |
|
|
return |
|
|
|
|
|
yield "<p style='color: #a0a0a0;'>A gerar resposta...</p>", gr.update(value="A gerar...", interactive=False) |
|
|
|
|
|
response = await self.service.generate_text(api_key=api_key, prompt=prompt, temperature=temp, max_tokens=int(max_tokens), top_k=int(top_k), top_p=top_p) |
|
|
if response.success: |
|
|
formatted_text = html.escape(response.data["generated_text"]).replace("\n", "<br>") |
|
|
yield formatted_text, gr.update(value="➤ Enviar", interactive=True) |
|
|
else: |
|
|
yield f"<p style='color: #FF4500;'>{response.error}</p>", gr.update(value="➤ Enviar", interactive=True) |
|
|
|
|
|
|
|
|
send_button.click( |
|
|
handle_generation, |
|
|
inputs=[api_key_input, prompt_input, temp_slider, max_tokens_slider, top_k_slider, top_p_slider, send_button], |
|
|
outputs=[output_display, send_button], |
|
|
api_name="generate", |
|
|
) |
|
|
key_button.click(handle_key_generation, outputs=[api_key_input, api_example_display]) |
|
|
demo.load(lambda: gr.update(value="<p style='color: #a0a0a0;'>Clique em 'Gerar Nova Chave' para ver um exemplo de código.</p>"), [], [api_example_display]) |
|
|
|
|
|
return demo |
|
|
|
|
|
|
|
|
|
|
|
service = GemmaService() |
|
|
gradio_interface = GradioInterface(service) |
|
|
gradio_blocks = gradio_interface.create_interface() |
|
|
|
|
|
app = FastAPI(title="Gemma Service (Gradio + API)") |
|
|
|
|
|
|
|
|
try: |
|
|
gr.mount_gradio_app(app, gradio_blocks, path="/") |
|
|
except Exception as exc: |
|
|
logger.warning("Não foi possível montar Gradio automaticamente: %s", exc) |
|
|
|
|
|
|
|
|
@app.on_event("startup") |
|
|
async def startup_event(): |
|
|
|
|
|
|
|
|
asyncio.create_task(service.initialize()) |
|
|
|
|
|
|
|
|
@app.post("/api/generate") |
|
|
async def api_generate(req: Request): |
|
|
try: |
|
|
body = await req.json() |
|
|
except Exception: |
|
|
return JSONResponse(status_code=400, content={"success": False, "error": "Payload inválido (JSON esperado)."}) |
|
|
|
|
|
api_key = body.get("api_key") |
|
|
prompt = body.get("prompt", "") |
|
|
max_tokens = int(body.get("max_tokens", 512)) |
|
|
temperature = float(body.get("temperature", 0.7)) |
|
|
top_k = int(body.get("top_k", 50)) |
|
|
top_p = float(body.get("top_p", 0.95)) |
|
|
|
|
|
resp = await service.generate_text(api_key=api_key, prompt=prompt, max_tokens=max_tokens, temperature=temperature, top_k=top_k, top_p=top_p) |
|
|
status = 200 if resp.success else 400 |
|
|
return JSONResponse(status_code=status, content=resp.dict()) |
|
|
|
|
|
|
|
|
@app.post("/run/generate") |
|
|
async def gradio_compatible_generate(req: Request): |
|
|
try: |
|
|
body = await req.json() |
|
|
except Exception: |
|
|
return JSONResponse(status_code=400, content={"success": False, "error": "Payload inválido (JSON esperado)."}) |
|
|
|
|
|
data = body.get("data") |
|
|
if not isinstance(data, list): |
|
|
return JSONResponse(status_code=400, content={"success": False, "error": "Campo 'data' inválido. Esperado array."}) |
|
|
|
|
|
try: |
|
|
api_key = data[0] |
|
|
prompt = data[1] if len(data) > 1 else "" |
|
|
max_tokens = int(data[2]) if len(data) > 2 else 512 |
|
|
temperature = float(data[3]) if len(data) > 3 else 0.7 |
|
|
top_k = int(data[4]) if len(data) > 4 else 50 |
|
|
top_p = float(data[5]) if len(data) > 5 else 0.95 |
|
|
except Exception as e: |
|
|
return JSONResponse(status_code=400, content={"success": False, "error": f"Erro ao parsear 'data': {e}"}) |
|
|
|
|
|
resp = await service.generate_text(api_key=api_key, prompt=prompt, max_tokens=max_tokens, temperature=temperature, top_k=top_k, top_p=top_p) |
|
|
status = 200 if resp.success else 400 |
|
|
return JSONResponse(status_code=status, content=resp.dict()) |
|
|
|