File size: 11,801 Bytes
6fc6360
add12a3
 
9784a84
6fc6360
 
 
add12a3
 
9784a84
 
 
 
add12a3
 
6fc6360
 
add12a3
 
 
 
 
 
 
 
9784a84
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26a4c95
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
add12a3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6fc6360
add12a3
6fc6360
 
 
26a4c95
9784a84
6fc6360
 
 
 
26a4c95
 
 
 
 
 
 
9784a84
 
 
 
add12a3
 
 
 
 
 
6fc6360
 
 
 
 
 
 
 
add12a3
 
9784a84
add12a3
 
 
 
 
 
 
 
 
 
 
 
9784a84
 
 
add12a3
 
 
 
 
9784a84
add12a3
9784a84
6fc6360
26a4c95
 
 
6fc6360
9784a84
 
 
 
 
 
 
 
 
6fc6360
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1e5532c
6fc6360
1e5532c
6fc6360
1e5532c
26a4c95
 
 
 
 
 
 
 
 
 
 
6fc6360
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1e5532c
 
 
 
6fc6360
 
1e5532c
 
6fc6360
 
 
 
1e5532c
 
 
6fc6360
 
 
 
1e5532c
 
 
6fc6360
 
 
 
 
 
1e5532c
6fc6360
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
"""
Gradio demo for Gemma Code Generator.
Loads the fine-tuned model directly using PEFT.
Includes production monitoring and logging.
"""

import gradio as gr
import torch
import os
import json
import ast
import time
from datetime import datetime
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel

# Model configuration
BASE_MODEL = "google/gemma-2-2b-it"
ADAPTER_MODEL = "nvhuynh16/gemma-2b-code-alpaca-best"
HF_TOKEN = os.environ.get("HF_TOKEN", None)

# Global variables for lazy loading
tokenizer = None
model = None

# Production monitoring
ENABLE_LOGGING = os.environ.get("ENABLE_LOGGING", "true").lower() == "true"
LOG_FILE = "production_logs.jsonl"

def check_syntax(code: str) -> bool:
    """Check if generated code has valid Python syntax"""
    try:
        ast.parse(code)
        return True
    except:
        return False

def log_request(instruction: str, generated_code: str, tokens_generated: int, latency: float, error: str = None):
    """Log request for production monitoring"""
    if not ENABLE_LOGGING:
        return

    log_entry = {
        "timestamp": datetime.now().isoformat(),
        "instruction_length": len(instruction),
        "response_length": len(generated_code),
        "tokens_generated": tokens_generated,
        "latency_seconds": round(latency, 2),
        "has_syntax_error": not check_syntax(generated_code) if generated_code and not error else True,
        "error": error,
    }

    try:
        with open(LOG_FILE, "a", encoding="utf-8") as f:
            f.write(json.dumps(log_entry) + "\n")
    except Exception as e:
        print(f"Logging failed: {e}")

# Safety filters - Layer 1: Input Validation
DANGEROUS_KEYWORDS = [
    "delete all files", "rm -rf", "shutil.rmtree",
    "sql injection", "drop table", "truncate table",
    "keylogger", "backdoor", "exploit",
    "hack into", "steal password", "crack password",
    "ddos", "denial of service", "fork bomb",
    "malware", "ransomware", "trojan"
]

def validate_input(instruction: str) -> tuple:
    """
    Validate input for dangerous keywords.
    Returns: (is_valid: bool, error_message: str)
    """
    instruction_lower = instruction.lower()

    for keyword in DANGEROUS_KEYWORDS:
        if keyword in instruction_lower:
            return False, f"⚠️ Safety Filter: Request blocked. Your instruction contains potentially unsafe content related to '{keyword}'.\n\nPlease rephrase your request to focus on legitimate programming tasks."

    return True, ""

# Safety filters - Layer 2: Output Filtering
DANGEROUS_PATTERNS = [
    ("os.remove", "file deletion"),
    ("shutil.rmtree", "directory deletion"),
    ("os.unlink", "file deletion"),
    ("DROP TABLE", "database destruction"),
    ("TRUNCATE TABLE", "database destruction"),
    ("DELETE FROM", "database deletion"),
    ("eval(", "arbitrary code execution"),
    ("exec(", "arbitrary code execution"),
    ("__import__", "dynamic imports"),
    ("os.system", "system command execution"),
    ("subprocess.call", "system command execution"),
    ("subprocess.run", "system command execution"),
]

