File size: 13,593 Bytes
0e8a23e
8fb58c0
8e6d932
ce7ff6c
684060d
 
ce7ff6c
8e6d932
ce7ff6c
8fb58c0
8e6d932
684060d
8e6d932
0e8a23e
 
45c0ed8
23a7d96
684060d
 
cdcab04
ce7ff6c
 
23a7d96
9fdb3a0
8e6d932
ce7ff6c
cdcab04
8e6d932
ce7ff6c
 
8e6d932
 
ce7ff6c
 
8e6d932
cdcab04
ce7ff6c
 
 
9d6afff
8e6d932
71a7916
0e8a23e
684060d
 
 
 
 
 
ae13336
0e8a23e
 
 
 
 
 
684060d
71a7916
8e6d932
ce7ff6c
cdcab04
8e6d932
23a7d96
ce7ff6c
 
 
 
 
b9300de
684060d
b9300de
b85befe
b9300de
8e6d932
ce7ff6c
8e6d932
ae13336
684060d
 
8e6d932
 
 
684060d
8e6d932
ae13336
 
684060d
ce7ff6c
23a7d96
8e6d932
ce7ff6c
8e6d932
ce7ff6c
684060d
b9300de
23a7d96
8e6d932
0e8a23e
8e6d932
b85befe
 
684060d
b85befe
 
23a7d96
b85befe
 
ae13336
b85befe
 
 
ce7ff6c
b85befe
 
 
 
 
 
 
 
 
 
 
23a7d96
b85befe
 
 
8e6d932
b85befe
 
 
 
 
 
8e6d932
b85befe
8e6d932
b85befe
 
71a7916
8e6d932
23a7d96
 
ce7ff6c
 
 
b9300de
ce7ff6c
 
b9300de
8e6d932
0e8a23e
8e6d932
7d3e98e
0e8a23e
 
23a7d96
ae13336
23a7d96
 
7d3e98e
684060d
23a7d96
b9300de
8e6d932
b85befe
ce7ff6c
23a7d96
ce7ff6c
b9300de
684060d
23a7d96
ae13336
0e8a23e
8e6d932
b85befe
0e8a23e
 
23a7d96
deb411f
b85befe
 
 
 
23a7d96
8e6d932
23a7d96
0e8a23e
8e6d932
0e8a23e
8e6d932
0e8a23e
8e6d932
 
 
 
 
 
 
 
ae13336
 
 
 
8e6d932
 
ae13336
8e6d932
 
 
b85befe
684060d
8e6d932
 
 
0e8a23e
8e6d932
23a7d96
0e8a23e
cd1e1eb
23a7d96
684060d
b9300de
0e8a23e
8e6d932
 
684060d
8e6d932
684060d
cd1e1eb
b85befe
23a7d96
 
8e6d932
 
 
23a7d96
8e6d932
0e8a23e
8e6d932
0e8a23e
684060d
2401055
b85befe
 
 
 
8e6d932
b85befe
0e8a23e
b85befe
 
 
 
 
0e8a23e
 
b85befe
 
 
 
 
0e8a23e
 
b85befe
 
 
 
 
 
0e8a23e
b85befe
 
 
 
 
 
0e8a23e
b85befe
 
 
8e6d932
 
b85befe
 
 
 
 
 
8e6d932
b85befe
 
 
71a7916
b85befe
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
# app.py
import os
import secrets
import logging
import asyncio
import html
from dataclasses import dataclass
from typing import Any, Optional, Tuple

import gradio as gr
from transformers import pipeline
from dotenv import load_dotenv
from pydantic import BaseModel
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse

# ----------------- Configuration & Models -----------------
load_dotenv()


@dataclass
class Config:
    HF_TOKEN: str = os.getenv("HF_TOKEN", "")
    MODEL_NAME: str = os.getenv("MODEL_NAME", "google/gemma-3-270m-it")
    MAX_TOKENS: int = int(os.getenv("MAX_TOKENS", "2048"))
    LOG_LEVEL: str = os.getenv("LOG_LEVEL", "INFO")


