Spaces:
Sleeping
Sleeping
File size: 11,801 Bytes
6fc6360 add12a3 9784a84 6fc6360 add12a3 9784a84 add12a3 6fc6360 add12a3 9784a84 26a4c95 add12a3 6fc6360 add12a3 6fc6360 26a4c95 9784a84 6fc6360 26a4c95 9784a84 add12a3 6fc6360 add12a3 9784a84 add12a3 9784a84 add12a3 9784a84 add12a3 9784a84 6fc6360 26a4c95 6fc6360 9784a84 6fc6360 1e5532c 6fc6360 1e5532c 6fc6360 1e5532c 26a4c95 6fc6360 1e5532c 6fc6360 1e5532c 6fc6360 1e5532c 6fc6360 1e5532c 6fc6360 1e5532c 6fc6360 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 |
"""
Gradio demo for Gemma Code Generator.
Loads the fine-tuned model directly using PEFT.
Includes production monitoring and logging.
"""
import gradio as gr
import torch
import os
import json
import ast
import time
from datetime import datetime
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
# Model configuration
BASE_MODEL = "google/gemma-2-2b-it"
ADAPTER_MODEL = "nvhuynh16/gemma-2b-code-alpaca-best"
HF_TOKEN = os.environ.get("HF_TOKEN", None)
# Global variables for lazy loading
tokenizer = None
model = None
# Production monitoring
ENABLE_LOGGING = os.environ.get("ENABLE_LOGGING", "true").lower() == "true"
LOG_FILE = "production_logs.jsonl"
def check_syntax(code: str) -> bool:
"""Check if generated code has valid Python syntax"""
try:
ast.parse(code)
return True
except:
return False
def log_request(instruction: str, generated_code: str, tokens_generated: int, latency: float, error: str = None):
"""Log request for production monitoring"""
if not ENABLE_LOGGING:
return
log_entry = {
"timestamp": datetime.now().isoformat(),
"instruction_length": len(instruction),
"response_length": len(generated_code),
"tokens_generated": tokens_generated,
"latency_seconds": round(latency, 2),
"has_syntax_error": not check_syntax(generated_code) if generated_code and not error else True,
"error": error,
}
try:
with open(LOG_FILE, "a", encoding="utf-8") as f:
f.write(json.dumps(log_entry) + "\n")
except Exception as e:
print(f"Logging failed: {e}")
# Safety filters - Layer 1: Input Validation
DANGEROUS_KEYWORDS = [
"delete all files", "rm -rf", "shutil.rmtree",
"sql injection", "drop table", "truncate table",
"keylogger", "backdoor", "exploit",
"hack into", "steal password", "crack password",
"ddos", "denial of service", "fork bomb",
"malware", "ransomware", "trojan"
]
def validate_input(instruction: str) -> tuple:
"""
Validate input for dangerous keywords.
Returns: (is_valid: bool, error_message: str)
"""
instruction_lower = instruction.lower()
for keyword in DANGEROUS_KEYWORDS:
if keyword in instruction_lower:
return False, f"⚠️ Safety Filter: Request blocked. Your instruction contains potentially unsafe content related to '{keyword}'.\n\nPlease rephrase your request to focus on legitimate programming tasks."
return True, ""
# Safety filters - Layer 2: Output Filtering
DANGEROUS_PATTERNS = [
("os.remove", "file deletion"),
("shutil.rmtree", "directory deletion"),
("os.unlink", "file deletion"),
("DROP TABLE", "database destruction"),
("TRUNCATE TABLE", "database destruction"),
("DELETE FROM", "database deletion"),
("eval(", "arbitrary code execution"),
("exec(", "arbitrary code execution"),
("__import__", "dynamic imports"),
("os.system", "system command execution"),
("subprocess.call", "system command execution"),
("subprocess.run", "system command execution"),
]
def filter_dangerous_code(code: str) -> str:
"""
Filter dangerous code patterns from output.
Returns: filtered code or safety warning
"""
code_lower = code.lower()
for pattern, reason in DANGEROUS_PATTERNS:
if pattern.lower() in code_lower:
return f"""# ⚠️ SAFETY FILTER ACTIVATED
#
# Code generation blocked: Potentially dangerous pattern detected ({reason})
# Pattern: {pattern}
#
# This is a safety feature to prevent generating code that could:
# - Delete files or data
# - Execute arbitrary system commands
# - Compromise system security
#
# Please rephrase your request with safer requirements.
# For educational purposes, consult official documentation or security resources.
"""
return code
def load_model():
"""Lazy load model on first request"""
global tokenizer, model
if model is None:
print("Loading model for the first time...")
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, token=HF_TOKEN)
# Load base model with 4-bit quantization
base_model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL,
device_map="auto",
torch_dtype=torch.float16,
load_in_4bit=True,
token=HF_TOKEN
)
# Load LoRA adapter
model = PeftModel.from_pretrained(base_model, ADAPTER_MODEL, token=HF_TOKEN)
model.eval()
print("Model loaded successfully!")
return tokenizer, model
def generate_code(instruction: str, max_tokens: int = 256, temperature: float = 0.7):
"""Generate code from instruction with monitoring and safety filters"""
start_time = time.time()
if not instruction.strip():
return "Please enter an instruction."
# Layer 1: Input validation
is_valid, validation_error = validate_input(instruction)
if not is_valid:
# Log blocked request
log_request(instruction, validation_error, 0, time.time() - start_time, "BLOCKED_BY_SAFETY_FILTER")
return validation_error
generated_code = ""
tokens_generated = 0
error = None
try:
# Load model (cached after first call)
tok, mdl = load_model()
# Format prompt in Alpaca style
prompt = f"""### Instruction:
{instruction}
### Input:
### Response:
"""
# Tokenize
inputs = tok(prompt, return_tensors="pt").to(mdl.device)
input_length = len(inputs.input_ids[0])
# Generate
with torch.no_grad():
outputs = mdl.generate(
**inputs,
max_new_tokens=max_tokens,
temperature=temperature,
top_p=0.9,
do_sample=True,
pad_token_id=tok.eos_token_id,
)
# Calculate tokens generated
tokens_generated = len(outputs[0]) - input_length
# Decode
generated = tok.decode(outputs[0], skip_special_tokens=True)
# Extract code after "### Response:"
if "### Response:" in generated:
generated_code = generated.split("### Response:")[-1].strip()
else:
generated_code = generated.strip()
# Layer 2: Output filtering for dangerous patterns
generated_code = filter_dangerous_code(generated_code)
except Exception as e:
error = str(e)
generated_code = f"Error: {error}\n\nPlease try again."
finally:
# Log request
latency = time.time() - start_time
log_request(instruction, generated_code, tokens_generated, latency, error)
return generated_code
# Custom CSS for better appearance
custom_css = """
.container {
max-width: 900px;
margin: auto;
}
.output-code {
font-family: 'Courier New', monospace;
font-size: 14px;
}
"""
# Create Gradio interface
with gr.Blocks(theme=gr.themes.Soft(), css=custom_css) as demo:
gr.Markdown(
"""
# 🤖 Gemma Code Generator
Fine-tuned Gemma-2B model for multi-language code generation using QLoRA.
**Performance**: 76% syntax correctness | **BLEU Score: 16.83** (+53% improvement over baseline 11.00)
**Note**: First request may take 1-2 minutes as the model loads on HuggingFace servers. Subsequent requests are instant!
---
### 🛡️ Safety Features
This demo includes production-grade safety filters:
- **Input Validation**: Blocks requests with potentially dangerous keywords
- **Output Filtering**: Prevents generation of code that could delete files, execute arbitrary commands, or compromise security
- **Production Monitoring**: All requests are logged for quality tracking (privacy-respecting, no personal data stored)
⚠️ **AI-Generated Code Disclaimer**: Always review generated code before use. AI models can make mistakes.
"""
)
with gr.Row():
with gr.Column(scale=1):
instruction_input = gr.Textbox(
label="Code Instruction",
placeholder="Describe the function you want to create...",
lines=3,
)
with gr.Accordion("Advanced Settings", open=False):
max_tokens_slider = gr.Slider(
minimum=64,
maximum=512,
value=256,
step=64,
label="Max Tokens",
info="Maximum length of generated code"
)
temperature_slider = gr.Slider(
minimum=0.1,
maximum=1.5,
value=0.7,
step=0.1,
label="Temperature",
info="Higher = more creative, Lower = more deterministic"
)
generate_btn = gr.Button("Generate Code", variant="primary", size="lg")
with gr.Column(scale=1):
output_code = gr.Code(
label="Generated Code",
language="python",
elem_classes="output-code"
)
# Examples
gr.Examples(
examples=[
["Write a function to check if a number is prime"],
["Create a function to reverse a string"],
["Write a function to find the factorial of a number"],
["Implement binary search on a sorted list"],
["Create a function to merge two sorted lists"],
["Write a function to calculate Fibonacci numbers"],
["Implement a function to find the longest common subsequence"],
["Create a function to validate an email address using regex"],
["Write a function to convert a decimal number to binary"],
["Implement a simple LRU cache using OrderedDict"],
],
inputs=[instruction_input],
label="Example Prompts (Click to use)"
)
# Event handler
generate_btn.click(
fn=generate_code,
inputs=[instruction_input, max_tokens_slider, temperature_slider],
outputs=[output_code],
)
# Model information footer
gr.Markdown(
"""
---
### 📊 Model Performance
| Metric | Baseline (Pretrained) | Fine-Tuned (Actual) | Improvement |
|--------|----------------------|---------------------|-------------|
| **BLEU Score** | 11.00 | **16.83** | **+53%** ✅ |
| **Syntax Correctness** | 81% | 76% | -5% |
| **Trainable Parameters** | 2.5B | 3.2M (0.12%) | 100x fewer |
*Evaluated on 100 multi-language test samples (72 Python, 28 other languages)*
### 🛠️ Technical Details
- **Base Model**: google/gemma-2-2b-it (2.5B parameters)
- **Fine-tuning**: QLoRA (4-bit quantization + LoRA rank 16)
- **Dataset**: CodeAlpaca-20k (18,000 training examples)
- **Checkpoint**: Step 2000 (~1.8 epochs, selected for best BLEU score)
- **Training**: 10-15 hours on free Google Colab T4 GPU
- **Cost**: $0 (free Colab + free HF Spaces hosting)
### 🔗 Links
[Model on HuggingFace](https://huggingface.co/nvhuynh16/gemma-2b-code-alpaca-best) •
[GitHub Repository](https://github.com/nvhuynh16/Gemma-Code-Fine-Tuning) •
[Portfolio](https://portfolio-nvhuynh.vercel.app) •
[Base Model](https://huggingface.co/google/gemma-2-2b-it)
---
**Built for portfolio demonstration** • Targeting AI/ML Applied Scientist roles
*This demo uses HuggingFace Inference API for serverless, cost-free inference*
"""
)
if __name__ == "__main__":
demo.launch()
|