def filter_dangerous_code(code: str) -> str:
    """
    Filter dangerous code patterns from output.
    Returns: filtered code or safety warning
    """
    code_lower = code.lower()

    for pattern, reason in DANGEROUS_PATTERNS:
        if pattern.lower() in code_lower:
            return f"""# ⚠️ SAFETY FILTER ACTIVATED
#
# Code generation blocked: Potentially dangerous pattern detected ({reason})
# Pattern: {pattern}
#
# This is a safety feature to prevent generating code that could:
# - Delete files or data
# - Execute arbitrary system commands
# - Compromise system security
#
# Please rephrase your request with safer requirements.
# For educational purposes, consult official documentation or security resources.
"""

    return code

def load_model():
    """Lazy load model on first request"""
    global tokenizer, model

    if model is None:
        print("Loading model for the first time...")

        # Load tokenizer
        tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, token=HF_TOKEN)

        # Load base model with 4-bit quantization
        base_model = AutoModelForCausalLM.from_pretrained(
            BASE_MODEL,
            device_map="auto",
            torch_dtype=torch.float16,
            load_in_4bit=True,
            token=HF_TOKEN
        )

        # Load LoRA adapter
        model = PeftModel.from_pretrained(base_model, ADAPTER_MODEL, token=HF_TOKEN)
        model.eval()

        print("Model loaded successfully!")

    return tokenizer, model


def generate_code(instruction: str, max_tokens: int = 256, temperature: float = 0.7):
    """Generate code from instruction with monitoring and safety filters"""
    start_time = time.time()

    if not instruction.strip():
        return "Please enter an instruction."

    # Layer 1: Input validation
    is_valid, validation_error = validate_input(instruction)
    if not is_valid:
        # Log blocked request
        log_request(instruction, validation_error, 0, time.time() - start_time, "BLOCKED_BY_SAFETY_FILTER")
        return validation_error

    generated_code = ""
    tokens_generated = 0
    error = None

    try:
        # Load model (cached after first call)
        tok, mdl = load_model()

        # Format prompt in Alpaca style
        prompt = f"""### Instruction:
{instruction}

### Input:


### Response:
"""

        # Tokenize
        inputs = tok(prompt, return_tensors="pt").to(mdl.device)
        input_length = len(inputs.input_ids[0])

        # Generate
        with torch.no_grad():
            outputs = mdl.generate(
                **inputs,
                max_new_tokens=max_tokens,
                temperature=temperature,
                top_p=0.9,
                do_sample=True,
                pad_token_id=tok.eos_token_id,
            )

        # Calculate tokens generated
        tokens_generated = len(outputs[0]) - input_length

        # Decode
        generated = tok.decode(outputs[0], skip_special_tokens=True)

        # Extract code after "### Response:"
        if "### Response:" in generated:
            generated_code = generated.split("### Response:")[-1].strip()
        else:
            generated_code = generated.strip()

        # Layer 2: Output filtering for dangerous patterns
        generated_code = filter_dangerous_code(generated_code)

    except Exception as e:
        error = str(e)
        generated_code = f"Error: {error}\n\nPlease try again."

    finally:
        # Log request
        latency = time.time() - start_time
        log_request(instruction, generated_code, tokens_generated, latency, error)

    return generated_code


# Custom CSS for better appearance
custom_css = """
.container {
    max-width: 900px;
    margin: auto;
}
.output-code {
    font-family: 'Courier New', monospace;
    font-size: 14px;
}
"""

