Spaces:
Sleeping
Sleeping
| """ | |
| Gradio demo for Gemma Code Generator. | |
| Loads the fine-tuned model directly using PEFT. | |
| Includes production monitoring and logging. | |
| """ | |
| import gradio as gr | |
| import torch | |
| import os | |
| import json | |
| import ast | |
| import time | |
| from datetime import datetime | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| from peft import PeftModel | |
| # Model configuration | |
| BASE_MODEL = "google/gemma-2-2b-it" | |
| ADAPTER_MODEL = "nvhuynh16/gemma-2b-code-alpaca-best" | |
| HF_TOKEN = os.environ.get("HF_TOKEN", None) | |
| # Global variables for lazy loading | |
| tokenizer = None | |
| model = None | |
| # Production monitoring | |
| ENABLE_LOGGING = os.environ.get("ENABLE_LOGGING", "true").lower() == "true" | |
| LOG_FILE = "production_logs.jsonl" | |
| def check_syntax(code: str) -> bool: | |
| """Check if generated code has valid Python syntax""" | |
| try: | |
| ast.parse(code) | |
| return True | |
| except: | |
| return False | |
| def log_request(instruction: str, generated_code: str, tokens_generated: int, latency: float, error: str = None): | |
| """Log request for production monitoring""" | |
| if not ENABLE_LOGGING: | |
| return | |
| log_entry = { | |
| "timestamp": datetime.now().isoformat(), | |
| "instruction_length": len(instruction), | |
| "response_length": len(generated_code), | |
| "tokens_generated": tokens_generated, | |
| "latency_seconds": round(latency, 2), | |
| "has_syntax_error": not check_syntax(generated_code) if generated_code and not error else True, | |
| "error": error, | |
| } | |
| try: | |
| with open(LOG_FILE, "a", encoding="utf-8") as f: | |
| f.write(json.dumps(log_entry) + "\n") | |
| except Exception as e: | |
| print(f"Logging failed: {e}") | |
| # Safety filters - Layer 1: Input Validation | |
| DANGEROUS_KEYWORDS = [ | |
| "delete all files", "rm -rf", "shutil.rmtree", | |
| "sql injection", "drop table", "truncate table", | |
| "keylogger", "backdoor", "exploit", | |
| "hack into", "steal password", "crack password", | |
| "ddos", "denial of service", "fork bomb", | |
| "malware", "ransomware", "trojan" | |
| ] | |
| def validate_input(instruction: str) -> tuple: | |
| """ | |
| Validate input for dangerous keywords. | |
| Returns: (is_valid: bool, error_message: str) | |
| """ | |
| instruction_lower = instruction.lower() | |
| for keyword in DANGEROUS_KEYWORDS: | |
| if keyword in instruction_lower: | |
| return False, f"⚠️ Safety Filter: Request blocked. Your instruction contains potentially unsafe content related to '{keyword}'.\n\nPlease rephrase your request to focus on legitimate programming tasks." | |
| return True, "" | |
| # Safety filters - Layer 2: Output Filtering | |
| DANGEROUS_PATTERNS = [ | |
| ("os.remove", "file deletion"), | |
| ("shutil.rmtree", "directory deletion"), | |
| ("os.unlink", "file deletion"), | |
| ("DROP TABLE", "database destruction"), | |
| ("TRUNCATE TABLE", "database destruction"), | |
| ("DELETE FROM", "database deletion"), | |
| ("eval(", "arbitrary code execution"), | |
| ("exec(", "arbitrary code execution"), | |
| ("__import__", "dynamic imports"), | |
| ("os.system", "system command execution"), | |
| ("subprocess.call", "system command execution"), | |
| ("subprocess.run", "system command execution"), | |
| ] | |
| def filter_dangerous_code(code: str) -> str: | |
| """ | |
| Filter dangerous code patterns from output. | |
| Returns: filtered code or safety warning | |
| """ | |
| code_lower = code.lower() | |
| for pattern, reason in DANGEROUS_PATTERNS: | |
| if pattern.lower() in code_lower: | |
| return f"""# ⚠️ SAFETY FILTER ACTIVATED | |
| # | |
| # Code generation blocked: Potentially dangerous pattern detected ({reason}) | |
| # Pattern: {pattern} | |
| # | |
| # This is a safety feature to prevent generating code that could: | |
| # - Delete files or data | |
| # - Execute arbitrary system commands | |
| # - Compromise system security | |
| # | |
| # Please rephrase your request with safer requirements. | |
| # For educational purposes, consult official documentation or security resources. | |
| """ | |
| return code | |
| def load_model(): | |
| """Lazy load model on first request""" | |
| global tokenizer, model | |
| if model is None: | |
| print("Loading model for the first time...") | |
| # Load tokenizer | |
| tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, token=HF_TOKEN) | |
| # Load base model with 4-bit quantization | |
| base_model = AutoModelForCausalLM.from_pretrained( | |
| BASE_MODEL, | |
| device_map="auto", | |
| torch_dtype=torch.float16, | |
| load_in_4bit=True, | |
| token=HF_TOKEN | |
| ) | |
| # Load LoRA adapter | |
| model = PeftModel.from_pretrained(base_model, ADAPTER_MODEL, token=HF_TOKEN) | |
| model.eval() | |
| print("Model loaded successfully!") | |
| return tokenizer, model | |
| def generate_code(instruction: str, max_tokens: int = 256, temperature: float = 0.7): | |
| """Generate code from instruction with monitoring and safety filters""" | |
| start_time = time.time() | |
| if not instruction.strip(): | |
| return "Please enter an instruction." | |
| # Layer 1: Input validation | |
| is_valid, validation_error = validate_input(instruction) | |
| if not is_valid: | |
| # Log blocked request | |
| log_request(instruction, validation_error, 0, time.time() - start_time, "BLOCKED_BY_SAFETY_FILTER") | |
| return validation_error | |
| generated_code = "" | |
| tokens_generated = 0 | |
| error = None | |
| try: | |
| # Load model (cached after first call) | |
| tok, mdl = load_model() | |
| # Format prompt in Alpaca style | |
| prompt = f"""### Instruction: | |
| {instruction} | |
| ### Input: | |
| ### Response: | |
| """ | |
| # Tokenize | |
| inputs = tok(prompt, return_tensors="pt").to(mdl.device) | |
| input_length = len(inputs.input_ids[0]) | |
| # Generate | |
| with torch.no_grad(): | |
| outputs = mdl.generate( | |
| **inputs, | |
| max_new_tokens=max_tokens, | |
| temperature=temperature, | |
| top_p=0.9, | |
| do_sample=True, | |
| pad_token_id=tok.eos_token_id, | |
| ) | |
| # Calculate tokens generated | |
| tokens_generated = len(outputs[0]) - input_length | |
| # Decode | |
| generated = tok.decode(outputs[0], skip_special_tokens=True) | |
| # Extract code after "### Response:" | |
| if "### Response:" in generated: | |
| generated_code = generated.split("### Response:")[-1].strip() | |
| else: | |
| generated_code = generated.strip() | |
| # Layer 2: Output filtering for dangerous patterns | |
| generated_code = filter_dangerous_code(generated_code) | |
| except Exception as e: | |
| error = str(e) | |
| generated_code = f"Error: {error}\n\nPlease try again." | |
| finally: | |
| # Log request | |
| latency = time.time() - start_time | |
| log_request(instruction, generated_code, tokens_generated, latency, error) | |
| return generated_code | |
| # Custom CSS for better appearance | |
| custom_css = """ | |
| .container { | |
| max-width: 900px; | |
| margin: auto; | |
| } | |
| .output-code { | |
| font-family: 'Courier New', monospace; | |
| font-size: 14px; | |
| } | |
| """ | |
| # Create Gradio interface | |
| with gr.Blocks(theme=gr.themes.Soft(), css=custom_css) as demo: | |
| gr.Markdown( | |
| """ | |
| # 🤖 Gemma Code Generator | |
| Fine-tuned Gemma-2B model for multi-language code generation using QLoRA. | |
| **Performance**: 76% syntax correctness | **BLEU Score: 16.83** (+53% improvement over baseline 11.00) | |
| **Note**: First request may take 1-2 minutes as the model loads on HuggingFace servers. Subsequent requests are instant! | |
| --- | |
| ### 🛡️ Safety Features | |
| This demo includes production-grade safety filters: | |
| - **Input Validation**: Blocks requests with potentially dangerous keywords | |
| - **Output Filtering**: Prevents generation of code that could delete files, execute arbitrary commands, or compromise security | |
| - **Production Monitoring**: All requests are logged for quality tracking (privacy-respecting, no personal data stored) | |
| ⚠️ **AI-Generated Code Disclaimer**: Always review generated code before use. AI models can make mistakes. | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| instruction_input = gr.Textbox( | |
| label="Code Instruction", | |
| placeholder="Describe the function you want to create...", | |
| lines=3, | |
| ) | |
| with gr.Accordion("Advanced Settings", open=False): | |
| max_tokens_slider = gr.Slider( | |
| minimum=64, | |
| maximum=512, | |
| value=256, | |
| step=64, | |
| label="Max Tokens", | |
| info="Maximum length of generated code" | |
| ) | |
| temperature_slider = gr.Slider( | |
| minimum=0.1, | |
| maximum=1.5, | |
| value=0.7, | |
| step=0.1, | |
| label="Temperature", | |
| info="Higher = more creative, Lower = more deterministic" | |
| ) | |
| generate_btn = gr.Button("Generate Code", variant="primary", size="lg") | |
| with gr.Column(scale=1): | |
| output_code = gr.Code( | |
| label="Generated Code", | |
| language="python", | |
| elem_classes="output-code" | |
| ) | |
| # Examples | |
| gr.Examples( | |
| examples=[ | |
| ["Write a function to check if a number is prime"], | |
| ["Create a function to reverse a string"], | |
| ["Write a function to find the factorial of a number"], | |
| ["Implement binary search on a sorted list"], | |
| ["Create a function to merge two sorted lists"], | |
| ["Write a function to calculate Fibonacci numbers"], | |
| ["Implement a function to find the longest common subsequence"], | |
| ["Create a function to validate an email address using regex"], | |
| ["Write a function to convert a decimal number to binary"], | |
| ["Implement a simple LRU cache using OrderedDict"], | |
| ], | |
| inputs=[instruction_input], | |
| label="Example Prompts (Click to use)" | |
| ) | |
| # Event handler | |
| generate_btn.click( | |
| fn=generate_code, | |
| inputs=[instruction_input, max_tokens_slider, temperature_slider], | |
| outputs=[output_code], | |
| ) | |
| # Model information footer | |
| gr.Markdown( | |
| """ | |
| --- | |
| ### 📊 Model Performance | |
| | Metric | Baseline (Pretrained) | Fine-Tuned (Actual) | Improvement | | |
| |--------|----------------------|---------------------|-------------| | |
| | **BLEU Score** | 11.00 | **16.83** | **+53%** ✅ | | |
| | **Syntax Correctness** | 81% | 76% | -5% | | |
| | **Trainable Parameters** | 2.5B | 3.2M (0.12%) | 100x fewer | | |
| *Evaluated on 100 multi-language test samples (72 Python, 28 other languages)* | |
| ### 🛠️ Technical Details | |
| - **Base Model**: google/gemma-2-2b-it (2.5B parameters) | |
| - **Fine-tuning**: QLoRA (4-bit quantization + LoRA rank 16) | |
| - **Dataset**: CodeAlpaca-20k (18,000 training examples) | |
| - **Checkpoint**: Step 2000 (~1.8 epochs, selected for best BLEU score) | |
| - **Training**: 10-15 hours on free Google Colab T4 GPU | |
| - **Cost**: $0 (free Colab + free HF Spaces hosting) | |
| ### 🔗 Links | |
| [Model on HuggingFace](https://huggingface.co/nvhuynh16/gemma-2b-code-alpaca-best) • | |
| [GitHub Repository](https://github.com/nvhuynh16/Gemma-Code-Fine-Tuning) • | |
| [Portfolio](https://portfolio-nvhuynh.vercel.app) • | |
| [Base Model](https://huggingface.co/google/gemma-2-2b-it) | |
| --- | |
| **Built for portfolio demonstration** • Targeting AI/ML Applied Scientist roles | |
| *This demo uses HuggingFace Inference API for serverless, cost-free inference* | |
| """ | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |