import torch
import re
from html import unescape
from transformers import GPT2LMHeadModel, GPT2Tokenizer
from peft import PeftModel
from transformers import StoppingCriteria, StoppingCriteriaList
from difflib import SequenceMatcher
from flask import Flask, request, jsonify
# --------------------------
# Step 1: Nastavení zařízení
# --------------------------
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"🚀 Running on device: {device}")
# --------------------------
# Step 2: Načtení tokenizeru
# --------------------------
model_path = "./"
try:
tokenizer = GPT2Tokenizer.from_pretrained(model_path)
tokenizer.pad_token = tokenizer.eos_token
print("✅ Tokenizer loaded successfully")
except Exception as e:
print(f"❌ Error loading tokenizer: {e}")
exit()
# --------------------------
# Step 3: Načtení modelu s fallbackem
# --------------------------
quant_config = None
if torch.cuda.is_available():
try:
from transformers import BitsAndBytesConfig
quant_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16
)
print("✅ Using 4-bit quantization (GPU mode)")
except Exception as e:
print("⚠️ BitsAndBytes not available, continuing without quantization:", e)
else:
print("💡 CPU mode — quantization disabled")
try:
base_model = GPT2LMHeadModel.from_pretrained(
model_path,
quantization_config=quant_config,
device_map={"": 0} if torch.cuda.is_available() else None,
low_cpu_mem_usage=True,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
).to(device)
print("✅ Base model loaded successfully")
except Exception as e:
print(f"❌ Error loading base model: {e}")
exit()
# --------------------------
# Step 4: Načtení PEFT (LoRA)
# --------------------------
try:
model = PeftModel.from_pretrained(
base_model,
model_path,
is_trainable=False,
device_map={"": 0} if torch.cuda.is_available() else None
)
model.to(device)
print("✅ PEFT model loaded successfully")
except Exception as e:
print(f"⚠️ Warning: Failed to load PEFT adapter, using base model. ({e})")
model = base_model
# --------------------------
# Step 5: System prompt
# --------------------------
system_prompt = """You are GPT-A, a friendly AI assistant made by LuxAI.
You must answer very short and cooherent."""
# --------------------------
# Step 6: Stopping criteria
# --------------------------
class CustomStoppingCriteria(StoppingCriteria):
def __init__(self, stop_token_id):
self.stop_token_id = stop_token_id
def __call__(self, input_ids, scores, **kwargs):
return input_ids[0][-1] == self.stop_token_id or len(input_ids[0]) > 512
stopping_criteria = StoppingCriteriaList([CustomStoppingCriteria(tokenizer.eos_token_id)])
# --------------------------
# Step 6.5: Utility funkce
# --------------------------
def clean_response(text):
"""Odstraní HTML, Markdown a redundantní mezery."""
original_text = text
text = re.sub(r"<[^>]+>", " ", text)
text = unescape(text)
text = re.sub(r"[*#`_~]+", "", text)
text = re.sub(r"\s+", " ", text).strip()
if text != original_text:
print("🧹 Cleaned response.")
return text
def remove_repetitions(text, similarity_threshold=0.8):
"""Odstraní opakující se věty."""
sentences = re.split(r'(?<=[.!?])\s+', text)
if len(sentences) <= 1:
return text
unique_sentences = []
for sent in sentences:
sent_clean = sent.strip()
if not sent_clean:
continue
if not unique_sentences or SequenceMatcher(None, sent_clean, unique_sentences[-1]).ratio() < similarity_threshold:
unique_sentences.append(sent_clean)
return " ".join(unique_sentences)
def truncate_to_last_sentence(text):
"""Zkrátí text na poslední dokončenou větu."""
sentences = re.split(r'(?<=[.!?])\s+', text)
for i in range(len(sentences) - 1, -1, -1):
if re.search(r'[.!?]$', sentences[i].strip()):
return " ".join(sentences[:i+1]).strip()
return text.strip()
# --------------------------
# Step 7: Generování odpovědi
# --------------------------
def generate_response(
user_input,
max_length=2048,
temperature=0.7,
top_k=50,
top_p=0.7,
repetition_penalty=10.0,
num_beams=4,
early_stopping=True,
do_sample=True
):
try:
prompt = f"{system_prompt}\n\nUser: {user_input}\nAssistant:"
inputs = tokenizer(prompt, return_tensors="pt").to(device)
print(f"📥 Input on device: {inputs['input_ids'].device}")
with torch.no_grad():
outputs = model.generate(
**inputs,
max_length=max_length,
temperature=temperature if do_sample else 1.0,
top_k=top_k if do_sample else None,
top_p=top_p if do_sample else None,
repetition_penalty=repetition_penalty,
num_beams=num_beams,
early_stopping=early_stopping if num_beams > 1 else False,
num_return_sequences=1,
pad_token_id=tokenizer.eos_token_id,
eos_token_id=tokenizer.eos_token_id,
do_sample=do_sample,
stopping_criteria=stopping_criteria
)
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
response = generated_text.split("Assistant:")[-1].strip()
response = clean_response(response)
response = remove_repetitions(response)
response = truncate_to_last_sentence(response)
return response
except Exception as e:
print(f"❌ Error during generation: {e}")
return None
# --------------------------
# Step 8: Flask API
# --------------------------
app = Flask(__name__)
@app.route('/generate', methods=['POST'])
def generate_text():
data = request.get_json()
if not data or 'user_input' not in data:
return jsonify({'error': 'Missing user_input parameter'}), 400
user_input = data['user_input']
generated_response = generate_response(user_input)
if generated_response is None:
return jsonify({'error': 'Failed to generate response'}), 500
return jsonify({'response': generated_response})
# --------------------------
# Step 9: Spuštění serveru
# --------------------------
if __name__ == '__main__':
app.run(host='0.0.0.0', port=7860)