from flask import Flask, render_template, request, jsonify, stream_template from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline import torch import gc import threading import time import os from datetime import datetime import json import logging # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) app = Flask(__name__) app.secret_key = os.urandom(24) class CodeLlamaService: def __init__(self): self.model = None self.tokenizer = None self.pipeline = None self.is_loading = False self.is_loaded = False self.load_lock = threading.Lock() def load_model(self): """Load Code Llama model with memory optimization for HF Spaces""" if self.is_loaded or self.is_loading: return with self.load_lock: if self.is_loaded or self.is_loading: return self.is_loading = True logger.info("Loading Code Llama model...") try: # Use the smallest Code Llama model that fits in 16GB model_name = "codellama/CodeLlama-7b-Instruct-hf" # Check if CUDA is available device = "cuda" if torch.cuda.is_available() else "cpu" logger.info(f"Using device: {device}") # Load tokenizer self.tokenizer = AutoTokenizer.from_pretrained( model_name, use_fast=True, trust_remote_code=True ) # Configure model loading based on device if device == "cuda": # GPU: Use float16 for memory efficiency self.model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=torch.float16, low_cpu_mem_usage=True, trust_remote_code=True, device_map="auto" ) torch_dtype = torch.float16 else: # CPU: Use float32 to avoid Half precision errors self.model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=torch.float32, low_cpu_mem_usage=True, trust_remote_code=True ) # Move model to CPU explicitly self.model = self.model.to('cpu') torch_dtype = torch.float32 # Create pipeline with appropriate settings if device == "cuda": self.pipeline = pipeline( "text-generation", model=self.model, tokenizer=self.tokenizer, torch_dtype=torch_dtype, device=0 # GPU device ) else: self.pipeline = pipeline( "text-generation", model=self.model, tokenizer=self.tokenizer, device=-1 # CPU device ) self.is_loaded = True logger.info("Model loaded successfully!") except Exception as e: logger.error(f"Error loading model: {str(e)}") self.is_loaded = False # Clean up on failure if hasattr(self, 'model') and self.model is not None: del self.model if hasattr(self, 'tokenizer') and self.tokenizer is not None: del self.tokenizer if hasattr(self, 'pipeline') and self.pipeline is not None: del self.pipeline gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() finally: self.is_loading = False def generate_code(self, prompt, max_length=1024, temperature=0.3): """Generate code based on prompt""" if not self.is_loaded: return {"error": "Model not loaded", "code": "", "explanation": ""} try: # Format prompt for instruction following formatted_prompt = f"[INST] {prompt} [/INST]" # Generate response with error handling generation_kwargs = { "max_new_tokens": max_length, "do_sample": True if temperature > 0 else False, "temperature": temperature if temperature > 0 else None, "top_p": 0.9 if temperature > 0 else None, "repetition_penalty": 1.1, "return_full_text": False, "pad_token_id": self.tokenizer.eos_token_id } # Remove None values to avoid warnings generation_kwargs = {k: v for k, v in generation_kwargs.items() if v is not None} outputs = self.pipeline(formatted_prompt, **generation_kwargs) # Extract generated text if isinstance(outputs, list) and len(outputs) > 0: if 'generated_text' in outputs[0]: response = outputs[0]['generated_text'] else: response = str(outputs[0]) else: response = str(outputs) response = response.strip() # Split response into code and explanation if possible code, explanation = self._parse_response(response) return { "success": True, "code": code, "explanation": explanation, "full_response": response } except Exception as e: logger.error(f"Error generating code: {str(e)}") return {"error": str(e), "code": "", "explanation": ""} def _parse_response(self, response): """Parse response to separate code and explanation""" # Try to find code blocks if "```" in response: parts = response.split("```") code_parts = [] explanation_parts = [] for i, part in enumerate(parts): if i % 2 == 1: # Odd indices are code blocks # Remove language identifier if present lines = part.strip().split('\n') if lines and any(lang in lines[0].lower() for lang in ['python', 'javascript', 'java', 'cpp', 'c++', 'html', 'css']): code_parts.append('\n'.join(lines[1:])) else: code_parts.append(part.strip()) else: # Even indices are explanations if part.strip(): explanation_parts.append(part.strip()) code = '\n\n'.join(code_parts) explanation = '\n\n'.join(explanation_parts) else: # If no code blocks, try to separate by common patterns lines = response.split('\n') code_lines = [] explanation_lines = [] in_code_block = False for line in lines: # Simple heuristic to detect code vs explanation if (line.strip().startswith(('def ', 'class ', 'import ', 'from ', 'if ', 'for ', 'while ', 'function', 'var ', 'let ', 'const ')) or line.startswith((' ', '\t')) or ('=' in line and not line.strip().startswith('#') and not line.strip().startswith('//'))): code_lines.append(line) in_code_block = True elif in_code_block and line.strip() == '': code_lines.append(line) # Keep empty lines in code blocks else: if in_code_block and line.strip(): # Check if this line looks like code or explanation if any(char in line for char in ['{', '}', ';', '()', '[]']) and not line.strip().endswith('.'): code_lines.append(line) else: explanation_lines.append(line) in_code_block = False else: explanation_lines.append(line) in_code_block = False code = '\n'.join(code_lines) explanation = '\n'.join(explanation_lines) return code.strip(), explanation.strip() # Initialize service llama_service = CodeLlamaService() @app.route('/') def index(): return render_template('index.html') @app.route('/api/status') def status(): return jsonify({ 'is_loaded': llama_service.is_loaded, 'is_loading': llama_service.is_loading }) @app.route('/api/load_model', methods=['POST']) def load_model(): if not llama_service.is_loaded and not llama_service.is_loading: threading.Thread(target=llama_service.load_model).start() return jsonify({'status': 'loading'}) elif llama_service.is_loaded: return jsonify({'status': 'loaded'}) else: return jsonify({'status': 'loading'}) @app.route('/api/generate', methods=['POST']) def generate(): data = request.json existing_code = data.get('existing_code', '').strip() instruction = data.get('instruction', '').strip() if not instruction: return jsonify({'error': 'Instruction is required'}) # Build prompt if existing_code: prompt = f"""Here is the existing code: ``` {existing_code} ``` Instruction: {instruction} Please provide the modified/complete code and explain what changes you made.""" else: prompt = f"""Instruction: {instruction} Please provide the code and explain what it does.""" # Generate response result = llama_service.generate_code( prompt, max_length=2048, temperature=0.3 ) return jsonify(result) @app.route('/api/explain', methods=['POST']) def explain_code(): data = request.json code = data.get('code', '').strip() if not code: return jsonify({'error': 'Code is required'}) prompt = f"""Please explain this code in detail: ``` {code} ``` Provide a clear explanation of what this code does, how it works, and any important details.""" result = llama_service.generate_code(prompt, max_length=1024, temperature=0.1) return jsonify({ 'explanation': result.get('explanation', result.get('full_response', '')) }) if __name__ == '__main__': # Load model on startup threading.Thread(target=llama_service.load_model).start() app.run(host='0.0.0.0', port=7860, debug=False, use_reloader=False)