NiviruIns commited on
Commit
1a06520
·
verified ·
1 Parent(s): cda6349

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -44
app.py CHANGED
@@ -21,73 +21,70 @@ except Exception as e:
21
  print(f"❌ Error loading model: {e}")
22
  exit(1)
23
 
24
- def preprocess_diff(diff_text):
25
  """
26
- Strips all metadata to ensure the model focuses ONLY on code changes.
27
  """
28
- if not diff_text:
29
- return ""
30
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  lines = diff_text.split('\n')
32
  cleaned_lines = []
33
-
34
  for line in lines:
35
- # Keep only added (+) or removed (-) lines
36
  if (line.startswith('+') or line.startswith('-')):
37
- # Remove metadata markers and noisy imports
38
  if line.startswith('+++') or line.startswith('---'): continue
39
  if "import " in line or "require(" in line: continue
40
- if len(line.strip()) < 5: continue # Skip braces/empty lines
41
-
42
  cleaned_lines.append(line.strip())
43
-
44
  return "\n".join(cleaned_lines)
45
 
46
  def sanitize_summary(summary, diff_text, filename):
47
- """
48
- The 'Scorched Earth' filter. If it smells like a hallucination, kill it.
49
- """
50
  summary_clean = summary.strip()
51
 
52
- # 1. Catch Jira Tickets (e.g., STORM-1404, JIRA - 123)
53
- # The regex allows for optional spaces around the hyphen
54
  ticket_pattern = re.compile(r'\b[A-Z]{3,}\s?-\s?\d+\b')
55
  match = ticket_pattern.search(summary_clean)
56
 
57
  if match:
58
  ticket = match.group()
59
- # If this exact ticket string isn't in the source code, it's fake.
60
  if ticket not in diff_text:
61
- print(f"⚠️ Hallucination Killed: '{ticket}' in '{filename}'")
62
- return f"Update {filename}"
63
 
64
  # 2. Catch Linguistic Nonsense
65
- forbidden_words = [
66
- "transitive verb", "intransitive", "adjective",
67
- "CHANGELOG", "readme", "documentation"
68
- ]
69
-
70
- # Only block "CHANGELOG" if the file itself isn't a changelog
71
- if "changelog" not in filename.lower():
72
- for word in forbidden_words:
73
- if word in summary_clean.lower():
74
- print(f"⚠️ Nonsense Killed: '{word}' in '{filename}'")
75
- return f"Update {filename} logic"
76
 
77
  return summary_clean
78
 
79
  def generate_summary(diff_text, filename):
80
- # Aggressively clean the input
81
  cleaned_diff = preprocess_diff(diff_text)
82
 
83
- # If the diff is too small (e.g., just whitespace), skip the AI
84
  if not cleaned_diff or len(cleaned_diff) < 20:
85
- return f"Update {filename}"
86
 
87
- # Encode
88
  input_ids = tokenizer.encode(cleaned_diff, return_tensors="pt", max_length=512, truncation=True).to(device)
89
 
90
- # Generate
91
  outputs = model.generate(
92
  input_ids,
93
  max_length=60,
@@ -97,19 +94,14 @@ def generate_summary(diff_text, filename):
97
  )
98
 
99
  raw_summary = tokenizer.decode(outputs[0], skip_special_tokens=True)
100
-
101
- # Apply the Sanitizer
102
- final_summary = sanitize_summary(raw_summary, diff_text, filename)
103
-
104
- return final_summary
105
 
106
  @app.route('/generate', methods=['POST'])
107
  def generate_commit():
108
  data = request.json
109
  files = data.get('files', [])
110
 
111
- if not files:
112
- return jsonify({"commit_message": ""})
113
 
114
  final_message_parts = []
115
 
@@ -117,7 +109,6 @@ def generate_commit():
117
  name = file_obj.get('name', 'file')
118
  diff = file_obj.get('diff', '')
119
 
120
- # Hard limit on huge files
121
  if len(diff) > 12000:
122
  final_message_parts.append(f"{name}\nUpdate large file (chunked)")
123
  continue
@@ -127,7 +118,7 @@ def generate_commit():
127
  final_message_parts.append(f"{name}\n{summary}")
128
  except Exception as e:
129
  print(f"Error processing {name}: {e}")
130
- final_message_parts.append(f"{name}\nRefactor code")
131
 
