| import torch |
| import time |
| import gc |
| import os |
| import logging |
| from collections import defaultdict |
| from transformers import AutoTokenizer |
| from fastapi import FastAPI, Request |
| from fastapi.responses import JSONResponse |
| from fastapi.middleware.cors import CORSMiddleware |
| import gradio as gr |
| from best import ModelConfig, IndonesianLLM, generate_text, _extract_thinking |
|
|
| |
| logging.basicConfig(level=logging.INFO) |
| logger = logging.getLogger(__name__) |
|
|
| |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| logger.info(f"Device: {device}") |
|
|
| |
| logger.info(f"model.pt ada: {os.path.exists('indonesian_llm_model (42).pt')}") |
| if not os.path.exists('indonesian_llm_model (42).pt'): |
| raise FileNotFoundError("model.pt tidak ditemukan! Upload dulu ke Space.") |
| logger.info(f"model.pt size: {os.path.getsize('indonesian_llm_model (42).pt') / 1e6:.1f} MB") |
|
|
| |
| logger.info("Loading tokenizer...") |
| tokenizer = AutoTokenizer.from_pretrained("indolem/indobert-base-uncased") |
| tokenizer.add_special_tokens({"additional_special_tokens": ["<cot>", "</cot>"]}) |
| logger.info("Tokenizer OK") |
|
|
| |
| logger.info("Loading checkpoint...") |
| checkpoint = torch.load("indonesian_llm_model (42).pt", map_location='cpu', weights_only=False) |
| logger.info(f"Checkpoint keys: {list(checkpoint.keys())}") |
|
|
| logger.info("Building model...") |
| config = checkpoint['config'] |
| model = IndonesianLLM(config) |
| logger.info(f"Model params: {model.count_parameters():,}") |
|
|
| logger.info("Loading weights...") |
| state_dict = checkpoint['model_state_dict'] |
| for k in list(state_dict.keys()): |
| if state_dict[k].dtype == torch.float16: |
| state_dict[k] = state_dict[k].float() |
|
|
| model.load_state_dict(state_dict) |
| del checkpoint, state_dict |
| gc.collect() |
|
|
| model.eval() |
| model.to(device) |
| logger.info("Model siap!") |
|
|
| |
| API_KEYS = {"kunci-rahasia-kamu-123"} |
| ip_request_count = defaultdict(list) |
| ip_banned_until = {} |
| BLACKLIST_THRESHOLD = 100 |
| BLACKLIST_WINDOW = 60 |
| BLACKLIST_DURATION = 3600 |
|
|
| |
| |
| |
| app = FastAPI( |
| title="Indonesian LLM API", |
| description="API untuk model bahasa Indonesia dengan Chain-of-Thought", |
| version="1.0.0" |
| ) |
|
|
| |
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins=["*"], |
| allow_methods=["*"], |
| allow_headers=["*"], |
| ) |
|
|
| |
| @app.middleware("http") |
| async def ddos_protection(request: Request, call_next): |
| ip = request.client.host if request.client else "unknown" |
| now = time.time() |
|
|
| if ip in ip_banned_until: |
| if now < ip_banned_until[ip]: |
| remaining = int(ip_banned_until[ip] - now) |
| return JSONResponse( |
| status_code=429, |
| content={"error": f"IP dibanned. Coba lagi dalam {remaining} detik."} |
| ) |
| else: |
| del ip_banned_until[ip] |
| ip_request_count[ip] = [] |
|
|
| ip_request_count[ip].append(now) |
| ip_request_count[ip] = [t for t in ip_request_count[ip] if now - t < BLACKLIST_WINDOW] |
|
|
| if len(ip_request_count[ip]) > BLACKLIST_THRESHOLD: |
| ip_banned_until[ip] = now + BLACKLIST_DURATION |
| ip_request_count[ip] = [] |
| return JSONResponse( |
| status_code=429, |
| content={"error": f"Terlalu banyak request. IP dibanned selama {BLACKLIST_DURATION // 60} menit."} |
| ) |
|
|
| return await call_next(request) |
|
|
| |
| |
| |
|
|
| def check_api_key(request: Request) -> bool: |
| key = request.headers.get("X-API-Key") |
| return bool(key and key in API_KEYS) |
|
|
| @app.get("/api/health") |
| def health(): |
| return { |
| "status": "ok", |
| "device": str(device), |
| "model_params": model.count_parameters() |
| } |
|
|
| @app.post("/api/chat") |
| async def api_chat(request: Request): |
| |
| if not check_api_key(request): |
| return JSONResponse( |
| status_code=401, |
| content={"error": "API key tidak valid. Tambahkan header X-API-Key."} |
| ) |
|
|
| |
| ip = request.client.host if request.client else "unknown" |
| now = time.time() |
| rate_key = f"{ip}_chat" |
| ip_request_count[rate_key] = [ |
| t for t in ip_request_count[rate_key] if now - t < 60 |
| ] |
| if len(ip_request_count[rate_key]) >= 10: |
| return JSONResponse( |
| status_code=429, |
| content={"error": "Rate limit: maksimal 10 request per menit."} |
| ) |
| ip_request_count[rate_key].append(now) |
|
|
| |
| try: |
| body = await request.json() |
| message = str(body.get("message", "")).strip() |
| max_tokens = int(body.get("max_tokens", 200)) |
| temperature = float(body.get("temperature", 0.7)) |
| show_think = bool(body.get("show_thinking", False)) |
| except Exception: |
| return JSONResponse( |
| status_code=400, |
| content={"error": "Request body tidak valid. Gunakan JSON."} |
| ) |
|
|
| |
| if not message: |
| return JSONResponse(status_code=400, content={"error": "Pesan tidak boleh kosong."}) |
| if len(message) > 500: |
| return JSONResponse(status_code=400, content={"error": "Pesan terlalu panjang. Maksimal 500 karakter."}) |
| if not (10 <= max_tokens <= 500): |
| return JSONResponse(status_code=400, content={"error": "max_tokens harus antara 10 dan 500."}) |
| if not (0.1 <= temperature <= 1.5): |
| return JSONResponse(status_code=400, content={"error": "temperature harus antara 0.1 dan 1.5."}) |
|
|
| |
| try: |
| start = time.time() |
| prompt = f"{message} <cot>" |
| full = generate_text( |
| model=model, |
| tokenizer=tokenizer, |
| prompt=prompt, |
| max_new_tokens=max_tokens, |
| temperature=temperature, |
| top_k=50, |
| top_p=0.9, |
| device=device |
| ) |
| raw = full[len(prompt):].strip() |
| thinking, answer = _extract_thinking(raw) |
| elapsed_ms = int((time.time() - start) * 1000) |
|
|
| logger.info(f"[{ip}] '{message[:40]}' β {elapsed_ms}ms") |
|
|
| return JSONResponse(content={ |
| "answer": answer if answer else "Maaf, saya tidak mengerti.", |
| "thinking": thinking if show_think else None, |
| "processing_time_ms": elapsed_ms |
| }) |
|
|
| except Exception as e: |
| logger.error(f"Generate error: {e}") |
| return JSONResponse( |
| status_code=500, |
| content={"error": f"Gagal generate: {str(e)}"} |
| ) |
|
|
| |
| |
| |
| def gradio_chat(message, history): |
| if not message.strip(): |
| return "Silakan ketik pesan." |
| try: |
| prompt = f"{message} <cot>" |
| full = generate_text( |
| model=model, |
| tokenizer=tokenizer, |
| prompt=prompt, |
| max_new_tokens=200, |
| temperature=0.7, |
| top_k=50, |
| top_p=0.9, |
| device=device |
| ) |
| raw = full[len(prompt):].strip() |
| _, answer = _extract_thinking(raw) |
| return answer if answer else "Maaf, saya tidak mengerti." |
| except Exception as e: |
| logger.error(f"Gradio error: {e}") |
| return f"Error: {str(e)}" |
|
|
| gradio_ui = gr.ChatInterface( |
| fn=gradio_chat, |
| title="Indonesian LLM", |
| description="Model bahasa Indonesia dengan Chain-of-Thought reasoning. API tersedia di /api/chat", |
| examples=[ |
| ["Halo, apa kabar?"], |
| ["Jelaskan cara kerja internet"], |
| ["Berapa hasil dari 25 dikali 4?"], |
| ["Apa ibu kota Indonesia?"], |
| ], |
| ) |
|
|
| |
| |
| |
| demo = gr.mount_gradio_app(app, gradio_ui, path="/") |
|
|
| |
| |
| |
| import uvicorn |
| uvicorn.run(demo, host="0.0.0.0", port=7860) |