import os import re from flask import Flask, request, jsonify from transformers import AutoTokenizer, AutoModelForSeq2SeqLM import torch app = Flask(__name__) # --- MODEL LOADING --- MODEL_NAME = "SEBIS/code_trans_t5_base_commit_generation" print(f"--- AI Commit Generator Server ---") print(f"Downloading/Loading Model: {MODEL_NAME}") device = "cpu" try: tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, skip_special_tokens=True) model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME).to(device) print("✅ Model loaded successfully!") except Exception as e: print(f"❌ Error loading model: {e}") exit(1) def get_smart_fallback(diff_text, filename): """ Analyzes the CODE itself to generate a specific message when the AI model fails or hallucinates. """ # 1. Detect Logging / Debugging if "console.log" in diff_text or "outputChannel.append" in diff_text or "print(" in diff_text: return f"Add debug logging to {filename}" # 2. Detect Error Handling if "try" in diff_text and "catch" in diff_text: return f"Improve error handling in {filename}" # 3. Detect Timing / Async Logic if "setTimeout" in diff_text or "debounce" in diff_text or "await" in diff_text: return f"Update async/debounce logic in {filename}" # 4. Detect Import Changes if "import " in diff_text or "require(" in diff_text: return f"Update imports in {filename}" # 5. Last Resort: Find the function name patterns = [r'function\s+([a-zA-Z0-9_]+)', r'const\s+([a-zA-Z0-9_]+)\s*='] for pattern in patterns: match = re.search(pattern, diff_text) if match: return f"Refactor '{match.group(1)}' logic" return f"Update logic in {filename}" def preprocess_diff(diff_text): if not diff_text: return "" lines = diff_text.split('\n') cleaned_lines = [] for line in lines: if (line.startswith('+') or line.startswith('-')): if line.startswith('+++') or line.startswith('---'): continue # Don't strip imports here, we might need them for context if len(line.strip()) < 4: continue cleaned_lines.append(line.strip()) return "\n".join(cleaned_lines) def sanitize_summary(summary, diff_text, filename): summary_clean = summary.strip() # 1. Catch Hallucinated Jira Tickets ticket_pattern = re.compile(r'\b[A-Z]{3,}\s?-\s?\d+\b') match = ticket_pattern.search(summary_clean) if match: ticket = match.group() if ticket not in diff_text: print(f"⚠️ Hallucination Killed: '{ticket}' -> Using Smart Fallback") return get_smart_fallback(diff_text, filename) # 2. Catch Linguistic Nonsense forbidden_words = ["transitive verb", "intransitive", "adjective"] for word in forbidden_words: if word in summary_clean.lower(): return get_smart_fallback(diff_text, filename) return summary_clean def generate_summary(diff_text, filename): cleaned_diff = preprocess_diff(diff_text) if not cleaned_diff or len(cleaned_diff) < 20: return get_smart_fallback(diff_text, filename) input_ids = tokenizer.encode(cleaned_diff, return_tensors="pt", max_length=512, truncation=True).to(device) outputs = model.generate( input_ids, max_length=60, min_length=5, num_beams=5, early_stopping=True ) raw_summary = tokenizer.decode(outputs[0], skip_special_tokens=True) return sanitize_summary(raw_summary, diff_text, filename) @app.route('/generate', methods=['POST']) def generate_commit(): data = request.json files = data.get('files', []) if not files: return jsonify({"commit_message": ""}) final_message_parts = [] for file_obj in files: name = file_obj.get('name', 'file') diff = file_obj.get('diff', '') if len(diff) > 12000: final_message_parts.append(f"{name}\nUpdate large file (chunked)") continue try: summary = generate_summary(diff, name) final_message_parts.append(f"{name}\n{summary}") except Exception as e: print(f"Error processing {name}: {e}") final_message_parts.append(f"{name}\nUpdate file") return jsonify({"commit_message": "\n\n".join(final_message_parts)}) if __name__ == '__main__': app.run(host='0.0.0.0', port=7860)