| from flask import Flask, render_template, request, jsonify, stream_template |
| from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline |
| import torch |
| import gc |
| import threading |
| import time |
| import os |
| from datetime import datetime |
| import json |
| import logging |
|
|
| |
| logging.basicConfig(level=logging.INFO) |
| logger = logging.getLogger(__name__) |
|
|
| app = Flask(__name__) |
| app.secret_key = os.urandom(24) |
|
|
| class CodeLlamaService: |
| def __init__(self): |
| self.model = None |
| self.tokenizer = None |
| self.pipeline = None |
| self.is_loading = False |
| self.is_loaded = False |
| self.load_lock = threading.Lock() |
| |
| def load_model(self): |
| """Load Code Llama model with memory optimization for HF Spaces""" |
| if self.is_loaded or self.is_loading: |
| return |
| |
| with self.load_lock: |
| if self.is_loaded or self.is_loading: |
| return |
| |
| self.is_loading = True |
| logger.info("Loading Code Llama model...") |
| |
| try: |
| |
| model_name = "codellama/CodeLlama-7b-Instruct-hf" |
| |
| |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| logger.info(f"Using device: {device}") |
| |
| |
| self.tokenizer = AutoTokenizer.from_pretrained( |
| model_name, |
| use_fast=True, |
| trust_remote_code=True |
| ) |
| |
| |
| if device == "cuda": |
| |
| self.model = AutoModelForCausalLM.from_pretrained( |
| model_name, |
| torch_dtype=torch.float16, |
| low_cpu_mem_usage=True, |
| trust_remote_code=True, |
| device_map="auto" |
| ) |
| torch_dtype = torch.float16 |
| else: |
| |
| self.model = AutoModelForCausalLM.from_pretrained( |
| model_name, |
| torch_dtype=torch.float32, |
| low_cpu_mem_usage=True, |
| trust_remote_code=True |
| ) |
| |
| self.model = self.model.to('cpu') |
| torch_dtype = torch.float32 |
| |
| |
| if device == "cuda": |
| self.pipeline = pipeline( |
| "text-generation", |
| model=self.model, |
| tokenizer=self.tokenizer, |
| torch_dtype=torch_dtype, |
| device=0 |
| ) |
| else: |
| self.pipeline = pipeline( |
| "text-generation", |
| model=self.model, |
| tokenizer=self.tokenizer, |
| device=-1 |
| ) |
| |
| self.is_loaded = True |
| logger.info("Model loaded successfully!") |
| |
| except Exception as e: |
| logger.error(f"Error loading model: {str(e)}") |
| self.is_loaded = False |
| |
| if hasattr(self, 'model') and self.model is not None: |
| del self.model |
| if hasattr(self, 'tokenizer') and self.tokenizer is not None: |
| del self.tokenizer |
| if hasattr(self, 'pipeline') and self.pipeline is not None: |
| del self.pipeline |
| gc.collect() |
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
| finally: |
| self.is_loading = False |
| |
| def generate_code(self, prompt, max_length=1024, temperature=0.3): |
| """Generate code based on prompt""" |
| if not self.is_loaded: |
| return {"error": "Model not loaded", "code": "", "explanation": ""} |
| |
| try: |
| |
| formatted_prompt = f"<s>[INST] {prompt} [/INST]" |
| |
| |
| generation_kwargs = { |
| "max_new_tokens": max_length, |
| "do_sample": True if temperature > 0 else False, |
| "temperature": temperature if temperature > 0 else None, |
| "top_p": 0.9 if temperature > 0 else None, |
| "repetition_penalty": 1.1, |
| "return_full_text": False, |
| "pad_token_id": self.tokenizer.eos_token_id |
| } |
| |
| |
| generation_kwargs = {k: v for k, v in generation_kwargs.items() if v is not None} |
| |
| outputs = self.pipeline(formatted_prompt, **generation_kwargs) |
| |
| |
| if isinstance(outputs, list) and len(outputs) > 0: |
| if 'generated_text' in outputs[0]: |
| response = outputs[0]['generated_text'] |
| else: |
| response = str(outputs[0]) |
| else: |
| response = str(outputs) |
| |
| response = response.strip() |
| |
| |
| code, explanation = self._parse_response(response) |
| |
| return { |
| "success": True, |
| "code": code, |
| "explanation": explanation, |
| "full_response": response |
| } |
| |
| except Exception as e: |
| logger.error(f"Error generating code: {str(e)}") |
| return {"error": str(e), "code": "", "explanation": ""} |
| |
| def _parse_response(self, response): |
| """Parse response to separate code and explanation""" |
| |
| if "```" in response: |
| parts = response.split("```") |
| code_parts = [] |
| explanation_parts = [] |
| |
| for i, part in enumerate(parts): |
| if i % 2 == 1: |
| |
| lines = part.strip().split('\n') |
| if lines and any(lang in lines[0].lower() for lang in ['python', 'javascript', 'java', 'cpp', 'c++', 'html', 'css']): |
| code_parts.append('\n'.join(lines[1:])) |
| else: |
| code_parts.append(part.strip()) |
| else: |
| if part.strip(): |
| explanation_parts.append(part.strip()) |
| |
| code = '\n\n'.join(code_parts) |
| explanation = '\n\n'.join(explanation_parts) |
| else: |
| |
| lines = response.split('\n') |
| code_lines = [] |
| explanation_lines = [] |
| |
| in_code_block = False |
| for line in lines: |
| |
| if (line.strip().startswith(('def ', 'class ', 'import ', 'from ', 'if ', 'for ', 'while ', 'function', 'var ', 'let ', 'const ')) or |
| line.startswith((' ', '\t')) or |
| ('=' in line and not line.strip().startswith('#') and not line.strip().startswith('//'))): |
| code_lines.append(line) |
| in_code_block = True |
| elif in_code_block and line.strip() == '': |
| code_lines.append(line) |
| else: |
| if in_code_block and line.strip(): |
| |
| if any(char in line for char in ['{', '}', ';', '()', '[]']) and not line.strip().endswith('.'): |
| code_lines.append(line) |
| else: |
| explanation_lines.append(line) |
| in_code_block = False |
| else: |
| explanation_lines.append(line) |
| in_code_block = False |
| |
| code = '\n'.join(code_lines) |
| explanation = '\n'.join(explanation_lines) |
| |
| return code.strip(), explanation.strip() |
|
|
| |
| llama_service = CodeLlamaService() |
|
|
| @app.route('/') |
| def index(): |
| return render_template('index.html') |
|
|
| @app.route('/api/status') |
| def status(): |
| return jsonify({ |
| 'is_loaded': llama_service.is_loaded, |
| 'is_loading': llama_service.is_loading |
| }) |
|
|
| @app.route('/api/load_model', methods=['POST']) |
| def load_model(): |
| if not llama_service.is_loaded and not llama_service.is_loading: |
| threading.Thread(target=llama_service.load_model).start() |
| return jsonify({'status': 'loading'}) |
| elif llama_service.is_loaded: |
| return jsonify({'status': 'loaded'}) |
| else: |
| return jsonify({'status': 'loading'}) |
|
|
| @app.route('/api/generate', methods=['POST']) |
| def generate(): |
| data = request.json |
| |
| existing_code = data.get('existing_code', '').strip() |
| instruction = data.get('instruction', '').strip() |
| |
| if not instruction: |
| return jsonify({'error': 'Instruction is required'}) |
| |
| |
| if existing_code: |
| prompt = f"""Here is the existing code: |
| |
| ``` |
| {existing_code} |
| ``` |
| |
| Instruction: {instruction} |
| |
| Please provide the modified/complete code and explain what changes you made.""" |
| else: |
| prompt = f"""Instruction: {instruction} |
| |
| Please provide the code and explain what it does.""" |
| |
| |
| result = llama_service.generate_code( |
| prompt, |
| max_length=2048, |
| temperature=0.3 |
| ) |
| |
| return jsonify(result) |
|
|
| @app.route('/api/explain', methods=['POST']) |
| def explain_code(): |
| data = request.json |
| code = data.get('code', '').strip() |
| |
| if not code: |
| return jsonify({'error': 'Code is required'}) |
| |
| prompt = f"""Please explain this code in detail: |
| |
| ``` |
| {code} |
| ``` |
| |
| Provide a clear explanation of what this code does, how it works, and any important details.""" |
| |
| result = llama_service.generate_code(prompt, max_length=1024, temperature=0.1) |
| |
| return jsonify({ |
| 'explanation': result.get('explanation', result.get('full_response', '')) |
| }) |
|
|
| if __name__ == '__main__': |
| |
| threading.Thread(target=llama_service.load_model).start() |
| app.run(host='0.0.0.0', port=7860, debug=False, use_reloader=False) |