luxopes commited on
Commit
670c60e
·
verified ·
1 Parent(s): abc23bf

Update server.py

Browse files
Files changed (1) hide show
  1. server.py +71 -66
server.py CHANGED
@@ -1,52 +1,65 @@
1
  import torch
2
  import re
3
  from html import unescape
4
- from transformers import GPT2LMHeadModel, GPT2Tokenizer, BitsAndBytesConfig
5
  from peft import PeftModel
6
  from transformers import StoppingCriteria, StoppingCriteriaList
7
  from difflib import SequenceMatcher
8
  from flask import Flask, request, jsonify
9
 
10
- # Step 2: Load tokenizer
 
 
 
 
 
 
 
 
11
  model_path = "./"
12
  try:
13
  tokenizer = GPT2Tokenizer.from_pretrained(model_path)
14
  tokenizer.pad_token = tokenizer.eos_token
15
- print("Tokenizer loaded successfully")
16
  except Exception as e:
17
- print(f"Error loading tokenizer: {e}")
18
  exit()
19
 
20
- # Step 3: Load model (kvantizace + fallback)
21
- quant_config = BitsAndBytesConfig(
22
- load_in_4bit=True,
23
- bnb_4bit_use_double_quant=True,
24
- bnb_4bit_quant_type="nf4",
25
- bnb_4bit_compute_dtype=torch.bfloat16
26
- )
 
 
 
 
 
 
 
 
 
 
 
27
 
28
  try:
29
  base_model = GPT2LMHeadModel.from_pretrained(
30
  model_path,
31
  quantization_config=quant_config,
32
- device_map={"": 0},
33
- low_cpu_mem_usage=True
34
- )
35
- print("Base model loaded successfully (4bit quantized)")
 
36
  except Exception as e:
37
- print(f"Error loading base model: {e}")
38
- try:
39
- base_model = GPT2LMHeadModel.from_pretrained(
40
- model_path,
41
- low_cpu_mem_usage=True,
42
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
43
- ).to("cuda:0" if torch.cuda.is_available() else "cpu")
44
- print("Base model loaded without quantization")
45
- except Exception as e:
46
- print(f"Error loading base model without quantization: {e}")
47
- exit()
48
 
49
- # Step 4: Load PEFT (LoRA)
 
 
50
  try:
51
  model = PeftModel.from_pretrained(
52
  base_model,
@@ -54,16 +67,21 @@ try:
54
  is_trainable=False,
55
  device_map={"": 0} if torch.cuda.is_available() else None
56
  )
57
- print("PEFT model loaded successfully")
 
58
  except Exception as e:
59
- print(f"Error loading PEFT model: {e}")
60
- exit()
61
 
 
62
  # Step 5: System prompt
 
63
  system_prompt = """You are TinyGPT, a friendly AI assistant made by LuxAI.
64
  You must answer very short."""
65
 
 
66
  # Step 6: Stopping criteria
 
67
  class CustomStoppingCriteria(StoppingCriteria):
68
  def __init__(self, stop_token_id):
69
  self.stop_token_id = stop_token_id
@@ -73,59 +91,47 @@ class CustomStoppingCriteria(StoppingCriteria):
73
 
74
  stopping_criteria = StoppingCriteriaList([CustomStoppingCriteria(tokenizer.eos_token_id)])
75
 
76
-
77
- # Step 6.5: Čisticí a kontrolní funkce
 
78
  def clean_response(text):
79
  """Odstraní HTML, Markdown a redundantní mezery."""
80
  original_text = text
81
- text = re.sub(r"<[^>]+>", " ", text) # odstraní HTML tagy
82
- text = unescape(text) # dekóduje HTML entity
83
- text = re.sub(r"[*#`_~]+", "", text) # odstraní markdown znaky
84
  text = re.sub(r"\s+", " ", text).strip()
85
  if text != original_text:
86
- print("🧹 Response cleaned from HTML/Markdown artifacts.")
87
  return text
88
 
89
 
90
  def remove_repetitions(text, similarity_threshold=0.8):
91
- """
92
- Pokud se opakují stejné věty (např. 'I'm TinyGPT...' 8x),
93
- ponechá pouze první.
94
- """
95
  sentences = re.split(r'(?<=[.!?])\s+', text)
96
  if len(sentences) <= 1:
97
  return text
98
-
99
  unique_sentences = []
100
  for sent in sentences:
101
  sent_clean = sent.strip()
102
  if not sent_clean:
103
  continue
104
- if not unique_sentences:
105
- unique_sentences.append(sent_clean)
106
- continue
107
- ratio = SequenceMatcher(None, sent_clean, unique_sentences[-1]).ratio()
108
- if ratio < similarity_threshold:
109
  unique_sentences.append(sent_clean)
