Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| import torch | |
| from huggingface_hub import login | |
| import os | |
| import logging | |
| from datetime import datetime | |
| import json | |
| from typing import List, Dict | |
| import warnings | |
| import spaces | |
| # Filter out warnings | |
| warnings.filterwarnings('ignore') | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Environment variables | |
| HF_TOKEN = os.getenv("HUGGING_FACE_TOKEN") | |
| MODEL_NAME = os.getenv("MODEL_NAME", "google/gemma-2b-it") | |
| # Cache directory for model | |
| CACHE_DIR = "/home/user/.cache/huggingface" | |
| os.makedirs(CACHE_DIR, exist_ok=True) | |
| # History file | |
| HISTORY_FILE = "/home/user/review_history.json" | |
| class Review: | |
| def __init__(self, code: str, language: str, suggestions: str): | |
| self.code = code | |
| self.language = language | |
| self.suggestions = suggestions | |
| self.timestamp = datetime.now().isoformat() | |
| self.response_time = 0.0 | |
| def to_dict(self): | |
| return { | |
| 'timestamp': self.timestamp, | |
| 'language': self.language, | |
| 'code': self.code, | |
| 'suggestions': self.suggestions, | |
| 'response_time': self.response_time | |
| } | |
| def from_dict(cls, data): | |
| review = cls(data['code'], data['language'], data['suggestions']) | |
| review.timestamp = data['timestamp'] | |
| review.response_time = data['response_time'] | |
| return review | |
| class CodeReviewer: | |
| def __init__(self): | |
| self.model = None | |
| self.tokenizer = None | |
| self.device = None | |
| self.review_history: List[Review] = [] | |
| self.metrics = { | |
| 'total_reviews': 0, | |
| 'avg_response_time': 0.0, | |
| 'reviews_today': 0 | |
| } | |
| self._initialized = False | |
| self.load_history() | |
| def load_history(self): | |
| """Load review history from file.""" | |
| try: | |
| if os.path.exists(HISTORY_FILE): | |
| with open(HISTORY_FILE, 'r') as f: | |
| data = json.load(f) | |
| self.review_history = [Review.from_dict(r) for r in data['history']] | |
| self.metrics = data['metrics'] | |
| logger.info(f"Loaded {len(self.review_history)} reviews from history") | |
| except Exception as e: | |
| logger.error(f"Error loading history: {e}") | |
| def save_history(self): | |
| """Save review history to file.""" | |
| try: | |
| data = { | |
| 'history': [r.to_dict() for r in self.review_history], | |
| 'metrics': self.metrics | |
| } | |
| with open(HISTORY_FILE, 'w') as f: | |
| json.dump(data, f) | |
| logger.info("Saved review history") | |
| except Exception as e: | |
| logger.error(f"Error saving history: {e}") | |
| def ensure_initialized(self): | |
| """Ensure model is initialized.""" | |
| if not self._initialized: | |
| self.initialize_model() | |
| self._initialized = True | |
| def initialize_model(self): | |
| """Initialize the model and tokenizer.""" | |
| try: | |
| if HF_TOKEN: | |
| login(token=HF_TOKEN, add_to_git_credential=False) | |
| logger.info("Loading tokenizer...") | |
| self.tokenizer = AutoTokenizer.from_pretrained( | |
| MODEL_NAME, | |
| token=HF_TOKEN, | |
| trust_remote_code=True, | |
| cache_dir=CACHE_DIR | |
| ) | |
| special_tokens = { | |
| 'pad_token': '[PAD]', | |
| 'eos_token': '</s>', | |
| 'bos_token': '<s>' | |
| } | |
| num_added = self.tokenizer.add_special_tokens(special_tokens) | |
| logger.info(f"Added {num_added} special tokens") | |
| logger.info("Tokenizer loaded successfully") | |
| logger.info("Loading model...") | |
| self.model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_NAME, | |
| device_map="auto", | |
| torch_dtype=torch.float16, | |
| trust_remote_code=True, | |
| low_cpu_mem_usage=True, | |
| cache_dir=CACHE_DIR, | |
| token=HF_TOKEN | |
| ) | |
| if num_added > 0: | |
| logger.info("Resizing model embeddings for special tokens") | |
| self.model.resize_token_embeddings(len(self.tokenizer)) | |
| self.device = next(self.model.parameters()).device | |
| logger.info(f"Model loaded successfully on {self.device}") | |
| self._initialized = True | |
| return True | |
| except Exception as e: | |
| logger.error(f"Error initializing model: {e}") | |
| self._initialized = False | |
| return False | |
| def create_review_prompt(self, code: str, language: str) -> str: | |
| """Create a structured prompt for code review.""" | |
| return f"""Review this {language} code. List specific points in these sections: | |
| Issues: | |
| Improvements: | |
| Best Practices: | |
| Security: | |
| Code: | |
| ```{language} | |
| {code} | |
| ```""" | |
| def review_code(self, code: str, language: str) -> str: | |
| """Perform code review using the model.""" | |
| try: | |
| if not self._initialized and not self.initialize_model(): | |
| return "Error: Model initialization failed. Please try again later." | |
| start_time = datetime.now() | |
| prompt = self.create_review_prompt(code, language) | |
| try: | |
| inputs = self.tokenizer( | |
| prompt, | |
| return_tensors="pt", | |
| truncation=True, | |
| max_length=512, | |
| padding=True | |
| ) | |
| if inputs is None: | |
| raise ValueError("Failed to tokenize input") | |
| inputs = inputs.to(self.device) | |
| except Exception as token_error: | |
| logger.error(f"Tokenization error: {token_error}") | |
| return "Error: Failed to process input code. Please try again." | |
| try: | |
| with torch.no_grad(): | |
| outputs = self.model.generate( | |
| **inputs, | |
| max_new_tokens=512, | |
| do_sample=True, | |
| temperature=0.7, | |
| top_p=0.95, | |
| num_beams=1, | |
| early_stopping=True, | |
| pad_token_id=self.tokenizer.pad_token_id, | |
| eos_token_id=self.tokenizer.eos_token_id | |
| ) | |
| except Exception as gen_error: | |
| logger.error(f"Generation error: {gen_error}") | |
| return "Error: Failed to generate review. Please try again." | |
| try: | |
| response = self.tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| suggestions = response[len(prompt):].strip() | |
| except Exception as decode_error: | |
| logger.error(f"Decoding error: {decode_error}") | |
| return "Error: Failed to decode model output. Please try again." | |
| # Create and save review | |
| end_time = datetime.now() | |
| review = Review(code, language, suggestions) | |
| review.response_time = (end_time - start_time).total_seconds() | |
| # Update metrics first | |
| self.metrics['total_reviews'] += 1 | |
| total_time = self.metrics['avg_response_time'] * (self.metrics['total_reviews'] - 1) | |
| total_time += review.response_time | |
| self.metrics['avg_response_time'] = total_time / self.metrics['total_reviews'] | |
| today = datetime.now().date() | |
| # Add review to history | |
| self.review_history.append(review) | |
| # Update today's reviews count | |
| self.metrics['reviews_today'] = sum( | |
| 1 for r in self.review_history | |
| if datetime.fromisoformat(r.timestamp).date() == today | |
| ) | |
| # Save to file | |
| self.save_history() | |
| if self.device and self.device.type == "cuda": | |
| del inputs, outputs | |
| torch.cuda.empty_cache() | |
| return suggestions | |
| except Exception as e: | |
| logger.error(f"Error during code review: {e}") | |
| return f"Error performing code review: {str(e)}" | |
| def update_metrics(self, review: Review): | |
| """Update metrics with new review.""" | |
| self.metrics['total_reviews'] += 1 | |
| total_time = self.metrics['avg_response_time'] * (self.metrics['total_reviews'] - 1) | |
| total_time += review.response_time | |
| self.metrics['avg_response_time'] = total_time / self.metrics['total_reviews'] | |
| today = datetime.now().date() | |
| self.metrics['reviews_today'] = sum( | |
| 1 for r in self.review_history | |
| if datetime.fromisoformat(r.timestamp).date() == today | |
| ) | |
| def get_history(self) -> List[Dict]: | |
| """Get formatted review history.""" | |
| return [ | |
| { | |
| 'timestamp': r.timestamp, | |
| 'language': r.language, | |
| 'code': r.code, | |
| 'suggestions': r.suggestions, | |
| 'response_time': f"{r.response_time:.2f}s" | |
| } | |
| for r in reversed(self.review_history[-10:]) | |
| ] | |
| def get_metrics(self) -> Dict: | |
| """Get current metrics.""" | |
| return { | |
| 'Total Reviews': self.metrics['total_reviews'], | |
| 'Average Response Time': f"{self.metrics['avg_response_time']:.2f}s", | |
| 'Reviews Today': self.metrics['reviews_today'], | |
| 'Device': str(self.device) if self.device else "Not initialized" | |
| } | |
| # Initialize reviewer | |
| reviewer = CodeReviewer() | |
| # Create Gradio interface | |
| with gr.Blocks(theme=gr.themes.Soft()) as iface: | |
| gr.Markdown("# Code Review Assistant") | |
| gr.Markdown("An automated code review system powered by Gemma-2b") | |
| with gr.Tabs(): | |
| with gr.Tab("Review Code"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| code_input = gr.Textbox( | |
| lines=10, | |
| placeholder="Enter your code here...", | |
| label="Code" | |
| ) | |
| language_input = gr.Dropdown( | |
| choices=["python", "javascript", "java", "cpp", "typescript", "go", "rust"], | |
| value="python", | |
| label="Language" | |
| ) | |
| submit_btn = gr.Button("Submit for Review", variant="primary") | |
| with gr.Column(): | |
| output = gr.Textbox( | |
| label="Review Results", | |
| lines=10 | |
| ) | |
| with gr.Tab("History"): | |
| with gr.Row(): | |
| refresh_history = gr.Button("Refresh History", variant="secondary") | |
| history_output = gr.Textbox( | |
| label="Review History", | |
| lines=20, | |
| value="Click 'Refresh History' to view review history" | |
| ) | |
| with gr.Tab("Metrics"): | |
| with gr.Row(): | |
| refresh_metrics = gr.Button("Refresh Metrics", variant="secondary") | |
| metrics_output = gr.JSON( | |
| label="Performance Metrics" | |
| ) | |
| def review_code_interface(code: str, language: str) -> str: | |
| if not code.strip(): | |
| return "Please enter some code to review." | |
| try: | |
| reviewer.ensure_initialized() | |
| result = reviewer.review_code(code, language) | |
| return result | |
| except Exception as e: | |
| logger.error(f"Interface error: {e}") | |
| return f"Error: {str(e)}" | |
| def get_history_interface() -> str: | |
| try: | |
| history = reviewer.get_history() | |
| if not history: | |
| return "No reviews yet." | |
| result = "" | |
| for review in history: | |
| result += f"Time: {review['timestamp']}\n" | |
| result += f"Language: {review['language']}\n" | |
| result += f"Response Time: {review['response_time']}\n" | |
| result += "Code:\n```\n" + review['code'] + "\n```\n" | |
| result += "Suggestions:\n" + review['suggestions'] + "\n" | |
| result += "-" * 80 + "\n\n" | |
| return result | |
| except Exception as e: | |
| logger.error(f"History error: {e}") | |
| return "Error retrieving history" | |
| def get_metrics_interface() -> Dict: | |
| try: | |
| metrics = reviewer.get_metrics() | |
| if not metrics: | |
| return { | |
| 'Total Reviews': 0, | |
| 'Average Response Time': '0.00s', | |
| 'Reviews Today': 0, | |
| 'Device': str(reviewer.device) if reviewer.device else "Not initialized" | |
| } | |
| return metrics | |
| except Exception as e: | |
| logger.error(f"Metrics error: {e}") | |
| return {"error": str(e)} | |
| def update_all_outputs(code: str, language: str) -> tuple: | |
| """Update all outputs after code review.""" | |
| result = review_code_interface(code, language) | |
| history = get_history_interface() | |
| metrics = get_metrics_interface() | |
| return result, history, metrics | |
| # Connect the interface | |
| submit_btn.click( | |
| update_all_outputs, | |
| inputs=[code_input, language_input], | |
| outputs=[output, history_output, metrics_output] | |
| ) | |
| refresh_history.click( | |
| get_history_interface, | |
| outputs=history_output | |
| ) | |
| refresh_metrics.click( | |
| get_metrics_interface, | |
| outputs=metrics_output | |
| ) | |
| # Add example inputs | |
| gr.Examples( | |
| examples=[ | |
| ["""def add_numbers(a, b): | |
| return a + b""", "python"], | |
| ["""function calculateSum(numbers) { | |
| let sum = 0; | |
| for(let i = 0; i < numbers.length; i++) { | |
| sum += numbers[i]; | |
| } | |
| return sum; | |
| }""", "javascript"] | |
| ], | |
| inputs=[code_input, language_input] | |
| ) | |
| # Launch the app | |
| if __name__ == "__main__": | |
| iface.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| show_error=True, | |
| quiet=False | |
| ) | |