class GenerationRequest(BaseModel):
    prompt: str
    max_tokens: int = 512
    temperature: float = 0.7
    top_k: int = 50
    top_p: float = 0.95


class APIResponse(BaseModel):
    success: bool
    data: Any = None
    error: Optional[str] = None


# ----------------- Logger -----------------
def setup_logger() -> logging.Logger:
    cfg = Config()
    log_level = getattr(logging, cfg.LOG_LEVEL.upper(), logging.INFO)
    logger = logging.getLogger("gemma_saas")
    if not logger.handlers:
        logger.setLevel(log_level)
        formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
        fh = logging.FileHandler("gemma_saas.log")
        fh.setFormatter(formatter)
        sh = logging.StreamHandler()
        sh.setFormatter(formatter)
        logger.addHandler(fh)
        logger.addHandler(sh)
    return logger


logger = setup_logger()


# ----------------- Model Manager -----------------
class ModelManager:
    def __init__(self, config: Config):
        self.config = config
        self.pipeline = None
        self.model_loaded = False

    async def initialize(self) -> None:
        if not self.config.HF_TOKEN:
            logger.error("Token do Hugging Face não encontrado. O carregamento do modelo poderá falhar.")
            return

        try:
            logger.info(f"A carregar o modelo: {self.config.MODEL_NAME}...")
            os.environ.setdefault("HF_TOKEN", self.config.HF_TOKEN)

            loop = asyncio.get_event_loop()

            def load_pipeline():
                return pipeline(
                    "text-generation",
                    model=self.config.MODEL_NAME,
                    token=self.config.HF_TOKEN,
                    torch_dtype="auto",
                    device_map="auto",
                )

            self.pipeline = await loop.run_in_executor(None, load_pipeline)
            self.model_loaded = True
            logger.info("✅ Modelo carregado com sucesso!")
        except Exception as e:
            logger.error(f"❌ Erro ao carregar o modelo: {e}", exc_info=True)

    async def generate(self, request: GenerationRequest) -> Tuple[bool, str, int]:
        if not self.model_loaded or self.pipeline is None:
            return False, "❌ O modelo não está disponível. Por favor, verifique os logs do servidor.", 0

        if not request.prompt.strip():
            return False, "⚠️ O prompt não pode estar vazio.", 0

        loop = asyncio.get_event_loop()
        messages = [{"role": "user", "content": request.prompt.strip()}]

        def do_generation():
            tokenizer = getattr(self.pipeline, "tokenizer", None)

            if tokenizer and hasattr(tokenizer, "apply_chat_template"):
                prompt_text = tokenizer.apply_chat_template(
                    messages, tokenize=False, add_generation_prompt=True
                )
            else:
                prompt_text = request.prompt.strip()

            outputs = self.pipeline(
                prompt_text,
                max_new_tokens=min(request.max_tokens, self.config.MAX_TOKENS),
                do_sample=True,
                temperature=request.temperature,
                top_k=request.top_k,
                top_p=request.top_p,
            )

            generated_text = outputs[0].get("generated_text", "")
            if generated_text.startswith(prompt_text):
                generated_text = generated_text[len(prompt_text):]

            tokens_used = 0
            if tokenizer and hasattr(tokenizer, "encode"):
                try:
                    tokens_used = len(tokenizer.encode(generated_text))
                except Exception:
                    tokens_used = 0

            return generated_text, tokens_used

        generated_text, tokens_used = await loop.run_in_executor(None, do_generation)
        return True, generated_text, tokens_used


# ----------------- Service Layer -----------------
class GemmaService:
    def __init__(self):
        self.config = Config()
        self.model_manager = ModelManager(self.config)

    async def initialize(self):
        await self.model_manager.initialize()

    async def generate_text(self, api_key: str, prompt: str, **kwargs) -> APIResponse:
        if not api_key or not isinstance(api_key, str) or not api_key.startswith("gsk-"):
            return APIResponse(success=False, error="Chave de API inválida ou ausente.")
        try:
            req = GenerationRequest(prompt=prompt, **kwargs)
            success, text, tokens_used = await self.model_manager.generate(req)
            if success:
                return APIResponse(success=True, data={"generated_text": text, "tokens_used": tokens_used})
            else:
                return APIResponse(success=False, error=text)
        except Exception as e:
            logger.error(f"Erro de serviço durante a geração de texto: {e}", exc_info=True)
            return APIResponse(success=False, error="Ocorreu um erro interno no serviço.")


