NiviruIns commited on
Commit
0857640
·
verified ·
1 Parent(s): 7e3ddbd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +53 -40
app.py CHANGED
@@ -6,16 +6,14 @@ import torch
6
 
7
  app = Flask(__name__)
8
 
9
- # --- SWITCH TO THE EXPERT MODEL ---
10
  MODEL_NAME = "SEBIS/code_trans_t5_base_commit_generation"
11
-
12
  print(f"--- AI Commit Generator Server ---")
13
  print(f"Downloading/Loading Model: {MODEL_NAME}")
14
 
15
  device = "cpu"
16
 
17
  try:
18
- # AutoTokenizer handles the specific needs of this model automatically
19
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, skip_special_tokens=True)
20
  model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME).to(device)
21
  print("✅ Model loaded successfully!")
@@ -25,7 +23,7 @@ except Exception as e:
25
 
26
  def preprocess_diff(diff_text):
27
  """
28
- Aggressively cleans the diff to keep ONLY the changes.
29
  """
30
  if not diff_text:
31
  return ""
@@ -34,23 +32,51 @@ def preprocess_diff(diff_text):
34
  cleaned_lines = []
35
 
36
  for line in lines:
37
- # Only keep lines that are actual additions/deletions
38
- # checking length > 1 to avoid empty '+' or '-' lines
39
- if (line.startswith('+') or line.startswith('-')) and len(line.strip()) > 1:
40
- # Skip metadata lines starting with +++ or ---
41
  if line.startswith('+++') or line.startswith('---'):
42
  continue
43
- cleaned_lines.append(line)
 
 
 
 
 
44
 
45
  return "\n".join(cleaned_lines)
46
 
47
- def generate_summary(diff_text):
48
- # Preprocess to get pure code changes
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  cleaned_diff = preprocess_diff(diff_text)
50
 
51
- # If cleaning removed everything (e.g., only whitespace changes), fallback
52
- if not cleaned_diff or len(cleaned_diff.strip()) < 10:
53
- return "Update logic"
54
 
55
  # Tokenize
56
  input_ids = tokenizer.encode(cleaned_diff, return_tensors="pt", max_length=512, truncation=True).to(device)
@@ -58,32 +84,21 @@ def generate_summary(diff_text):
58
  # Generate
59
  outputs = model.generate(
60
  input_ids,
61
- max_length=80,
62
- min_length=5,
63
  num_beams=5,
64
- repetition_penalty=1.5, # Increased penalty to stop loops
65
- no_repeat_ngram_size=2,
66
  early_stopping=True
67
  )
68
 
69
  summary = tokenizer.decode(outputs[0], skip_special_tokens=True)
70
 
71
- # --- HALLUCINATION GUARD ---
72
- # Check for random Jira tickets (e.g., STORM-236, PROJ-123)
73
- # Pattern: Uppercase letters, hyphen, numbers
74
- ticket_pattern = re.compile(r'\b[A-Z]{2,}-\d+\b')
75
- match = ticket_pattern.search(summary)
76
-
77
- if match:
78
- found_ticket = match.group()
79
- # If the ticket ID is NOT in the source code, it's a hallucination
80
- if found_ticket not in diff_text:
81
- print(f"⚠️ Detected hallucination ({found_ticket}). Reverting to fallback.")
82
- return "Refactor code and logic"
83
-
84
- # Fallback if model yields empty string
85
  if not summary.strip():
86
- return "Update logic"
87
 
88
  return summary
89
 
@@ -98,22 +113,20 @@ def generate_commit():
98
  final_message_parts = []
99
 
100
  for file_obj in files:
101
- name = file_obj.get('name', 'Unknown File')
102
  diff = file_obj.get('diff', '')
103
 
104
- print(f"[{name}] Length: {len(diff)}")
105
-
106
- # Guard against massive files
107
  if len(diff) > 12000:
108
- final_message_parts.append(f"{name}\nLarge changes detected (please commit in smaller chunks)")
109
  continue
110
 
111
  try:
112
- summary = generate_summary(diff)
113
  final_message_parts.append(f"{name}\n{summary}")
114
  except Exception as e:
115
  print(f"Error processing {name}: {e}")
116
- final_message_parts.append(f"{name}\nUpdate file")
117
 
