Spaces:
Sleeping
Sleeping
File size: 4,499 Bytes
fddb21f a052544 fddb21f eb3184c fddb21f cda6349 eb3184c fddb21f d8b7758 fddb21f cbccf44 fddb21f eb3184c d8b7758 fddb21f 1a06520 a052544 1ae0af0 a052544 1ae0af0 1a06520 1ae0af0 1a06520 a052544 0857640 cda6349 1ae0af0 0857640 a052544 cda6349 1a06520 cda6349 0857640 1ae0af0 1a06520 cda6349 1a06520 cda6349 0857640 a052544 cda6349 1a06520 fddb21f a052544 fddb21f cda6349 fddb21f d8b7758 fddb21f a052544 cda6349 1a06520 fddb21f 1a06520 fddb21f 0857640 fddb21f eb3184c cda6349 fddb21f 0857640 fddb21f 1a06520 fddb21f 3474fd8 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 | import os
import re
from flask import Flask, request, jsonify
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch
app = Flask(__name__)
# --- MODEL LOADING ---
MODEL_NAME = "SEBIS/code_trans_t5_base_commit_generation"
print(f"--- AI Commit Generator Server ---")
print(f"Downloading/Loading Model: {MODEL_NAME}")
device = "cpu"
try:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, skip_special_tokens=True)
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME).to(device)
print("✅ Model loaded successfully!")
except Exception as e:
print(f"❌ Error loading model: {e}")
exit(1)
def get_smart_fallback(diff_text, filename):
"""
Analyzes the CODE itself to generate a specific message
when the AI model fails or hallucinates.
"""
# 1. Detect Logging / Debugging
if "console.log" in diff_text or "outputChannel.append" in diff_text or "print(" in diff_text:
return f"Add debug logging to {filename}"
# 2. Detect Error Handling
if "try" in diff_text and "catch" in diff_text:
return f"Improve error handling in {filename}"
# 3. Detect Timing / Async Logic
if "setTimeout" in diff_text or "debounce" in diff_text or "await" in diff_text:
return f"Update async/debounce logic in {filename}"
# 4. Detect Import Changes
if "import " in diff_text or "require(" in diff_text:
return f"Update imports in {filename}"
# 5. Last Resort: Find the function name
patterns = [r'function\s+([a-zA-Z0-9_]+)', r'const\s+([a-zA-Z0-9_]+)\s*=']
for pattern in patterns:
match = re.search(pattern, diff_text)
if match:
return f"Refactor '{match.group(1)}' logic"
return f"Update logic in {filename}"
def preprocess_diff(diff_text):
if not diff_text: return ""
lines = diff_text.split('\n')
cleaned_lines = []
for line in lines:
if (line.startswith('+') or line.startswith('-')):
if line.startswith('+++') or line.startswith('---'): continue
# Don't strip imports here, we might need them for context
if len(line.strip()) < 4: continue
cleaned_lines.append(line.strip())
return "\n".join(cleaned_lines)
def sanitize_summary(summary, diff_text, filename):
summary_clean = summary.strip()
# 1. Catch Hallucinated Jira Tickets
ticket_pattern = re.compile(r'\b[A-Z]{3,}\s?-\s?\d+\b')
match = ticket_pattern.search(summary_clean)
if match:
ticket = match.group()
if ticket not in diff_text:
print(f"⚠️ Hallucination Killed: '{ticket}' -> Using Smart Fallback")
return get_smart_fallback(diff_text, filename)
# 2. Catch Linguistic Nonsense
forbidden_words = ["transitive verb", "intransitive", "adjective"]
for word in forbidden_words:
if word in summary_clean.lower():
return get_smart_fallback(diff_text, filename)
return summary_clean
def generate_summary(diff_text, filename):
cleaned_diff = preprocess_diff(diff_text)
if not cleaned_diff or len(cleaned_diff) < 20:
return get_smart_fallback(diff_text, filename)
input_ids = tokenizer.encode(cleaned_diff, return_tensors="pt", max_length=512, truncation=True).to(device)
outputs = model.generate(
input_ids,
max_length=60,
min_length=5,
num_beams=5,
early_stopping=True
)
raw_summary = tokenizer.decode(outputs[0], skip_special_tokens=True)
return sanitize_summary(raw_summary, diff_text, filename)
@app.route('/generate', methods=['POST'])
def generate_commit():
data = request.json
files = data.get('files', [])
if not files: return jsonify({"commit_message": ""})
final_message_parts = []
for file_obj in files:
name = file_obj.get('name', 'file')
diff = file_obj.get('diff', '')
if len(diff) > 12000:
final_message_parts.append(f"{name}\nUpdate large file (chunked)")
continue
try:
summary = generate_summary(diff, name)
final_message_parts.append(f"{name}\n{summary}")
except Exception as e:
print(f"Error processing {name}: {e}")
final_message_parts.append(f"{name}\nUpdate file")
return jsonify({"commit_message": "\n\n".join(final_message_parts)})
if __name__ == '__main__':
app.run(host='0.0.0.0', port=7860) |