NiviruIns commited on
Commit
cda6349
·
verified ·
1 Parent(s): 0857640

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +47 -46
app.py CHANGED
@@ -6,7 +6,7 @@ import torch
6
 
7
  app = Flask(__name__)
8
 
9
- # --- MODEL SETUP ---
10
  MODEL_NAME = "SEBIS/code_trans_t5_base_commit_generation"
11
  print(f"--- AI Commit Generator Server ---")
12
  print(f"Downloading/Loading Model: {MODEL_NAME}")
@@ -23,7 +23,7 @@ except Exception as e:
23
 
24
  def preprocess_diff(diff_text):
25
  """
26
- Cleans the diff to remove metadata and save token space.
27
  """
28
  if not diff_text:
29
  return ""
@@ -32,75 +32,76 @@ def preprocess_diff(diff_text):
32
  cleaned_lines = []
33
 
34
  for line in lines:
35
- # We only care about changes
36
  if (line.startswith('+') or line.startswith('-')):
37
- # Skip metadata +++ / ---
38
- if line.startswith('+++') or line.startswith('---'):
39
- continue
 
40
 
41
- # Clean generic import lines which confuse the model
42
- if "import " in line or "require(" in line:
43
- continue
44
-
45
  cleaned_lines.append(line.strip())
46
 
47
  return "\n".join(cleaned_lines)
48
 
49
- def is_hallucination(summary, diff_text):
50
  """
51
- Returns True if the summary contains known hallucination patterns.
52
  """
53
- summary_lower = summary.lower()
 
 
 
 
 
54
 
55
- # 1. Linguistic nonsense
56
- forbidden_terms = [
57
- "transitive verb", "intransitive verb", "adjective",
58
- "noun", "pronoun", "metrics collection", "data volume"
59
- ]
60
- if any(term in summary_lower for term in forbidden_terms):
61
- return True
62
-
63
- # 2. Random Jira Tickets (e.g. STORM-123) that are NOT in the diff
64
- ticket_pattern = re.compile(r'\b[A-Z]{2,}-\d+\b')
65
- match = ticket_pattern.search(summary)
66
  if match:
67
  ticket = match.group()
 
68
  if ticket not in diff_text:
69
- return True
70
-
71
- return False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
 
73
  def generate_summary(diff_text, filename):
74
- # Preprocess
75
  cleaned_diff = preprocess_diff(diff_text)
76
 
77
- # If the diff is just imports or too small, don't ask the AI
78
- if not cleaned_diff or len(cleaned_diff) < 15:
79
  return f"Update {filename}"
80
 
81
- # Tokenize
82
  input_ids = tokenizer.encode(cleaned_diff, return_tensors="pt", max_length=512, truncation=True).to(device)
83
 
84
  # Generate
85
  outputs = model.generate(
86
  input_ids,
87
- max_length=50, # Shorter max length to prevent rambling
88
- min_length=3,
89
  num_beams=5,
90
  early_stopping=True
91
  )
92
 
93
- summary = tokenizer.decode(outputs[0], skip_special_tokens=True)
94
 
95
- # Validate Output
96
- if is_hallucination(summary, diff_text):
97
- print(f"⚠️ Hallucination caught: '{summary}' -> Reverting to default.")
98
- return f"Update {filename} logic"
99
-
100
- if not summary.strip():
101
- return f"Modify {filename}"
102
-
103
- return summary
104
 
105
  @app.route('/generate', methods=['POST'])
106
  def generate_commit():
@@ -116,9 +117,9 @@ def generate_commit():
116
  name = file_obj.get('name', 'file')
117
  diff = file_obj.get('diff', '')
118
 
119
- # Guard for massive files
120
  if len(diff) > 12000:
121
- final_message_parts.append(f"{name}\nLarge update (chunked)")
122
  continue
123
 
124
  try:
@@ -126,7 +127,7 @@ def generate_commit():
126
  final_message_parts.append(f"{name}\n{summary}")
127
  except Exception as e:
128
  print(f"Error processing {name}: {e}")
129
- final_message_parts.append(f"{name}\nUpdate changes")
130
 
131
  return jsonify({"commit_message": "\n\n".join(final_message_parts)})
132
 
 
6
 
7
  app = Flask(__name__)
8
 
