Spaces:
Sleeping
Sleeping
| import os | |
| import re | |
| from flask import Flask, request, jsonify | |
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
| import torch | |
| app = Flask(__name__) | |
| # --- MODEL LOADING --- | |
| MODEL_NAME = "SEBIS/code_trans_t5_base_commit_generation" | |
| print(f"--- AI Commit Generator Server ---") | |
| print(f"Downloading/Loading Model: {MODEL_NAME}") | |
| device = "cpu" | |
| try: | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, skip_special_tokens=True) | |
| model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME).to(device) | |
| print("✅ Model loaded successfully!") | |
| except Exception as e: | |
| print(f"❌ Error loading model: {e}") | |
| exit(1) | |
| def get_smart_fallback(diff_text, filename): | |
| """ | |
| Analyzes the CODE itself to generate a specific message | |
| when the AI model fails or hallucinates. | |
| """ | |
| # 1. Detect Logging / Debugging | |
| if "console.log" in diff_text or "outputChannel.append" in diff_text or "print(" in diff_text: | |
| return f"Add debug logging to {filename}" | |
| # 2. Detect Error Handling | |
| if "try" in diff_text and "catch" in diff_text: | |
| return f"Improve error handling in {filename}" | |
| # 3. Detect Timing / Async Logic | |
| if "setTimeout" in diff_text or "debounce" in diff_text or "await" in diff_text: | |
| return f"Update async/debounce logic in {filename}" | |
| # 4. Detect Import Changes | |
| if "import " in diff_text or "require(" in diff_text: | |
| return f"Update imports in {filename}" | |
| # 5. Last Resort: Find the function name | |
| patterns = [r'function\s+([a-zA-Z0-9_]+)', r'const\s+([a-zA-Z0-9_]+)\s*='] | |
| for pattern in patterns: | |
| match = re.search(pattern, diff_text) | |
| if match: | |
| return f"Refactor '{match.group(1)}' logic" | |
| return f"Update logic in {filename}" | |
| def preprocess_diff(diff_text): | |
| if not diff_text: return "" | |
| lines = diff_text.split('\n') | |
| cleaned_lines = [] | |
| for line in lines: | |
| if (line.startswith('+') or line.startswith('-')): | |
| if line.startswith('+++') or line.startswith('---'): continue | |
| # Don't strip imports here, we might need them for context | |
| if len(line.strip()) < 4: continue | |
| cleaned_lines.append(line.strip()) | |
| return "\n".join(cleaned_lines) | |
| def sanitize_summary(summary, diff_text, filename): | |
| summary_clean = summary.strip() | |
| # 1. Catch Hallucinated Jira Tickets | |
| ticket_pattern = re.compile(r'\b[A-Z]{3,}\s?-\s?\d+\b') | |
| match = ticket_pattern.search(summary_clean) | |
| if match: | |
| ticket = match.group() | |
| if ticket not in diff_text: | |
| print(f"⚠️ Hallucination Killed: '{ticket}' -> Using Smart Fallback") | |
| return get_smart_fallback(diff_text, filename) | |
| # 2. Catch Linguistic Nonsense | |
| forbidden_words = ["transitive verb", "intransitive", "adjective"] | |
| for word in forbidden_words: | |
| if word in summary_clean.lower(): | |
| return get_smart_fallback(diff_text, filename) | |
| return summary_clean | |
| def generate_summary(diff_text, filename): | |
| cleaned_diff = preprocess_diff(diff_text) | |
| if not cleaned_diff or len(cleaned_diff) < 20: | |
| return get_smart_fallback(diff_text, filename) | |
| input_ids = tokenizer.encode(cleaned_diff, return_tensors="pt", max_length=512, truncation=True).to(device) | |
| outputs = model.generate( | |
| input_ids, | |
| max_length=60, | |
| min_length=5, | |
| num_beams=5, | |
| early_stopping=True | |
| ) | |
| raw_summary = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| return sanitize_summary(raw_summary, diff_text, filename) | |
| def generate_commit(): | |
| data = request.json | |
| files = data.get('files', []) | |
| if not files: return jsonify({"commit_message": ""}) | |
| final_message_parts = [] | |
| for file_obj in files: | |
| name = file_obj.get('name', 'file') | |
| diff = file_obj.get('diff', '') | |
| if len(diff) > 12000: | |
| final_message_parts.append(f"{name}\nUpdate large file (chunked)") | |
| continue | |
| try: | |
| summary = generate_summary(diff, name) | |
| final_message_parts.append(f"{name}\n{summary}") | |
| except Exception as e: | |
| print(f"Error processing {name}: {e}") | |
| final_message_parts.append(f"{name}\nUpdate file") | |
| return jsonify({"commit_message": "\n\n".join(final_message_parts)}) | |
| if __name__ == '__main__': | |
| app.run(host='0.0.0.0', port=7860) |