|
|
import torch |
|
|
import re |
|
|
from html import unescape |
|
|
from transformers import GPT2LMHeadModel, GPT2Tokenizer |
|
|
from peft import PeftModel |
|
|
from transformers import StoppingCriteria, StoppingCriteriaList |
|
|
from difflib import SequenceMatcher |
|
|
from flask import Flask, request, jsonify |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
print(f"🚀 Running on device: {device}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model_path = "./" |
|
|
try: |
|
|
tokenizer = GPT2Tokenizer.from_pretrained(model_path) |
|
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
print("✅ Tokenizer loaded successfully") |
|
|
except Exception as e: |
|
|
print(f"❌ Error loading tokenizer: {e}") |
|
|
exit() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
quant_config = None |
|
|
if torch.cuda.is_available(): |
|
|
try: |
|
|
from transformers import BitsAndBytesConfig |
|
|
quant_config = BitsAndBytesConfig( |
|
|
load_in_4bit=True, |
|
|
bnb_4bit_use_double_quant=True, |
|
|
bnb_4bit_quant_type="nf4", |
|
|
bnb_4bit_compute_dtype=torch.bfloat16 |
|
|
) |
|
|
print("✅ Using 4-bit quantization (GPU mode)") |
|
|
except Exception as e: |
|
|
print("⚠️ BitsAndBytes not available, continuing without quantization:", e) |
|
|
else: |
|
|
print("💡 CPU mode — quantization disabled") |
|
|
|
|
|
try: |
|
|
base_model = GPT2LMHeadModel.from_pretrained( |
|
|
model_path, |
|
|
quantization_config=quant_config, |
|
|
device_map={"": 0} if torch.cuda.is_available() else None, |
|
|
low_cpu_mem_usage=True, |
|
|
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32 |
|
|
).to(device) |
|
|
print("✅ Base model loaded successfully") |
|
|
except Exception as e: |
|
|
print(f"❌ Error loading base model: {e}") |
|
|
exit() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
model = PeftModel.from_pretrained( |
|
|
base_model, |
|
|
model_path, |
|
|
is_trainable=False, |
|
|
device_map={"": 0} if torch.cuda.is_available() else None |
|
|
) |
|
|
model.to(device) |
|
|
print("✅ PEFT model loaded successfully") |
|
|
except Exception as e: |
|
|
print(f"⚠️ Warning: Failed to load PEFT adapter, using base model. ({e})") |
|
|
model = base_model |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
system_prompt = """You are GPT-A, a friendly AI assistant made by LuxAI. |
|
|
You must answer very short and cooherent.""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class CustomStoppingCriteria(StoppingCriteria): |
|
|
def __init__(self, stop_token_id): |
|
|
self.stop_token_id = stop_token_id |
|
|
|
|
|
def __call__(self, input_ids, scores, **kwargs): |
|
|
return input_ids[0][-1] == self.stop_token_id or len(input_ids[0]) > 512 |
|
|
|
|
|
stopping_criteria = StoppingCriteriaList([CustomStoppingCriteria(tokenizer.eos_token_id)]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def clean_response(text): |
|
|
"""Odstraní HTML, Markdown a redundantní mezery.""" |
|
|
original_text = text |
|
|
text = re.sub(r"<[^>]+>", " ", text) |
|
|
text = unescape(text) |
|
|
text = re.sub(r"[*#`_~]+", "", text) |
|
|
text = re.sub(r"\s+", " ", text).strip() |
|
|
if text != original_text: |
|
|
print("🧹 Cleaned response.") |
|
|
return text |
|
|
|
|
|
|
|
|
def remove_repetitions(text, similarity_threshold=0.8): |
|
|
"""Odstraní opakující se věty.""" |
|
|
sentences = re.split(r'(?<=[.!?])\s+', text) |
|
|
if len(sentences) <= 1: |
|
|
return text |
|
|
unique_sentences = [] |
|
|
for sent in sentences: |
|
|
sent_clean = sent.strip() |
|
|
if not sent_clean: |
|
|
continue |
|
|
if not unique_sentences or SequenceMatcher(None, sent_clean, unique_sentences[-1]).ratio() < similarity_threshold: |
|
|
unique_sentences.append(sent_clean) |
|
|
return " ".join(unique_sentences) |
|
|
|
|
|
|
|
|
def truncate_to_last_sentence(text): |
|
|
"""Zkrátí text na poslední dokončenou větu.""" |
|
|
sentences = re.split(r'(?<=[.!?])\s+', text) |
|
|
for i in range(len(sentences) - 1, -1, -1): |
|
|
if re.search(r'[.!?]$', sentences[i].strip()): |
|
|
return " ".join(sentences[:i+1]).strip() |
|
|
return text.strip() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def generate_response( |
|
|
user_input, |
|
|
max_length=2048, |
|
|
temperature=0.7, |
|
|
top_k=50, |
|
|
top_p=0.7, |
|
|
repetition_penalty=10.0, |
|
|
num_beams=4, |
|
|
early_stopping=True, |
|
|
do_sample=True |
|
|
): |
|
|
try: |
|
|
prompt = f"{system_prompt}\n\nUser: {user_input}\nAssistant:" |
|
|
inputs = tokenizer(prompt, return_tensors="pt").to(device) |
|
|
print(f"📥 Input on device: {inputs['input_ids'].device}") |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model.generate( |
|
|
**inputs, |
|
|
max_length=max_length, |
|
|
temperature=temperature if do_sample else 1.0, |
|
|
top_k=top_k if do_sample else None, |
|
|
top_p=top_p if do_sample else None, |
|
|
repetition_penalty=repetition_penalty, |
|
|
num_beams=num_beams, |
|
|
early_stopping=early_stopping if num_beams > 1 else False, |
|
|
num_return_sequences=1, |
|
|
pad_token_id=tokenizer.eos_token_id, |
|
|
eos_token_id=tokenizer.eos_token_id, |
|
|
do_sample=do_sample, |
|
|
stopping_criteria=stopping_criteria |
|
|
) |
|
|
|
|
|
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
response = generated_text.split("Assistant:")[-1].strip() |
|
|
|
|
|
response = clean_response(response) |
|
|
response = remove_repetitions(response) |
|
|
response = truncate_to_last_sentence(response) |
|
|
|
|
|
return response |
|
|
|
|
|
except Exception as e: |
|
|
print(f"❌ Error during generation: {e}") |
|
|
return None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app = Flask(__name__) |
|
|
|
|
|
@app.route('/generate', methods=['POST']) |
|
|
def generate_text(): |
|
|
data = request.get_json() |
|
|
if not data or 'user_input' not in data: |
|
|
return jsonify({'error': 'Missing user_input parameter'}), 400 |
|
|
|
|
|
user_input = data['user_input'] |
|
|
generated_response = generate_response(user_input) |
|
|
|
|
|
if generated_response is None: |
|
|
return jsonify({'error': 'Failed to generate response'}), 500 |
|
|
|
|
|
return jsonify({'response': generated_response}) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
app.run(host='0.0.0.0', port=7860) |
|
|
|