9
+ # --- MODEL LOADING ---
10
  MODEL_NAME = "SEBIS/code_trans_t5_base_commit_generation"
11
  print(f"--- AI Commit Generator Server ---")
12
  print(f"Downloading/Loading Model: {MODEL_NAME}")
 
23
 
24
  def preprocess_diff(diff_text):
25
  """
26
+ Strips all metadata to ensure the model focuses ONLY on code changes.
27
  """
28
  if not diff_text:
29
  return ""
 
32
  cleaned_lines = []
33
 
34
  for line in lines:
35
+ # Keep only added (+) or removed (-) lines
36
  if (line.startswith('+') or line.startswith('-')):
37
+ # Remove metadata markers and noisy imports
38
+ if line.startswith('+++') or line.startswith('---'): continue
39
+ if "import " in line or "require(" in line: continue
40
+ if len(line.strip()) < 5: continue # Skip braces/empty lines
41
 
 
 
 
 
42
  cleaned_lines.append(line.strip())
43
 
44
  return "\n".join(cleaned_lines)
45
 
46
+ def sanitize_summary(summary, diff_text, filename):
47
  """
48
+ The 'Scorched Earth' filter. If it smells like a hallucination, kill it.
49
  """
50
+ summary_clean = summary.strip()
51
+
52
+ # 1. Catch Jira Tickets (e.g., STORM-1404, JIRA - 123)
53
+ # The regex allows for optional spaces around the hyphen
54
+ ticket_pattern = re.compile(r'\b[A-Z]{3,}\s?-\s?\d+\b')
55
+ match = ticket_pattern.search(summary_clean)
56
 
 
 
 
 
 
 
 
 
 
 
 
57
  if match:
58
  ticket = match.group()
59
+ # If this exact ticket string isn't in the source code, it's fake.
60
  if ticket not in diff_text:
61
+ print(f"⚠️ Hallucination Killed: '{ticket}' in '{filename}'")
62
+ return f"Update {filename}"
63
+
64
+ # 2. Catch Linguistic Nonsense
65
+ forbidden_words = [
66
+ "transitive verb", "intransitive", "adjective",
67
+ "CHANGELOG", "readme", "documentation"
68
+ ]
69
+
70
+ # Only block "CHANGELOG" if the file itself isn't a changelog
71
+ if "changelog" not in filename.lower():
72
+ for word in forbidden_words:
73
+ if word in summary_clean.lower():
74
+ print(f"⚠️ Nonsense Killed: '{word}' in '{filename}'")
75
+ return f"Update {filename} logic"
76
+
77
+ return summary_clean
78
 
79
  def generate_summary(diff_text, filename):
80
+ # Aggressively clean the input
81
  cleaned_diff = preprocess_diff(diff_text)
82
 
83
+ # If the diff is too small (e.g., just whitespace), skip the AI
84
+ if not cleaned_diff or len(cleaned_diff) < 20:
85
  return f"Update {filename}"
86
 
87
+ # Encode
88
  input_ids = tokenizer.encode(cleaned_diff, return_tensors="pt", max_length=512, truncation=True).to(device)
89
 
90
  # Generate
91
  outputs = model.generate(
92
  input_ids,
93
+ max_length=60,
94
+ min_length=5,
95
  num_beams=5,
96
  early_stopping=True
97
  )
98
 
99
+ raw_summary = tokenizer.decode(outputs[0], skip_special_tokens=True)
100
 
101
+ # Apply the Sanitizer
102
+ final_summary = sanitize_summary(raw_summary, diff_text, filename)
103
+
104
+ return final_summary
 
 
 
 
 
105
 
106
  @app.route('/generate', methods=['POST'])
107
  def generate_commit():
 
117
  name = file_obj.get('name', 'file')
118
  diff = file_obj.get('diff', '')
119
 
120
+ # Hard limit on huge files
121
  if len(diff) > 12000:
122
+ final_message_parts.append(f"{name}\nUpdate large file (chunked)")
123
  continue
124
 
125
  try:
 
127
  final_message_parts.append(f"{name}\n{summary}")
128
  except Exception as e:
129
  print(f"Error processing {name}: {e}")
130
+ final_message_parts.append(f"{name}\nRefactor code")
131
 
132
  return jsonify({"commit_message": "\n\n".join(final_message_parts)})
133