110
-
111
- if len(unique_sentences) < len(sentences):
112
- print("🧩 Repetitive content detected and reduced.")
113
  return " ".join(unique_sentences)
114
 
 
115
  def truncate_to_last_sentence(text):
116
  """Zkrátí text na poslední dokončenou větu."""
117
  sentences = re.split(r'(?<=[.!?])\s+', text)
118
- if sentences and sentences[-1].strip():
119
- # Najde poslední větu, která končí na . ? !
120
- for i in range(len(sentences) - 1, -1, -1):
121
- if re.search(r'[.!?]$', sentences[i].strip()):
122
- return " ".join(sentences[:i+1]).strip()
123
- # If no sentence ends with . ? !, return the whole text after cleaning
124
- return text.strip()
125
  return text.strip()
126
 
127
-
128
  # Step 7: Generování odpovědi
 
129
  def generate_response(
130
  user_input,
131
  max_length=2048,
@@ -139,9 +145,8 @@ def generate_response(
139
  ):
140
  try:
141
  prompt = f"{system_prompt}\n\nUser: {user_input}\nAssistant:"
142
- device = "cuda:0" if torch.cuda.is_available() else "cpu"
143
  inputs = tokenizer(prompt, return_tensors="pt").to(device)
144
- print(f"Input device: {inputs['input_ids'].device}")
145
 
146
  with torch.no_grad():
147
  outputs = model.generate(
@@ -163,25 +168,23 @@ def generate_response(
163
  generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
164
  response = generated_text.split("Assistant:")[-1].strip()
165
 
166
- # Vyčištění, odstranění opakování a zkrácení na poslední větu
167
  response = clean_response(response)
168
  response = remove_repetitions(response)
169
  response = truncate_to_last_sentence(response)
170
 
171
-
172
  return response
173
 
174
  except Exception as e:
175
- print(f"Error during generation: {e}")
176
  return None
177
 
178
- # Step 8: Initialize Flask app
 
 
179
  app = Flask(__name__)
180
 
181
- # Step 9: Define API endpoint
182
  @app.route('/generate', methods=['POST'])
183
  def generate_text():
184
- # Step 10: Get input, generate response, return JSON
185
  data = request.get_json()
186
  if not data or 'user_input' not in data:
187
  return jsonify({'error': 'Missing user_input parameter'}), 400
@@ -194,6 +197,8 @@ def generate_text():
194
 
195
  return jsonify({'response': generated_response})
196
 
197
- # Step 11: Run the Flask app
 
 
198
  if __name__ == '__main__':
199
  app.run(host='0.0.0.0', port=7860)
 
1
  import torch
2
  import re
3
  from html import unescape
4
+ from transformers import GPT2LMHeadModel, GPT2Tokenizer
5
  from peft import PeftModel
6
  from transformers import StoppingCriteria, StoppingCriteriaList
7
  from difflib import SequenceMatcher
8
  from flask import Flask, request, jsonify
9
 
10
+ # --------------------------
11
+ # Step 1: Nastavení zařízení
12
+ # --------------------------
13
+ device = "cuda" if torch.cuda.is_available() else "cpu"
14
+ print(f"🚀 Running on device: {device}")
15
+
16
+ # --------------------------
17
+ # Step 2: Načtení tokenizeru
18
+ # --------------------------
19
  model_path = "./"
20
  try:
21
  tokenizer = GPT2Tokenizer.from_pretrained(model_path)
22
  tokenizer.pad_token = tokenizer.eos_token
23
+ print("Tokenizer loaded successfully")
24
  except Exception as e:
25
+ print(f"Error loading tokenizer: {e}")
26
  exit()
27
 
28
+ # --------------------------
29
+ # Step 3: Načtení modelu s fallbackem
30
+ # --------------------------
31
+ quant_config = None
32
+ if torch.cuda.is_available():
33
+ try:
34
+ from transformers import BitsAndBytesConfig
35
+ quant_config = BitsAndBytesConfig(
36
+ load_in_4bit=True,
37
+ bnb_4bit_use_double_quant=True,
38
+ bnb_4bit_quant_type="nf4",
39
+ bnb_4bit_compute_dtype=torch.bfloat16
40
+ )
41
+ print("✅ Using 4-bit quantization (GPU mode)")
42
+ except Exception as e:
43
+ print("⚠️ BitsAndBytes not available, continuing without quantization:", e)
44
+ else:
45
+ print("💡 CPU mode — quantization disabled")
46
 
47
  try:
48
  base_model = GPT2LMHeadModel.from_pretrained(
49
  model_path,
50
  quantization_config=quant_config,
51
+ device_map={"": 0} if torch.cuda.is_available() else None,
52
+ low_cpu_mem_usage=True,
53
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
54
+ ).to(device)
55
+ print("✅ Base model loaded successfully")
56
  except Exception as e:
57
+ print(f"Error loading base model: {e}")
58
+ exit()
 
 
 
 
 
 
 
 
 
59
 
60
+ # --------------------------
61
+ # Step 4: Načtení PEFT (LoRA)
62
+ # --------------------------
63
  try:
64
  model = PeftModel.from_pretrained(
65
  base_model,
 
67
  is_trainable=False,
68
  device_map={"": 0} if torch.cuda.is_available() else None
69
  )
70
+ model.to(device)
71
+ print("✅ PEFT model loaded successfully")
72
  except Exception as e:
73
+ print(f"⚠️ Warning: Failed to load PEFT adapter, using base model. ({e})")
74
+ model = base_model
75
 
76
+ # --------------------------
77
  # Step 5: System prompt
78
+ # --------------------------
79
  system_prompt = """You are TinyGPT, a friendly AI assistant made by LuxAI.
80
  You must answer very short."""
81
 
82
+ # --------------------------
83
  # Step 6: Stopping criteria
84
+ # --------------------------
85
  class CustomStoppingCriteria(StoppingCriteria):
86
  def __init__(self, stop_token_id):
87
  self.stop_token_id = stop_token_id
 
91
 
92
  stopping_criteria = StoppingCriteriaList([CustomStoppingCriteria(tokenizer.eos_token_id)])
93
 
94
+ # --------------------------
95
+ # Step 6.5: Utility funkce
96
+ # --------------------------
97
  def clean_response(text):
98
  """Odstraní HTML, Markdown a redundantní mezery."""
99
  original_text = text
100
+ text = re.sub(r"<[^>]+>", " ", text)
101
+ text = unescape(text)
102
+ text = re.sub(r"[*#`_~]+", "", text)
103
  text = re.sub(r"\s+", " ", text).strip()
104
  if text != original_text:
105
+ print("🧹 Cleaned response.")
106
  return text
107
 
108
 
109
  def remove_repetitions(text, similarity_threshold=0.8):
110
+ """Odstraní opakující se věty."""
 
 
 
111
  sentences = re.split(r'(?<=[.!?])\s+', text)
112
  if len(sentences) <= 1:
113
  return text
 
114
  unique_sentences = []
115
  for sent in sentences:
116
  sent_clean = sent.strip()
117
  if not sent_clean:
118
  continue
119
+ if not unique_sentences or SequenceMatcher(None, sent_clean, unique_sentences[-1]).ratio() < similarity_threshold:
 
 
 
 
120
  unique_sentences.append(sent_clean)
 
 
 
121
  return " ".join(unique_sentences)
122
 
123
+
124
  def truncate_to_last_sentence(text):
125
  """Zkrátí text na poslední dokončenou větu."""
126
  sentences = re.split(r'(?<=[.!?])\s+', text)
127
+ for i in range(len(sentences) - 1, -1, -1):
128
+ if re.search(r'[.!?]$', sentences[i].strip()):
129
+ return " ".join(sentences[:i+1]).strip()
 
 
 
 
130
  return text.strip()
131
 
132
+ # --------------------------
133
  # Step 7: Generování odpovědi
134
+ # --------------------------
135
  def generate_response(
136
  user_input,
137
  max_length=2048,
 
145
  ):
146
  try:
147
  prompt = f"{system_prompt}\n\nUser: {user_input}\nAssistant:"
 
148
  inputs = tokenizer(prompt, return_tensors="pt").to(device)
149
+ print(f"📥 Input on device: {inputs['input_ids'].device}")
150
 
151
  with torch.no_grad():
152
  outputs = model.generate(
 
168
  generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
169
  response = generated_text.split("Assistant:")[-1].strip()
170
 
 
171
  response = clean_response(response)
172
  response = remove_repetitions(response)
173
  response = truncate_to_last_sentence(response)
174
 
 
175
  return response
176
 
177
  except Exception as e:
178
+ print(f"Error during generation: {e}")
179
  return None
180
 
181
+ # --------------------------
182
+ # Step 8: Flask API
183
+ # --------------------------
184
  app = Flask(__name__)
185
 
 
186
  @app.route('/generate', methods=['POST'])
187
  def generate_text():
 
188
  data = request.get_json()
189
  if not data or 'user_input' not in data:
190
  return jsonify({'error': 'Missing user_input parameter'}), 400
 
197
 
198
  return jsonify({'response': generated_response})
199
 
200
+ # --------------------------
201
+ # Step 9: Spuštění serveru
202
+ # --------------------------
203
  if __name__ == '__main__':
204
  app.run(host='0.0.0.0', port=7860)