132
  return jsonify({"commit_message": "\n\n".join(final_message_parts)})
133
 
 
21
  print(f"❌ Error loading model: {e}")
22
  exit(1)
23
 
24
+ def get_smart_fallback(diff_text, filename):
25
  """
26
+ If AI fails, look at the code to see WHICH function was touched.
27
  """
28
+ # Look for function definitions or modifications
29
+ # Regex matches: function name(), const name =, class Name
30
+ patterns = [
31
+ r'function\s+([a-zA-Z0-9_]+)',
32
+ r'const\s+([a-zA-Z0-9_]+)\s*=',
33
+ r'let\s+([a-zA-Z0-9_]+)\s*=',
34
+ r'class\s+([a-zA-Z0-9_]+)',
35
+ r'def\s+([a-zA-Z0-9_]+)'
36
+ ]
37
+
38
+ for pattern in patterns:
39
+ match = re.search(pattern, diff_text)
40
+ if match:
41
+ func_name = match.group(1)
42
+ return f"Refactor '{func_name}' in {filename}"
43
+
44
+ return f"Update logic in {filename}"
45
+
46
+ def preprocess_diff(diff_text):
47
+ if not diff_text: return ""
48
  lines = diff_text.split('\n')
49
  cleaned_lines = []
 
50
  for line in lines:
 
51
  if (line.startswith('+') or line.startswith('-')):
 
52
  if line.startswith('+++') or line.startswith('---'): continue
53
  if "import " in line or "require(" in line: continue
54
+ if len(line.strip()) < 5: continue
 
55
  cleaned_lines.append(line.strip())
 
56
  return "\n".join(cleaned_lines)
57
 
58
  def sanitize_summary(summary, diff_text, filename):
 
 
 
59
  summary_clean = summary.strip()
60
 
61
+ # 1. Catch Hallucinated Jira Tickets
 
62
  ticket_pattern = re.compile(r'\b[A-Z]{3,}\s?-\s?\d+\b')
63
  match = ticket_pattern.search(summary_clean)
64
 
65
  if match:
66
  ticket = match.group()
 
67
  if ticket not in diff_text:
68
+ print(f"⚠️ Hallucination Killed: '{ticket}' -> Switching to Smart Fallback")
69
+ return get_smart_fallback(diff_text, filename)
70
 
71
  # 2. Catch Linguistic Nonsense
72
+ forbidden_words = ["transitive verb", "intransitive", "adjective"]
73
+ for word in forbidden_words:
74
+ if word in summary_clean.lower():
75
+ print(f"⚠️ Nonsense Killed: '{word}' -> Switching to Smart Fallback")
76
+ return get_smart_fallback(diff_text, filename)
 
 
 
 
 
 
77
 
78
  return summary_clean
79
 
80
  def generate_summary(diff_text, filename):
 
81
  cleaned_diff = preprocess_diff(diff_text)
82
 
 
83
  if not cleaned_diff or len(cleaned_diff) < 20:
84
+ return get_smart_fallback(diff_text, filename)
85
 
 
86
  input_ids = tokenizer.encode(cleaned_diff, return_tensors="pt", max_length=512, truncation=True).to(device)
87
 
 
88
  outputs = model.generate(
89
  input_ids,
90
  max_length=60,
 
94
  )
95
 
96
  raw_summary = tokenizer.decode(outputs[0], skip_special_tokens=True)
97
+ return sanitize_summary(raw_summary, diff_text, filename)
 
 
 
 
98
 
99
  @app.route('/generate', methods=['POST'])
100
  def generate_commit():
101
  data = request.json
102
  files = data.get('files', [])
103
 
104
+ if not files: return jsonify({"commit_message": ""})
 
105
 
106
  final_message_parts = []
107
 
 
109
  name = file_obj.get('name', 'file')
110
  diff = file_obj.get('diff', '')
111
 
 
112
  if len(diff) > 12000:
113
  final_message_parts.append(f"{name}\nUpdate large file (chunked)")
114
  continue
 
118
  final_message_parts.append(f"{name}\n{summary}")
119
  except Exception as e:
120
  print(f"Error processing {name}: {e}")
121
+ final_message_parts.append(f"{name}\nUpdate file")
122
 
123
  return jsonify({"commit_message": "\n\n".join(final_message_parts)})
124