Spaces:
Runtime error
Runtime error
| import threading | |
| import time | |
| import gradio as gr | |
| import logging | |
| import json | |
| import re | |
| import torch | |
| import tempfile | |
| import os | |
| from pathlib import Path | |
| from typing import Dict, List, Tuple, Optional, Any, Union | |
| from dataclasses import dataclass, field | |
| from enum import Enum | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline | |
| from sentence_transformers import SentenceTransformer | |
| import faiss | |
| import numpy as np | |
| from PIL import Image | |
| import black | |
| # Configure logging | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', | |
| handlers=[ | |
| logging.StreamHandler(), | |
| logging.FileHandler('gradio_builder.log') | |
| ] | |
| ) | |
| logger = logging.getLogger(__name__) | |
| # Configuration | |
| class Config: | |
| port: int = 7860 | |
| debug: bool = False | |
| share: bool = False | |
| model_name: str = "gpt2" | |
| embedding_model: str = "all-mpnet-base-v2" | |
| theme: str = "default" | |
| max_code_length: int = 5000 | |
| def from_file(cls, path: str) -> "Config": | |
| try: | |
| with open(path) as f: | |
| return cls(**json.load(f)) | |
| except Exception as e: | |
| logger.warning(f"Failed to load config from {path}: {e}. Using defaults.") | |
| return cls() | |
| # Constants | |
| CONFIG_PATH = Path("config.json") | |
| MODEL_CACHE_DIR = Path("model_cache") | |
| TEMPLATE_DIR = Path("templates") | |
| TEMP_DIR = Path("temp") | |
| DATABASE_PATH = Path("code_database.json") | |
| # Ensure directories exist | |
| for directory in [MODEL_CACHE_DIR, TEMPLATE_DIR, TEMP_DIR]: | |
| directory.mkdir(exist_ok=True, parents=True) | |
| class Template: | |
| code: str | |
| description: str | |
| components: List[str] = field(default_factory=list) | |
| created_at: str = field(default_factory=lambda: time.strftime("%Y-%m-%d %H:%M:%S")) | |
| tags: List[str] = field(default_factory=list) | |
| class TemplateManager: | |
| def __init__(self, template_dir: Path): | |
| self.template_dir = template_dir | |
| self.templates: Dict[str, Template] = {} | |
| def load_templates(self) -> None: | |
| for file_path in self.template_dir.glob("*.json"): | |
| try: | |
| with open(file_path, 'r') as f: | |
| template_data = json.load(f) | |
| template = Template(**template_data) | |
| self.templates[template_data['description']] = template | |
| logger.info(f"Loaded template: {file_path.stem}") | |
| except Exception as e: | |
| logger.error(f"Error loading template from {file_path}: {e}") | |
| def save_template(self, name: str, template: Template) -> bool: | |
| file_path = self.template_dir / f"{name}.json" | |
| try: | |
| with open(file_path, 'w') as f: | |
| json.dump(dataclasses.asdict(template), f, indent=2) | |
| self.templates[name] = template | |
| return True | |
| except Exception as e: | |
| logger.error(f"Error saving template to {file_path}: {e}") | |
| return False | |
| def get_template(self, name: str) -> Optional[str]: | |
| template = self.templates.get(name) | |
| return template.code if template else "" | |
| def delete_template(self, name: str) -> bool: | |
| file_path = self.template_dir / f"{name}.json" | |
| try: | |
| file_path.unlink() | |
| self.templates.pop(name, None) | |
| return True | |
| except Exception as e: | |
| logger.error(f"Error deleting template {name}: {e}") | |
| return False | |
| class RAGSystem: | |
| def __init__(self, config: Config): | |
| self.config = config | |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
| self.embedding_model = None | |
| self.code_embeddings = None | |
| self.index = None | |
| self.database = {'codes': [], 'embeddings': []} | |
| self.pipe = None | |
| try: | |
| self.tokenizer = AutoTokenizer.from_pretrained( | |
| config.model_name, | |
| cache_dir=MODEL_CACHE_DIR | |
| ) | |
| self.model = AutoModelForCausalLM.from_pretrained( | |
| config.model_name, | |
| cache_dir=MODEL_CACHE_DIR | |
| ).to(self.device) | |
| self.pipe = pipeline( | |
| "text-generation", | |
| model=self.model, | |
| tokenizer=self.tokenizer, | |
| device=self.device | |
| ) | |
| self.embedding_model = SentenceTransformer(config.embedding_model) | |
| self.load_database() | |
| logger.info("RAG system initialized successfully.") | |
| except Exception as e: | |
| logger.error(f"Error initializing RAG system: {e}") | |
| def load_database(self) -> None: | |
| if DATABASE_PATH.exists(): | |
| try: | |
| with open(DATABASE_PATH, 'r', encoding='utf-8') as f: | |
| self.database = json.load(f) | |
| self.code_embeddings = np.array(self.database['embeddings']) | |
| logger.info(f"Loaded {len(self.database['codes'])} code snippets from database.") | |
| self._build_index() | |
| except Exception as e: | |
| logger.error(f"Error loading database: {e}") | |
| self._initialize_empty_database() | |
| else: | |
| logger.info("Creating new database.") | |
| self._initialize_empty_database() | |
| def _initialize_empty_database(self) -> None: | |
| self.database = {'codes': [], 'embeddings': []} | |
| self.code_embeddings = np.array([]) | |
| self._build_index() | |
| def _build_index(self) -> None: | |
| if len(self.code_embeddings) > 0 and self.embedding_model: | |
| dim = self.code_embeddings.shape[1] | |
| self.index = faiss.IndexFlatL2(dim) | |
| self.index.add(self.code_embeddings) | |
| logger.info(f"Built FAISS index with {len(self.code_embeddings)} vectors") | |
| class GradioInterface: | |
| def __init__(self, config: Config): | |
| self.config = config | |
| self.template_manager = TemplateManager(TEMPLATE_DIR) | |
| self.template_manager.load_templates() | |
| self.rag_system = RAGSystem(config) | |
| def format_code(self, code: str) -> str: | |
| try: | |
| return black.format_str(code, mode=black.FileMode()) | |
| except Exception as e: | |
| logger.warning(f"Code formatting failed: {e}") | |
| return code | |
| def _extract_components(self, code: str) -> List[str]: | |
| components = [] | |
| try: | |
| function_matches = re.findall(r'def (\w+)\(', code) | |
| class_matches = re.findall(r'class (\w+):', code) | |
| components.extend(function_matches) | |
| components.extend(class_matches) | |
| except Exception as e: | |
| logger.error(f"Error extracting components: {e}") | |
| return list(set(components)) | |
| def launch(self) -> None: | |
| with gr.Blocks(theme=gr.themes.Base()) as interface: | |
| # Custom CSS | |
| gr.Markdown( | |
| """ | |
| <style> | |
| .header { | |
| text-align: center; | |
| background-color: #f0f0f0; | |
| padding: 20px; | |
| border-radius: 10px; | |
| margin-bottom: 20px; | |
| } | |
| .container { | |
| max-width: 1200px; | |
| margin: 0 auto; | |
| } | |
| </style> | |
| <div class="header"> | |
| <h1>Code Generation Interface</h1> | |
| <p>Generate and manage code templates easily</p> | |
| </div> | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| description_input = gr.Textbox( | |
| label="Description", | |
| placeholder="Enter a description for the code you want to generate", | |
| lines=3 | |
| ) | |
| template_choice = gr.Dropdown( | |
| label="Select Template", | |
| choices=list(self.template_manager.templates.keys()), | |
| value=None | |
| ) | |
| with gr.Row(): | |
| generate_button = gr.Button("Generate Code", variant="primary") | |
| save_button = gr.Button("Save as Template", variant="secondary") | |
| clear_button = gr.Button("Clear", variant="stop") | |
| with gr.Row(): | |
| code_output = gr.Code( | |
| label="Generated Code", | |
| language="python", | |
| interactive=True | |
| ) | |
| status_output = gr.Textbox( | |
| label="Status", | |
| interactive=False | |
| ) | |
| def generate_code_wrapper(description: str, template_choice: str) -> Tuple[str, str]: | |
| if not description.strip(): | |
| return "", "Please provide a description" | |
| try: | |
| template_code = self.template_manager.get_template(template_choice) if template_choice else "" | |
| generated_code = self.rag_system.generate_code(description, template_code) | |
| formatted_code = self.format_code(generated_code) | |
| if not formatted_code: | |
| return "", "Failed to generate code. Please try again." | |
| return formatted_code, "Code generated successfully." | |
| except Exception as e: | |
| logger.error(f"Error in code generation: {str(e)}") | |
| return "", f"Error: {str(e)}" | |
| def save_template_wrapper(code: str, name: str, description: str) -> Tuple[str, str]: | |
| try: | |
| if not name or not code: | |
| return code, "Template name and code are required." | |
| components = self._extract_components(code) | |
| template = Template( | |
| code=code, | |
| description=name, | |
| components=components, | |
| tags=[t.strip() for t in description.split(',') if t.strip()] | |
| ) | |
| if self.template_manager.save_template(name, template): | |
| self.rag_system.add_to_database(code) | |
| template_choice.choices = list(self.template_manager.templates.keys()) | |
| return code, f"Template '{name}' saved successfully." | |
| else: | |
| return code, "Failed to save template." | |
| except Exception as e: | |
| return code, f"Error saving template: {e}" | |
| def clear_outputs() -> Tuple[str, str, str]: | |
| return "", "", "" | |
| # Event handlers | |
| generate_button.click( | |
| fn=generate_code_wrapper, | |
| inputs=[description_input, template_choice], | |
| outputs=[code_output, status_output], | |
| api_name="generate_code", | |
| show_progress=True | |
| ) | |
| save_button.click( | |
| fn=save_template_wrapper, | |
| inputs=[code_output, template_choice, description_input], | |
| outputs=[code_output, status_output] | |
| ) | |
| clear_button.click( | |
| fn=clear_outputs, | |
| inputs=[], | |
| outputs=[description_input, code_output, status_output] | |
| ) | |
| # Launch the interface | |
| interface.launch( | |
| server_port=self.config.port, | |
| share=self.config.share, | |
| debug=self.config.debug | |
| ) | |
| def main(): | |
| logger.info("=== Application Startup ===") | |
| try: | |
| config = Config.from_file(CONFIG_PATH) | |
| interface = GradioInterface(config) | |
| interface.launch() | |
| except Exception as e: | |
| logger.error(f"Application error: {e}") | |
| raise | |
| finally: | |
| logger.info("=== Application Shutdown ===") | |
| if __name__ == "__main__": | |
| main() |