import torch import re from html import unescape from transformers import GPT2LMHeadModel, GPT2Tokenizer from peft import PeftModel from transformers import StoppingCriteria, StoppingCriteriaList from difflib import SequenceMatcher from flask import Flask, request, jsonify # -------------------------- # Step 1: Nastavení zařízení # -------------------------- device = "cuda" if torch.cuda.is_available() else "cpu" print(f"🚀 Running on device: {device}") # -------------------------- # Step 2: Načtení tokenizeru # -------------------------- model_path = "./" try: tokenizer = GPT2Tokenizer.from_pretrained(model_path) tokenizer.pad_token = tokenizer.eos_token print("✅ Tokenizer loaded successfully") except Exception as e: print(f"❌ Error loading tokenizer: {e}") exit() # -------------------------- # Step 3: Načtení modelu s fallbackem # -------------------------- quant_config = None if torch.cuda.is_available(): try: from transformers import BitsAndBytesConfig quant_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16 ) print("✅ Using 4-bit quantization (GPU mode)") except Exception as e: print("⚠️ BitsAndBytes not available, continuing without quantization:", e) else: print("💡 CPU mode — quantization disabled") try: base_model = GPT2LMHeadModel.from_pretrained( model_path, quantization_config=quant_config, device_map={"": 0} if torch.cuda.is_available() else None, low_cpu_mem_usage=True, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32 ).to(device) print("✅ Base model loaded successfully") except Exception as e: print(f"❌ Error loading base model: {e}") exit() # -------------------------- # Step 4: Načtení PEFT (LoRA) # -------------------------- try: model = PeftModel.from_pretrained( base_model, model_path, is_trainable=False, device_map={"": 0} if torch.cuda.is_available() else None ) model.to(device) print("✅ PEFT model loaded successfully") except Exception as e: print(f"⚠️ Warning: Failed to load PEFT adapter, using base model. ({e})") model = base_model # -------------------------- # Step 5: System prompt # -------------------------- system_prompt = """You are GPT-A, a friendly AI assistant made by LuxAI. You must answer very short and cooherent.""" # -------------------------- # Step 6: Stopping criteria # -------------------------- class CustomStoppingCriteria(StoppingCriteria): def __init__(self, stop_token_id): self.stop_token_id = stop_token_id def __call__(self, input_ids, scores, **kwargs): return input_ids[0][-1] == self.stop_token_id or len(input_ids[0]) > 512 stopping_criteria = StoppingCriteriaList([CustomStoppingCriteria(tokenizer.eos_token_id)]) # -------------------------- # Step 6.5: Utility funkce # -------------------------- def clean_response(text): """Odstraní HTML, Markdown a redundantní mezery.""" original_text = text text = re.sub(r"<[^>]+>", " ", text) text = unescape(text) text = re.sub(r"[*#`_~]+", "", text) text = re.sub(r"\s+", " ", text).strip() if text != original_text: print("🧹 Cleaned response.") return text def remove_repetitions(text, similarity_threshold=0.8): """Odstraní opakující se věty.""" sentences = re.split(r'(?<=[.!?])\s+', text) if len(sentences) <= 1: return text unique_sentences = [] for sent in sentences: sent_clean = sent.strip() if not sent_clean: continue if not unique_sentences or SequenceMatcher(None, sent_clean, unique_sentences[-1]).ratio() < similarity_threshold: unique_sentences.append(sent_clean) return " ".join(unique_sentences) def truncate_to_last_sentence(text): """Zkrátí text na poslední dokončenou větu.""" sentences = re.split(r'(?<=[.!?])\s+', text) for i in range(len(sentences) - 1, -1, -1): if re.search(r'[.!?]$', sentences[i].strip()): return " ".join(sentences[:i+1]).strip() return text.strip() # -------------------------- # Step 7: Generování odpovědi # -------------------------- def generate_response( user_input, max_length=2048, temperature=0.7, top_k=50, top_p=0.7, repetition_penalty=10.0, num_beams=4, early_stopping=True, do_sample=True ): try: prompt = f"{system_prompt}\n\nUser: {user_input}\nAssistant:" inputs = tokenizer(prompt, return_tensors="pt").to(device) print(f"📥 Input on device: {inputs['input_ids'].device}") with torch.no_grad(): outputs = model.generate( **inputs, max_length=max_length, temperature=temperature if do_sample else 1.0, top_k=top_k if do_sample else None, top_p=top_p if do_sample else None, repetition_penalty=repetition_penalty, num_beams=num_beams, early_stopping=early_stopping if num_beams > 1 else False, num_return_sequences=1, pad_token_id=tokenizer.eos_token_id, eos_token_id=tokenizer.eos_token_id, do_sample=do_sample, stopping_criteria=stopping_criteria ) generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) response = generated_text.split("Assistant:")[-1].strip() response = clean_response(response) response = remove_repetitions(response) response = truncate_to_last_sentence(response) return response except Exception as e: print(f"❌ Error during generation: {e}") return None # -------------------------- # Step 8: Flask API # -------------------------- app = Flask(__name__) @app.route('/generate', methods=['POST']) def generate_text(): data = request.get_json() if not data or 'user_input' not in data: return jsonify({'error': 'Missing user_input parameter'}), 400 user_input = data['user_input'] generated_response = generate_response(user_input) if generated_response is None: return jsonify({'error': 'Failed to generate response'}), 500 return jsonify({'response': generated_response}) # -------------------------- # Step 9: Spuštění serveru # -------------------------- if __name__ == '__main__': app.run(host='0.0.0.0', port=7860)