118
  return jsonify({"commit_message": "\n\n".join(final_message_parts)})
119
 
 
6
 
7
  app = Flask(__name__)
8
 
9
+ # --- MODEL SETUP ---
10
  MODEL_NAME = "SEBIS/code_trans_t5_base_commit_generation"
 
11
  print(f"--- AI Commit Generator Server ---")
12
  print(f"Downloading/Loading Model: {MODEL_NAME}")
13
 
14
  device = "cpu"
15
 
16
  try:
 
17
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, skip_special_tokens=True)
18
  model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME).to(device)
19
  print("✅ Model loaded successfully!")
 
23
 
24
  def preprocess_diff(diff_text):
25
  """
26
+ Cleans the diff to remove metadata and save token space.
27
  """
28
  if not diff_text:
29
  return ""
 
32
  cleaned_lines = []
33
 
34
  for line in lines:
35
+ # We only care about changes
36
+ if (line.startswith('+') or line.startswith('-')):
37
+ # Skip metadata +++ / ---
 
38
  if line.startswith('+++') or line.startswith('---'):
39
  continue
40
+
41
+ # Clean generic import lines which confuse the model
42
+ if "import " in line or "require(" in line:
43
+ continue
44
+
45
+ cleaned_lines.append(line.strip())
46
 
47
  return "\n".join(cleaned_lines)
48
 
49
+ def is_hallucination(summary, diff_text):
50
+ """
51
+ Returns True if the summary contains known hallucination patterns.
52
+ """
53
+ summary_lower = summary.lower()
54
+
55
+ # 1. Linguistic nonsense
56
+ forbidden_terms = [
57
+ "transitive verb", "intransitive verb", "adjective",
58
+ "noun", "pronoun", "metrics collection", "data volume"
59
+ ]
60
+ if any(term in summary_lower for term in forbidden_terms):
61
+ return True
62
+
63
+ # 2. Random Jira Tickets (e.g. STORM-123) that are NOT in the diff
64
+ ticket_pattern = re.compile(r'\b[A-Z]{2,}-\d+\b')
65
+ match = ticket_pattern.search(summary)
66
+ if match:
67
+ ticket = match.group()
68
+ if ticket not in diff_text:
69
+ return True
70
+
71
+ return False
72
+
73
+ def generate_summary(diff_text, filename):
74
+ # Preprocess
75
  cleaned_diff = preprocess_diff(diff_text)
76
 
77
+ # If the diff is just imports or too small, don't ask the AI
78
+ if not cleaned_diff or len(cleaned_diff) < 15:
79
+ return f"Update {filename}"
80
 
81
  # Tokenize
82
  input_ids = tokenizer.encode(cleaned_diff, return_tensors="pt", max_length=512, truncation=True).to(device)
 
84
  # Generate
85
  outputs = model.generate(
86
  input_ids,
87
+ max_length=50, # Shorter max length to prevent rambling
88
+ min_length=3,
89
  num_beams=5,
 
 
90
  early_stopping=True
91
  )
92
 
93
  summary = tokenizer.decode(outputs[0], skip_special_tokens=True)
94
 
95
+ # Validate Output
96
+ if is_hallucination(summary, diff_text):
97
+ print(f"⚠️ Hallucination caught: '{summary}' -> Reverting to default.")
98
+ return f"Update {filename} logic"
99
+
 
 
 
 
 
 
 
 
 
100
  if not summary.strip():
101
+ return f"Modify {filename}"
102
 
103
  return summary
104
 
 
113
  final_message_parts = []
114
 
115
  for file_obj in files:
116
+ name = file_obj.get('name', 'file')
117
  diff = file_obj.get('diff', '')
118
 
119
+ # Guard for massive files
 
 
120
  if len(diff) > 12000:
121
+ final_message_parts.append(f"{name}\nLarge update (chunked)")
122
  continue
123
 
124
  try:
125
+ summary = generate_summary(diff, name)
126
  final_message_parts.append(f"{name}\n{summary}")
127
  except Exception as e:
128
  print(f"Error processing {name}: {e}")
129
+ final_message_parts.append(f"{name}\nUpdate changes")
130
 
131
  return jsonify({"commit_message": "\n\n".join(final_message_parts)})
132