NiviruIns's picture
Update app.py
1ae0af0 verified
import os
import re
from flask import Flask, request, jsonify
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch
app = Flask(__name__)
# --- MODEL LOADING ---
MODEL_NAME = "SEBIS/code_trans_t5_base_commit_generation"
print(f"--- AI Commit Generator Server ---")
print(f"Downloading/Loading Model: {MODEL_NAME}")
device = "cpu"
try:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, skip_special_tokens=True)
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME).to(device)
print("✅ Model loaded successfully!")
except Exception as e:
print(f"❌ Error loading model: {e}")
exit(1)
def get_smart_fallback(diff_text, filename):
"""
Analyzes the CODE itself to generate a specific message
when the AI model fails or hallucinates.
"""
# 1. Detect Logging / Debugging
if "console.log" in diff_text or "outputChannel.append" in diff_text or "print(" in diff_text:
return f"Add debug logging to {filename}"
# 2. Detect Error Handling
if "try" in diff_text and "catch" in diff_text:
return f"Improve error handling in {filename}"
# 3. Detect Timing / Async Logic
if "setTimeout" in diff_text or "debounce" in diff_text or "await" in diff_text:
return f"Update async/debounce logic in {filename}"
# 4. Detect Import Changes
if "import " in diff_text or "require(" in diff_text:
return f"Update imports in {filename}"
# 5. Last Resort: Find the function name
patterns = [r'function\s+([a-zA-Z0-9_]+)', r'const\s+([a-zA-Z0-9_]+)\s*=']
for pattern in patterns:
match = re.search(pattern, diff_text)
if match:
return f"Refactor '{match.group(1)}' logic"
return f"Update logic in {filename}"
def preprocess_diff(diff_text):
if not diff_text: return ""
lines = diff_text.split('\n')
cleaned_lines = []
for line in lines:
if (line.startswith('+') or line.startswith('-')):
if line.startswith('+++') or line.startswith('---'): continue
# Don't strip imports here, we might need them for context
if len(line.strip()) < 4: continue
cleaned_lines.append(line.strip())
return "\n".join(cleaned_lines)
def sanitize_summary(summary, diff_text, filename):
summary_clean = summary.strip()
# 1. Catch Hallucinated Jira Tickets
ticket_pattern = re.compile(r'\b[A-Z]{3,}\s?-\s?\d+\b')
match = ticket_pattern.search(summary_clean)
if match:
ticket = match.group()
if ticket not in diff_text:
print(f"⚠️ Hallucination Killed: '{ticket}' -> Using Smart Fallback")
return get_smart_fallback(diff_text, filename)
# 2. Catch Linguistic Nonsense
forbidden_words = ["transitive verb", "intransitive", "adjective"]
for word in forbidden_words:
if word in summary_clean.lower():
return get_smart_fallback(diff_text, filename)
return summary_clean
def generate_summary(diff_text, filename):
cleaned_diff = preprocess_diff(diff_text)
if not cleaned_diff or len(cleaned_diff) < 20:
return get_smart_fallback(diff_text, filename)
input_ids = tokenizer.encode(cleaned_diff, return_tensors="pt", max_length=512, truncation=True).to(device)
outputs = model.generate(
input_ids,
max_length=60,
min_length=5,
num_beams=5,
early_stopping=True
)
raw_summary = tokenizer.decode(outputs[0], skip_special_tokens=True)
return sanitize_summary(raw_summary, diff_text, filename)
@app.route('/generate', methods=['POST'])
def generate_commit():
data = request.json
files = data.get('files', [])
if not files: return jsonify({"commit_message": ""})
final_message_parts = []
for file_obj in files:
name = file_obj.get('name', 'file')
diff = file_obj.get('diff', '')
if len(diff) > 12000:
final_message_parts.append(f"{name}\nUpdate large file (chunked)")
continue
try:
summary = generate_summary(diff, name)
final_message_parts.append(f"{name}\n{summary}")
except Exception as e:
print(f"Error processing {name}: {e}")
final_message_parts.append(f"{name}\nUpdate file")
return jsonify({"commit_message": "\n\n".join(final_message_parts)})
if __name__ == '__main__':
app.run(host='0.0.0.0', port=7860)