# ----------------- Build Gradio UI (síncrono) -----------------
class GradioInterface:
    def __init__(self, service: GemmaService):
        self.service = service

    def create_custom_css(self) -> str:
        return """
        @import url('https://fonts.googleapis.com/css2?family=Material+Icons&display=swap');
        :root { --dark-bg:#0a0a0a; --panel-bg:#1a1a1a; --border-color:#333; --text-color:#f0f0f0; --text-light:#a0a0a0; --accent-orange:#FF4500; --accent-orange-hover:#FF6347; --code-bg:#282c34; }
        .gradio-container { background: var(--dark-bg) !important; color: var(--text-color); }
        /* ... rest of CSS (trimmed for brevity) ... */
        #send_button::before { content: "send"; font-family: 'Material Icons', sans-serif; position:absolute; left:12px; top:50%; transform:translateY(-50%); font-size:18px; opacity:0.95; }
        #generate_button::before { content: "auto_awesome"; font-family: 'Material Icons', sans-serif; position:absolute; left:12px; top:50%; transform:translateY(-50%); font-size:18px; opacity:0.95; }
        """

    def create_interface(self) -> gr.Blocks:
        # Criar a interface de forma síncrona (não await)
        demo = gr.Blocks(css=self.create_custom_css(), theme=None)
        with demo:
            with gr.Row(elem_id="main_layout", equal_height=False):
                with gr.Column(scale=2):
                    with gr.Column(elem_id="left_panel"):
                        output_display = gr.Markdown(elem_id="output_display", value="<p style='color: #a0a0a0;'>A sua resposta aparecerá aqui...</p>")
                        with gr.Column(elem_id="input_area"):
                            api_key_input = gr.Textbox(label="A Sua Chave de API", placeholder="Cole a sua chave gsk-... aqui", type="password", elem_id="api_key_input")
                            with gr.Row():
                                prompt_input = gr.Textbox(show_label=False, placeholder="Digite a sua mensagem...", elem_id="prompt_input", scale=10)
                                send_button = gr.Button("➤ Enviar", elem_id="send_button", scale=2)

                with gr.Column(scale=1):
                    with gr.Column(elem_id="right_panel"):
                        gr.Markdown("## Controlo")
                        key_button = gr.Button("✨ Gerar Nova Chave", elem_id="generate_button")

                        with gr.Accordion("Parâmetros Avançados", open=False):
                            temp_slider = gr.Slider(minimum=0.1, maximum=2.0, value=0.7, step=0.1, label="Temperatura")
                            max_tokens_slider = gr.Slider(minimum=64, maximum=self.service.config.MAX_TOKENS, value=512, step=64, label="Max Tokens")
                            top_k_slider = gr.Slider(minimum=1, maximum=100, value=50, step=1, label="Top-K")
                            top_p_slider = gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-P")

                        gr.Markdown("### Como Usar a API")
                        api_example_display = gr.HTML("<p style='color: #a0a0a0;'>Clique em 'Gerar Nova Chave' para ver um exemplo de código.</p>")

            def handle_key_generation():
                key = f"gsk-{secrets.token_urlsafe(24).replace('_', '').replace('-', '')}"
                code_html = f"<div class='code-snippet'> ... </div>"
                return key, gr.update(value=code_html)

            async def handle_generation(api_key, prompt, temp, max_tokens, top_k, top_p, btn):
                if not api_key:
                    yield "<p style='color: #FFCC00;'>Por favor, insira a sua chave de API para começar.</p>", gr.update(value="➤ Enviar", interactive=True)
                    return
                if not prompt:
                    yield "<p style='color: #FFCC00;'>Por favor, digite um prompt.</p>", gr.update(value="➤ Enviar", interactive=True)
                    return

                yield "<p style='color: #a0a0a0;'>A gerar resposta...</p>", gr.update(value="A gerar...", interactive=False)

                response = await self.service.generate_text(api_key=api_key, prompt=prompt, temperature=temp, max_tokens=int(max_tokens), top_k=int(top_k), top_p=top_p)
                if response.success:
                    formatted_text = html.escape(response.data["generated_text"]).replace("\n", "<br>")
                    yield formatted_text, gr.update(value="➤ Enviar", interactive=True)
                else:
                    yield f"<p style='color: #FF4500;'>{response.error}</p>", gr.update(value="➤ Enviar", interactive=True)

            # conectar o callback
            send_button.click(
                handle_generation,
                inputs=[api_key_input, prompt_input, temp_slider, max_tokens_slider, top_k_slider, top_p_slider, send_button],
                outputs=[output_display, send_button],
                api_name="generate",
            )
            key_button.click(handle_key_generation, outputs=[api_key_input, api_example_display])
            demo.load(lambda: gr.update(value="<p style='color: #a0a0a0;'>Clique em 'Gerar Nova Chave' para ver um exemplo de código.</p>"), [], [api_example_display])

        return demo


