luxopes commited on
Commit
d081ad0
·
verified ·
1 Parent(s): a005021

Create server.py

Browse files
Files changed (1) hide show
  1. server.py +199 -0
server.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import re
3
+ from html import unescape
4
+ from transformers import GPT2LMHeadModel, GPT2Tokenizer, BitsAndBytesConfig
5
+ from peft import PeftModel
6
+ from transformers import StoppingCriteria, StoppingCriteriaList
7
+ from difflib import SequenceMatcher
8
+ from flask import Flask, request, jsonify
9
+
10
+ # Step 2: Load tokenizer
11
+ model_path = "./"
12
+ try:
13
+ tokenizer = GPT2Tokenizer.from_pretrained(model_path)
14
+ tokenizer.pad_token = tokenizer.eos_token
15
+ print("Tokenizer loaded successfully")
16
+ except Exception as e:
17
+ print(f"Error loading tokenizer: {e}")
18
+ exit()
19
+
20
+ # Step 3: Load model (kvantizace + fallback)
21
+ quant_config = BitsAndBytesConfig(
22
+ load_in_4bit=True,
23
+ bnb_4bit_use_double_quant=True,
24
+ bnb_4bit_quant_type="nf4",
25
+ bnb_4bit_compute_dtype=torch.bfloat16
26
+ )
27
+
28
+ try:
29
+ base_model = GPT2LMHeadModel.from_pretrained(
30
+ model_path,
31
+ quantization_config=quant_config,
32
+ device_map={"": 0},
33
+ low_cpu_mem_usage=True
34
+ )
35
+ print("Base model loaded successfully (4bit quantized)")
36
+ except Exception as e:
37
+ print(f"Error loading base model: {e}")
38
+ try:
39
+ base_model = GPT2LMHeadModel.from_pretrained(
40
+ model_path,
41
+ low_cpu_mem_usage=True,
42
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
43
+ ).to("cuda:0" if torch.cuda.is_available() else "cpu")
44
+ print("Base model loaded without quantization")
45
+ except Exception as e:
46
+ print(f"Error loading base model without quantization: {e}")
47
+ exit()
48
+
49
+ # Step 4: Load PEFT (LoRA)
50
+ try:
51
+ model = PeftModel.from_pretrained(
52
+ base_model,
53
+ model_path,
54
+ is_trainable=False,
55
+ device_map={"": 0} if torch.cuda.is_available() else None
56
+ )
57
+ print("PEFT model loaded successfully")
58
+ except Exception as e:
59
+ print(f"Error loading PEFT model: {e}")
60
+ exit()
61
+
62
+ # Step 5: System prompt
63
+ system_prompt = """You are TinyGPT, a friendly AI assistant made by LuxAI.
64
+ You must answer very short."""
65
+
66
+ # Step 6: Stopping criteria
67
+ class CustomStoppingCriteria(StoppingCriteria):
68
+ def __init__(self, stop_token_id):
69
+ self.stop_token_id = stop_token_id
70
+
71
+ def __call__(self, input_ids, scores, **kwargs):
72
+ return input_ids[0][-1] == self.stop_token_id or len(input_ids[0]) > 256
73
+
74
+ stopping_criteria = StoppingCriteriaList([CustomStoppingCriteria(tokenizer.eos_token_id)])
75
+
76
+
77
+ # Step 6.5: Čisticí a kontrolní funkce
78
+ def clean_response(text):
79
+ """Odstraní HTML, Markdown a redundantní mezery."""
80
+ original_text = text
81
+ text = re.sub(r"<[^>]+>", " ", text) # odstraní HTML tagy
82
+ text = unescape(text) # dekóduje HTML entity
83
+ text = re.sub(r"[*#`_~]+", "", text) # odstraní markdown znaky
84
+ text = re.sub(r"\s+", " ", text).strip()
85
+ if text != original_text:
86
+ print("🧹 Response cleaned from HTML/Markdown artifacts.")
87
+ return text
88
+
89
+
90
+ def remove_repetitions(text, similarity_threshold=0.8):
91
+ """
92
+ Pokud se opakují stejné věty (např. 'I'm TinyGPT...' 8x),
93
+ ponechá pouze první.
94
+ """
95
+ sentences = re.split(r'(?<=[.!?])\s+', text)
96
+ if len(sentences) <= 1:
97
+ return text
98
+
99
+ unique_sentences = []
100
+ for sent in sentences:
101
+ sent_clean = sent.strip()
102
+ if not sent_clean:
103
+ continue
104
+ if not unique_sentences:
105
+ unique_sentences.append(sent_clean)
106
+ continue
107
+ ratio = SequenceMatcher(None, sent_clean, unique_sentences[-1]).ratio()
108
+ if ratio < similarity_threshold:
109
+ unique_sentences.append(sent_clean)
110
+
111
+ if len(unique_sentences) < len(sentences):
112
+ print("🧩 Repetitive content detected and reduced.")
113
+ return " ".join(unique_sentences)
114
+
115
+ def truncate_to_last_sentence(text):
116
+ """Zkrátí text na poslední dokončenou větu."""
117
+ sentences = re.split(r'(?<=[.!?])\s+', text)
118
+ if sentences and sentences[-1].strip():
119
+ # Najde poslední větu, která končí na . ? !
120
+ for i in range(len(sentences) - 1, -1, -1):
121
+ if re.search(r'[.!?]$', sentences[i].strip()):
122
+ return " ".join(sentences[:i+1]).strip()
123
+ # If no sentence ends with . ? !, return the whole text after cleaning
124
+ return text.strip()
125
+ return text.strip()
126
+
127
+
128
+ # Step 7: Generování odpovědi
129
+ def generate_response(
130
+ user_input,
131
+ max_length=2048,
132
+ temperature=0.7,
133
+ top_k=50,
134
+ top_p=0.7,
135
+ repetition_penalty=10.0,
136
+ num_beams=4,
137
+ early_stopping=True,
138
+ do_sample=True
139
+ ):
140
+ try:
141
+ prompt = f"{system_prompt}\n\nUser: {user_input}\nAssistant:"
142
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
143
+ inputs = tokenizer(prompt, return_tensors="pt").to(device)
144
+ print(f"Input device: {inputs['input_ids'].device}")
145
+
146
+ with torch.no_grad():
147
+ outputs = model.generate(
148
+ **inputs,
149
+ max_length=max_length,
150
+ temperature=temperature if do_sample else 1.0,
151
+ top_k=top_k if do_sample else None,
152
+ top_p=top_p if do_sample else None,
153
+ repetition_penalty=repetition_penalty,
154
+ num_beams=num_beams,
155
+ early_stopping=early_stopping if num_beams > 1 else False,
156
+ num_return_sequences=1,
157
+ pad_token_id=tokenizer.eos_token_id,
158
+ eos_token_id=tokenizer.eos_token_id,
159
+ do_sample=do_sample,
160
+ stopping_criteria=stopping_criteria
161
+ )
162
+
163
+ generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
164
+ response = generated_text.split("Assistant:")[-1].strip()
165
+
166
+ # Vyčištění, odstranění opakování a zkrácení na poslední větu
167
+ response = clean_response(response)
168
+ response = remove_repetitions(response)
169
+ response = truncate_to_last_sentence(response)
170
+
171
+
172
+ return response
173
+
174
+ except Exception as e:
175
+ print(f"Error during generation: {e}")
176
+ return None
177
+
178
+ # Step 8: Initialize Flask app
179
+ app = Flask(__name__)
180
+
181
+ # Step 9: Define API endpoint
182
+ @app.route('/generate', methods=['POST'])
183
+ def generate_text():
184
+ # Step 10: Get input, generate response, return JSON
185
+ data = request.get_json()
186
+ if not data or 'user_input' not in data:
187
+ return jsonify({'error': 'Missing user_input parameter'}), 400
188
+
189
+ user_input = data['user_input']
190
+ generated_response = generate_response(user_input)
191
+
192
+ if generated_response is None:
193
+ return jsonify({'error': 'Failed to generate response'}), 500
194
+
195
+ return jsonify({'response': generated_response})
196
+
197
+ # Step 11: Run the Flask app
198
+ if __name__ == '__main__':
199
+ app.run(host='0.0.0.0', port=7860)