# Create Gradio interface
with gr.Blocks(theme=gr.themes.Soft(), css=custom_css) as demo:

    gr.Markdown(
        """
        # 🤖 Gemma Code Generator

        Fine-tuned Gemma-2B model for multi-language code generation using QLoRA.

        **Performance**: 76% syntax correctness | **BLEU Score: 16.83** (+53% improvement over baseline 11.00)

        **Note**: First request may take 1-2 minutes as the model loads on HuggingFace servers. Subsequent requests are instant!

        ---

        ### 🛡️ Safety Features

        This demo includes production-grade safety filters:
        - **Input Validation**: Blocks requests with potentially dangerous keywords
        - **Output Filtering**: Prevents generation of code that could delete files, execute arbitrary commands, or compromise security
        - **Production Monitoring**: All requests are logged for quality tracking (privacy-respecting, no personal data stored)

        ⚠️ **AI-Generated Code Disclaimer**: Always review generated code before use. AI models can make mistakes.
        """
    )

    with gr.Row():
        with gr.Column(scale=1):
            instruction_input = gr.Textbox(
                label="Code Instruction",
                placeholder="Describe the function you want to create...",
                lines=3,
            )

            with gr.Accordion("Advanced Settings", open=False):
                max_tokens_slider = gr.Slider(
                    minimum=64,
                    maximum=512,
                    value=256,
                    step=64,
                    label="Max Tokens",
                    info="Maximum length of generated code"
                )

                temperature_slider = gr.Slider(
                    minimum=0.1,
                    maximum=1.5,
                    value=0.7,
                    step=0.1,
                    label="Temperature",
                    info="Higher = more creative, Lower = more deterministic"
                )

            generate_btn = gr.Button("Generate Code", variant="primary", size="lg")

        with gr.Column(scale=1):
            output_code = gr.Code(
                label="Generated Code",
                language="python",
                elem_classes="output-code"
            )

    # Examples
    gr.Examples(
        examples=[
            ["Write a function to check if a number is prime"],
            ["Create a function to reverse a string"],
            ["Write a function to find the factorial of a number"],
            ["Implement binary search on a sorted list"],
            ["Create a function to merge two sorted lists"],
            ["Write a function to calculate Fibonacci numbers"],
            ["Implement a function to find the longest common subsequence"],
            ["Create a function to validate an email address using regex"],
            ["Write a function to convert a decimal number to binary"],
            ["Implement a simple LRU cache using OrderedDict"],
        ],
        inputs=[instruction_input],
        label="Example Prompts (Click to use)"
    )

    # Event handler
    generate_btn.click(
        fn=generate_code,
        inputs=[instruction_input, max_tokens_slider, temperature_slider],
        outputs=[output_code],
    )

    # Model information footer
    gr.Markdown(
        """
        ---

        ### 📊 Model Performance

        | Metric | Baseline (Pretrained) | Fine-Tuned (Actual) | Improvement |
        |--------|----------------------|---------------------|-------------|
        | **BLEU Score** | 11.00 | **16.83** | **+53%** ✅ |
        | **Syntax Correctness** | 81% | 76% | -5% |
        | **Trainable Parameters** | 2.5B | 3.2M (0.12%) | 100x fewer |

        *Evaluated on 100 multi-language test samples (72 Python, 28 other languages)*

        ### 🛠️ Technical Details

        - **Base Model**: google/gemma-2-2b-it (2.5B parameters)
        - **Fine-tuning**: QLoRA (4-bit quantization + LoRA rank 16)
        - **Dataset**: CodeAlpaca-20k (18,000 training examples)
        - **Checkpoint**: Step 2000 (~1.8 epochs, selected for best BLEU score)
        - **Training**: 10-15 hours on free Google Colab T4 GPU
        - **Cost**: $0 (free Colab + free HF Spaces hosting)

        ### 🔗 Links

        [Model on HuggingFace](https://huggingface.co/nvhuynh16/gemma-2b-code-alpaca-best) •
        [GitHub Repository](https://github.com/nvhuynh16/Gemma-Code-Fine-Tuning) •
        [Portfolio](https://portfolio-nvhuynh.vercel.app) •
        [Base Model](https://huggingface.co/google/gemma-2-2b-it)

        ---

        **Built for portfolio demonstration** • Targeting AI/ML Applied Scientist roles

        *This demo uses HuggingFace Inference API for serverless, cost-free inference*
        """
    )

if __name__ == "__main__":
    demo.launch()