# ----------------- FastAPI app and endpoints -----------------
service = GemmaService()
gradio_interface = GradioInterface(service)
gradio_blocks = gradio_interface.create_interface()

app = FastAPI(title="Gemma Service (Gradio + API)")

# montar Gradio na raiz "/" - se mount falhar, a UI ainda poderá ser servida pelo Space.
try:
    gr.mount_gradio_app(app, gradio_blocks, path="/")
except Exception as exc:
    logger.warning("Não foi possível montar Gradio automaticamente: %s", exc)


@app.on_event("startup")
async def startup_event():
    # inicializa modelo em background (não bloqueia o startup)
    # se preferir aguarde a carga antes de aceitar requests, substitua create_task por await
    asyncio.create_task(service.initialize())


@app.post("/api/generate")
async def api_generate(req: Request):
    try:
        body = await req.json()
    except Exception:
        return JSONResponse(status_code=400, content={"success": False, "error": "Payload inválido (JSON esperado)."})

    api_key = body.get("api_key")
    prompt = body.get("prompt", "")
    max_tokens = int(body.get("max_tokens", 512))
    temperature = float(body.get("temperature", 0.7))
    top_k = int(body.get("top_k", 50))
    top_p = float(body.get("top_p", 0.95))

    resp = await service.generate_text(api_key=api_key, prompt=prompt, max_tokens=max_tokens, temperature=temperature, top_k=top_k, top_p=top_p)
    status = 200 if resp.success else 400
    return JSONResponse(status_code=status, content=resp.dict())


@app.post("/run/generate")
async def gradio_compatible_generate(req: Request):
    try:
        body = await req.json()
    except Exception:
        return JSONResponse(status_code=400, content={"success": False, "error": "Payload inválido (JSON esperado)."})

    data = body.get("data")
    if not isinstance(data, list):
        return JSONResponse(status_code=400, content={"success": False, "error": "Campo 'data' inválido. Esperado array."})

    try:
        api_key = data[0]
        prompt = data[1] if len(data) > 1 else ""
        max_tokens = int(data[2]) if len(data) > 2 else 512
        temperature = float(data[3]) if len(data) > 3 else 0.7
        top_k = int(data[4]) if len(data) > 4 else 50
        top_p = float(data[5]) if len(data) > 5 else 0.95
    except Exception as e:
        return JSONResponse(status_code=400, content={"success": False, "error": f"Erro ao parsear 'data': {e}"})

    resp = await service.generate_text(api_key=api_key, prompt=prompt, max_tokens=max_tokens, temperature=temperature, top_k=top_k, top_p=top_p)
    status = 200 if resp.success else 400
    return JSONResponse(status_code=status, content=resp.dict())