File size: 4,499 Bytes
fddb21f
a052544
fddb21f
eb3184c
fddb21f
 
 
 
cda6349
eb3184c
fddb21f
d8b7758
fddb21f
cbccf44
fddb21f
 
eb3184c
 
d8b7758
fddb21f
 
 
 
1a06520
a052544
1ae0af0
 
a052544
1ae0af0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1a06520
 
 
1ae0af0
1a06520
 
 
 
 
a052544
 
 
0857640
cda6349
1ae0af0
 
0857640
a052544
 
cda6349
 
 
1a06520
cda6349
 
0857640
 
 
 
1ae0af0
1a06520
cda6349
 
1a06520
 
 
 
cda6349
 
0857640
 
a052544
 
cda6349
1a06520
fddb21f
a052544
fddb21f
 
 
cda6349
 
fddb21f
d8b7758
fddb21f
a052544
cda6349
1a06520
fddb21f
 
 
 
 
 
1a06520
fddb21f
 
 
 
0857640
fddb21f
 
eb3184c
cda6349
fddb21f
 
 
0857640
fddb21f
 
 
1a06520
fddb21f
 
 
 
3474fd8
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import os
import re
from flask import Flask, request, jsonify
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch

app = Flask(__name__)

# --- MODEL LOADING ---
MODEL_NAME = "SEBIS/code_trans_t5_base_commit_generation"
print(f"--- AI Commit Generator Server ---")
print(f"Downloading/Loading Model: {MODEL_NAME}")

device = "cpu"

try:
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, skip_special_tokens=True)
    model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME).to(device)
    print("✅ Model loaded successfully!")
except Exception as e:
    print(f"❌ Error loading model: {e}")
    exit(1)

def get_smart_fallback(diff_text, filename):
    """
    Analyzes the CODE itself to generate a specific message 
    when the AI model fails or hallucinates.
    """
    # 1. Detect Logging / Debugging
    if "console.log" in diff_text or "outputChannel.append" in diff_text or "print(" in diff_text:
        return f"Add debug logging to {filename}"

    # 2. Detect Error Handling
    if "try" in diff_text and "catch" in diff_text:
        return f"Improve error handling in {filename}"

    # 3. Detect Timing / Async Logic
    if "setTimeout" in diff_text or "debounce" in diff_text or "await" in diff_text:
        return f"Update async/debounce logic in {filename}"

    # 4. Detect Import Changes
    if "import " in diff_text or "require(" in diff_text:
        return f"Update imports in {filename}"

    # 5. Last Resort: Find the function name
    patterns = [r'function\s+([a-zA-Z0-9_]+)', r'const\s+([a-zA-Z0-9_]+)\s*=']
    for pattern in patterns:
        match = re.search(pattern, diff_text)
        if match:
            return f"Refactor '{match.group(1)}' logic"
            
    return f"Update logic in {filename}"

def preprocess_diff(diff_text):
    if not diff_text: return ""
    lines = diff_text.split('\n')
    cleaned_lines = []
    for line in lines:
        if (line.startswith('+') or line.startswith('-')):
            if line.startswith('+++') or line.startswith('---'): continue
            # Don't strip imports here, we might need them for context
            if len(line.strip()) < 4: continue 
            cleaned_lines.append(line.strip())
    return "\n".join(cleaned_lines)

def sanitize_summary(summary, diff_text, filename):
    summary_clean = summary.strip()
    
    # 1. Catch Hallucinated Jira Tickets
    ticket_pattern = re.compile(r'\b[A-Z]{3,}\s?-\s?\d+\b')
    match = ticket_pattern.search(summary_clean)
    
    if match:
        ticket = match.group()
        if ticket not in diff_text:
            print(f"⚠️ Hallucination Killed: '{ticket}' -> Using Smart Fallback")
            return get_smart_fallback(diff_text, filename)

    # 2. Catch Linguistic Nonsense
    forbidden_words = ["transitive verb", "intransitive", "adjective"]
    for word in forbidden_words:
        if word in summary_clean.lower():
            return get_smart_fallback(diff_text, filename)

    return summary_clean

def generate_summary(diff_text, filename):
    cleaned_diff = preprocess_diff(diff_text)
    
    if not cleaned_diff or len(cleaned_diff) < 20:
        return get_smart_fallback(diff_text, filename)

    input_ids = tokenizer.encode(cleaned_diff, return_tensors="pt", max_length=512, truncation=True).to(device)

    outputs = model.generate(
        input_ids, 
        max_length=60,         
        min_length=5, 
        num_beams=5, 
        early_stopping=True
    )
    
    raw_summary = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return sanitize_summary(raw_summary, diff_text, filename)

@app.route('/generate', methods=['POST'])
def generate_commit():
    data = request.json
    files = data.get('files', [])
    
    if not files: return jsonify({"commit_message": ""})

    final_message_parts = []
    
    for file_obj in files:
        name = file_obj.get('name', 'file')
        diff = file_obj.get('diff', '')
        
        if len(diff) > 12000:
            final_message_parts.append(f"{name}\nUpdate large file (chunked)")
            continue

        try:
            summary = generate_summary(diff, name)
            final_message_parts.append(f"{name}\n{summary}")
        except Exception as e:
            print(f"Error processing {name}: {e}")
            final_message_parts.append(f"{name}\nUpdate file")

    return jsonify({"commit_message": "\n\n".join(final_message_parts)})

if __name__ == '__main__':
    app.run(host='0.0.0.0', port=7860)