nvhuynh16's picture
Update app.py
26a4c95 verified
"""
Gradio demo for Gemma Code Generator.
Loads the fine-tuned model directly using PEFT.
Includes production monitoring and logging.
"""
import gradio as gr
import torch
import os
import json
import ast
import time
from datetime import datetime
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
# Model configuration
BASE_MODEL = "google/gemma-2-2b-it"
ADAPTER_MODEL = "nvhuynh16/gemma-2b-code-alpaca-best"
HF_TOKEN = os.environ.get("HF_TOKEN", None)
# Global variables for lazy loading
tokenizer = None
model = None
# Production monitoring
ENABLE_LOGGING = os.environ.get("ENABLE_LOGGING", "true").lower() == "true"
LOG_FILE = "production_logs.jsonl"
def check_syntax(code: str) -> bool:
"""Check if generated code has valid Python syntax"""
try:
ast.parse(code)
return True
except:
return False
def log_request(instruction: str, generated_code: str, tokens_generated: int, latency: float, error: str = None):
"""Log request for production monitoring"""
if not ENABLE_LOGGING:
return
log_entry = {
"timestamp": datetime.now().isoformat(),
"instruction_length": len(instruction),
"response_length": len(generated_code),
"tokens_generated": tokens_generated,
"latency_seconds": round(latency, 2),
"has_syntax_error": not check_syntax(generated_code) if generated_code and not error else True,
"error": error,
}
try:
with open(LOG_FILE, "a", encoding="utf-8") as f:
f.write(json.dumps(log_entry) + "\n")
except Exception as e:
print(f"Logging failed: {e}")
# Safety filters - Layer 1: Input Validation
DANGEROUS_KEYWORDS = [
"delete all files", "rm -rf", "shutil.rmtree",
"sql injection", "drop table", "truncate table",
"keylogger", "backdoor", "exploit",
"hack into", "steal password", "crack password",
"ddos", "denial of service", "fork bomb",
"malware", "ransomware", "trojan"
]
def validate_input(instruction: str) -> tuple:
"""
Validate input for dangerous keywords.
Returns: (is_valid: bool, error_message: str)
"""
instruction_lower = instruction.lower()
for keyword in DANGEROUS_KEYWORDS:
if keyword in instruction_lower:
return False, f"⚠️ Safety Filter: Request blocked. Your instruction contains potentially unsafe content related to '{keyword}'.\n\nPlease rephrase your request to focus on legitimate programming tasks."
return True, ""
# Safety filters - Layer 2: Output Filtering
DANGEROUS_PATTERNS = [
("os.remove", "file deletion"),
("shutil.rmtree", "directory deletion"),
("os.unlink", "file deletion"),
("DROP TABLE", "database destruction"),
("TRUNCATE TABLE", "database destruction"),
("DELETE FROM", "database deletion"),
("eval(", "arbitrary code execution"),
("exec(", "arbitrary code execution"),
("__import__", "dynamic imports"),
("os.system", "system command execution"),
("subprocess.call", "system command execution"),
("subprocess.run", "system command execution"),
]
def filter_dangerous_code(code: str) -> str:
"""
Filter dangerous code patterns from output.
Returns: filtered code or safety warning
"""
code_lower = code.lower()
for pattern, reason in DANGEROUS_PATTERNS:
if pattern.lower() in code_lower:
return f"""# ⚠️ SAFETY FILTER ACTIVATED
#
# Code generation blocked: Potentially dangerous pattern detected ({reason})
# Pattern: {pattern}
#
# This is a safety feature to prevent generating code that could:
# - Delete files or data
# - Execute arbitrary system commands
# - Compromise system security
#
# Please rephrase your request with safer requirements.
# For educational purposes, consult official documentation or security resources.
"""
return code
def load_model():
"""Lazy load model on first request"""
global tokenizer, model
if model is None:
print("Loading model for the first time...")
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, token=HF_TOKEN)
# Load base model with 4-bit quantization
base_model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL,
device_map="auto",
torch_dtype=torch.float16,
load_in_4bit=True,
token=HF_TOKEN
)
# Load LoRA adapter
model = PeftModel.from_pretrained(base_model, ADAPTER_MODEL, token=HF_TOKEN)
model.eval()
print("Model loaded successfully!")
return tokenizer, model
def generate_code(instruction: str, max_tokens: int = 256, temperature: float = 0.7):
"""Generate code from instruction with monitoring and safety filters"""
start_time = time.time()
if not instruction.strip():
return "Please enter an instruction."
# Layer 1: Input validation
is_valid, validation_error = validate_input(instruction)
if not is_valid:
# Log blocked request
log_request(instruction, validation_error, 0, time.time() - start_time, "BLOCKED_BY_SAFETY_FILTER")
return validation_error
generated_code = ""
tokens_generated = 0
error = None
try:
# Load model (cached after first call)
tok, mdl = load_model()
# Format prompt in Alpaca style
prompt = f"""### Instruction:
{instruction}
### Input:
### Response:
"""
# Tokenize
inputs = tok(prompt, return_tensors="pt").to(mdl.device)
input_length = len(inputs.input_ids[0])
# Generate
with torch.no_grad():
outputs = mdl.generate(
**inputs,
max_new_tokens=max_tokens,
temperature=temperature,
top_p=0.9,
do_sample=True,
pad_token_id=tok.eos_token_id,
)
# Calculate tokens generated
tokens_generated = len(outputs[0]) - input_length
# Decode
generated = tok.decode(outputs[0], skip_special_tokens=True)
# Extract code after "### Response:"
if "### Response:" in generated:
generated_code = generated.split("### Response:")[-1].strip()
else:
generated_code = generated.strip()
# Layer 2: Output filtering for dangerous patterns
generated_code = filter_dangerous_code(generated_code)
except Exception as e:
error = str(e)
generated_code = f"Error: {error}\n\nPlease try again."
finally:
# Log request
latency = time.time() - start_time
log_request(instruction, generated_code, tokens_generated, latency, error)
return generated_code
# Custom CSS for better appearance
custom_css = """
.container {
max-width: 900px;
margin: auto;
}
.output-code {
font-family: 'Courier New', monospace;
font-size: 14px;
}
"""
# Create Gradio interface
with gr.Blocks(theme=gr.themes.Soft(), css=custom_css) as demo:
gr.Markdown(
"""
# 🤖 Gemma Code Generator
Fine-tuned Gemma-2B model for multi-language code generation using QLoRA.
**Performance**: 76% syntax correctness | **BLEU Score: 16.83** (+53% improvement over baseline 11.00)
**Note**: First request may take 1-2 minutes as the model loads on HuggingFace servers. Subsequent requests are instant!
---
### 🛡️ Safety Features
This demo includes production-grade safety filters:
- **Input Validation**: Blocks requests with potentially dangerous keywords
- **Output Filtering**: Prevents generation of code that could delete files, execute arbitrary commands, or compromise security
- **Production Monitoring**: All requests are logged for quality tracking (privacy-respecting, no personal data stored)
⚠️ **AI-Generated Code Disclaimer**: Always review generated code before use. AI models can make mistakes.
"""
)
with gr.Row():
with gr.Column(scale=1):
instruction_input = gr.Textbox(
label="Code Instruction",
placeholder="Describe the function you want to create...",
lines=3,
)
with gr.Accordion("Advanced Settings", open=False):
max_tokens_slider = gr.Slider(
minimum=64,
maximum=512,
value=256,
step=64,
label="Max Tokens",
info="Maximum length of generated code"
)
temperature_slider = gr.Slider(
minimum=0.1,
maximum=1.5,
value=0.7,
step=0.1,
label="Temperature",
info="Higher = more creative, Lower = more deterministic"
)
generate_btn = gr.Button("Generate Code", variant="primary", size="lg")
with gr.Column(scale=1):
output_code = gr.Code(
label="Generated Code",
language="python",
elem_classes="output-code"
)
# Examples
gr.Examples(
examples=[
["Write a function to check if a number is prime"],
["Create a function to reverse a string"],
["Write a function to find the factorial of a number"],
["Implement binary search on a sorted list"],
["Create a function to merge two sorted lists"],
["Write a function to calculate Fibonacci numbers"],
["Implement a function to find the longest common subsequence"],
["Create a function to validate an email address using regex"],
["Write a function to convert a decimal number to binary"],
["Implement a simple LRU cache using OrderedDict"],
],
inputs=[instruction_input],
label="Example Prompts (Click to use)"
)
# Event handler
generate_btn.click(
fn=generate_code,
inputs=[instruction_input, max_tokens_slider, temperature_slider],
outputs=[output_code],
)
# Model information footer
gr.Markdown(
"""
---
### 📊 Model Performance
| Metric | Baseline (Pretrained) | Fine-Tuned (Actual) | Improvement |
|--------|----------------------|---------------------|-------------|
| **BLEU Score** | 11.00 | **16.83** | **+53%** ✅ |
| **Syntax Correctness** | 81% | 76% | -5% |
| **Trainable Parameters** | 2.5B | 3.2M (0.12%) | 100x fewer |
*Evaluated on 100 multi-language test samples (72 Python, 28 other languages)*
### 🛠️ Technical Details
- **Base Model**: google/gemma-2-2b-it (2.5B parameters)
- **Fine-tuning**: QLoRA (4-bit quantization + LoRA rank 16)
- **Dataset**: CodeAlpaca-20k (18,000 training examples)
- **Checkpoint**: Step 2000 (~1.8 epochs, selected for best BLEU score)
- **Training**: 10-15 hours on free Google Colab T4 GPU
- **Cost**: $0 (free Colab + free HF Spaces hosting)
### 🔗 Links
[Model on HuggingFace](https://huggingface.co/nvhuynh16/gemma-2b-code-alpaca-best) •
[GitHub Repository](https://github.com/nvhuynh16/Gemma-Code-Fine-Tuning) •
[Portfolio](https://portfolio-nvhuynh.vercel.app) •
[Base Model](https://huggingface.co/google/gemma-2-2b-it)
---
**Built for portfolio demonstration** • Targeting AI/ML Applied Scientist roles
*This demo uses HuggingFace Inference API for serverless, cost-free inference*
"""
)
if __name__ == "__main__":
demo.launch()