NiviruIns commited on
Commit
d8b7758
·
verified ·
1 Parent(s): fddb21f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -23
app.py CHANGED
@@ -1,25 +1,24 @@
1
  import os
2
  from flask import Flask, request, jsonify
3
- from transformers import RobertaTokenizer, T5ForConditionalGeneration
4
  import torch
5
 
6
  app = Flask(__name__)
7
 
8
- # --- CRITICAL CHANGE ---
9
- # Instead of your local folder, we point to a PUBLIC Expert Model
10
- # This model has read millions of commit messages and knows exactly what to do.
11
- MODEL_NAME = "ncoop57/commit-t5"
12
 
13
  print(f"--- AI Commit Generator Server ---")
14
- print(f"Downloading/Loading Expert Model: {MODEL_NAME}")
15
 
16
- device = "cpu" # HF Spaces free tier is CPU
17
 
18
  try:
19
- # This will download the model automatically the first time it runs
20
- tokenizer = RobertaTokenizer.from_pretrained(MODEL_NAME)
21
- model = T5ForConditionalGeneration.from_pretrained(MODEL_NAME).to(device)
22
- print("✅ Expert Model loaded successfully!")
23
  except Exception as e:
24
  print(f"❌ Error loading model: {e}")
25
  exit(1)
@@ -28,19 +27,14 @@ def generate_summary(diff_text):
28
  if not diff_text or len(diff_text.strip()) < 5:
29
  return "Update file"
30
 
31
- # Preprocess: "commit-t5" just expects the raw code diff, no "Summarize:" prefix needed usually,
32
- # but let's keep it simple.
33
- input_text = diff_text + " </s>"
34
-
35
- input_ids = tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True).input_ids.to(device)
36
 
37
  outputs = model.generate(
38
  input_ids,
39
- max_length=80, # Commits are usually short
40
- min_length=5,
41
  num_beams=5,
42
- early_stopping=True,
43
- no_repeat_ngram_size=2 # Stops it from saying "update update update"
44
  )
45
  return tokenizer.decode(outputs[0], skip_special_tokens=True)
46
 
@@ -58,14 +52,13 @@ def generate_commit():
58
  name = file_obj.get('name', 'Unknown File')
59
  diff = file_obj.get('diff', '')
60
 
61
- # Skip binary files or huge diffs
62
- if len(diff) > 4000:
63
  final_message_parts.append(f"{name}\nLarge changes detected")
64
  continue
65
 
66
  try:
67
  summary = generate_summary(diff)
68
- # Format: File Name -> The generated message
69
  final_message_parts.append(f"{name}\n{summary}")
70
  except Exception as e:
71
  print(f"Error processing {name}: {e}")
 
1
  import os
2
  from flask import Flask, request, jsonify
3
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
4
  import torch
5
 
6
  app = Flask(__name__)
7
 
8
+ # --- UPDATED MODEL ---
9
+ # This model is specifically trained for git commit generation and is active.
10
+ MODEL_NAME = "SEBIS/code_trans_t5_base_commit_generation"
 
11
 
12
  print(f"--- AI Commit Generator Server ---")
13
+ print(f"Downloading/Loading Model: {MODEL_NAME}")
14
 
15
+ device = "cpu"
16
 
17
  try:
18
+ # Use AutoTokenizer and AutoModelForSeq2SeqLM for better compatibility
19
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, skip_special_tokens=True)
20
+ model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME).to(device)
21
+ print("✅ Model loaded successfully!")
22
  except Exception as e:
23
  print(f"❌ Error loading model: {e}")
24
  exit(1)
 
27
  if not diff_text or len(diff_text.strip()) < 5:
28
  return "Update file"
29
 
30
+ # This model works best with raw code, but we tokenize it first
31
+ input_ids = tokenizer.encode(diff_text, return_tensors="pt", max_length=512, truncation=True).to(device)
 
 
 
32
 
33
  outputs = model.generate(
34
  input_ids,
35
+ max_length=80,
 
36
  num_beams=5,
37
+ early_stopping=True
 
38
  )
39
  return tokenizer.decode(outputs[0], skip_special_tokens=True)
40
 
 
52
  name = file_obj.get('name', 'Unknown File')
53
  diff = file_obj.get('diff', '')
54
 
55
+ # Skip huge files to prevent crashing CPU
56
+ if len(diff) > 6000:
57
  final_message_parts.append(f"{name}\nLarge changes detected")
58
  continue
59
 
60
  try:
61
  summary = generate_summary(diff)
 
62
  final_message_parts.append(f"{name}\n{summary}")
63
  except Exception as e:
64
  print(f"Error processing {name}: {e}")