NiviruIns commited on
Commit
a052544
·
verified ·
1 Parent(s): eb3184c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -4
app.py CHANGED
@@ -1,4 +1,5 @@
1
  import os
 
2
  from flask import Flask, request, jsonify
3
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
4
  import torch
@@ -22,20 +23,57 @@ except Exception as e:
22
  print(f"❌ Error loading model: {e}")
23
  exit(1)
24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  def generate_summary(diff_text):
26
- if not diff_text or len(diff_text.strip()) < 5:
 
 
 
27
  return "Update file"
28
 
29
- # The Expert model just needs the raw diff. No "Summarize:" prefix needed.
30
- input_ids = tokenizer.encode(diff_text, return_tensors="pt", max_length=512, truncation=True).to(device)
31
 
 
32
  outputs = model.generate(
33
  input_ids,
34
  max_length=80,
 
35
  num_beams=5,
 
 
36
  early_stopping=True
37
  )
38
- return tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
 
 
 
 
 
39
 
40
  @app.route('/generate', methods=['POST'])
41
  def generate_commit():
 
1
  import os
2
+ import re
3
  from flask import Flask, request, jsonify
4
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
5
  import torch
 
23
  print(f"❌ Error loading model: {e}")
24
  exit(1)
25
 
26
+ def preprocess_diff(diff_text):
27
+ """
28
+ Cleans the diff to remove git metadata and save token space for the actual code.
29
+ """
30
+ if not diff_text:
31
+ return ""
32
+
33
+ lines = diff_text.split('\n')
34
+ cleaned_lines = []
35
+
36
+ for line in lines:
37
+ # Remove git metadata lines
38
+ if line.startswith('diff --git') or line.startswith('index ') or line.startswith('+++') or line.startswith('---'):
39
+ continue
40
+ # Remove chunk headers like @@ -1,4 +1,5 @@
41
+ if line.startswith('@@'):
42
+ continue
43
+
44
+ cleaned_lines.append(line)
45
+
46
+ # Join and ensure we don't send an empty string
47
+ return "\n".join(cleaned_lines)
48
+
49
  def generate_summary(diff_text):
50
+ # Preprocess to get pure code changes
51
+ cleaned_diff = preprocess_diff(diff_text)
52
+
53
+ if not cleaned_diff or len(cleaned_diff.strip()) < 5:
54
  return "Update file"
55
 
56
+ # Tokenize
57
+ input_ids = tokenizer.encode(cleaned_diff, return_tensors="pt", max_length=512, truncation=True).to(device)
58
 
59
+ # Generate with better parameters to reduce "dumb" hallucinations
60
  outputs = model.generate(
61
  input_ids,
62
  max_length=80,
63
+ min_length=5,
64
  num_beams=5,
65
+ repetition_penalty=1.2, # Penalize repetition
66
+ no_repeat_ngram_size=2, # Prevent repeating phrases
67
  early_stopping=True
68
  )
69
+
70
+ summary = tokenizer.decode(outputs[0], skip_special_tokens=True)
71
+
72
+ # Fallback if model yields empty string
73
+ if not summary.strip():
74
+ return "Update logic"
75
+
76
+ return summary
77
 
78
  @app.route('/generate', methods=['POST'])
79
  def generate_commit():