| | from flask import Flask, request, jsonify |
| | from transformers import pipeline |
| | import logging |
| |
|
| | logging.basicConfig(level=logging.INFO) |
| | logger = logging.getLogger(__name__) |
| |
|
| | app = Flask(__name__) |
| |
|
| | |
| | try: |
| | chatbot = pipeline( |
| | "text-generation", |
| | model="sberbank-ai/rugpt3medium_based_on_gpt2", |
| | device_map="auto", |
| | model_kwargs={"torch_dtype": "auto"} |
| | ) |
| | logger.info("Model loaded successfully") |
| | except Exception as e: |
| | logger.error(f"Error: {e}") |
| | chatbot = None |
| |
|
| | @app.route('/health', methods=['GET']) |
| | def health_check(): |
| | return jsonify({"status": "healthy", "model_loaded": chatbot is not None}) |
| |
|
| | @app.route('/chat', methods=['POST']) |
| | def chat(): |
| | try: |
| | if not chatbot: |
| | return jsonify({"error": "Model not loaded"}), 500 |
| | |
| | data = request.get_json() |
| | user_message = data.get('message', '').strip() |
| | |
| | if not user_message: |
| | return jsonify({"error": "Empty message"}), 400 |
| |
|
| | logger.info(f"Received: {user_message}") |
| |
|
| | |
| | prompt = f"""Ты - профессиональный помощник по Telegram стикерам. Отвечай вежливо и по делу. |
| | |
| | Пользователь: {user_message} |
| | Помощник:""" |
| |
|
| | response = chatbot( |
| | prompt, |
| | max_new_tokens=100, |
| | temperature=0.7, |
| | do_sample=True, |
| | top_p=0.9, |
| | repetition_penalty=1.1, |
| | num_return_sequences=1, |
| | pad_token_id=chatbot.tokenizer.eos_token_id, |
| | truncation=True |
| | ) |
| |
|
| | |
| | generated_text = response[0]['generated_text'] |
| | assistant_response = generated_text.replace(prompt, "").strip() |
| | |
| | |
| | if "Пользователь:" in assistant_response: |
| | assistant_response = assistant_response.split("Пользователь:")[0].strip() |
| | |
| | logger.info(f"Response: {assistant_response}") |
| | |
| | return jsonify({ |
| | "response": assistant_response, |
| | "status": "success" |
| | }) |
| | |
| | except Exception as e: |
| | logger.error(f"Error: {e}") |
| | return jsonify({"error": "Internal error"}), 500 |
| |
|
| | @app.route('/', methods=['GET']) |
| | def home(): |
| | return jsonify({"message": "Sticker assistant API"}) |
| |
|
| | if __name__ == '__main__': |
| | app.run(host='0.0.0.0', port=7860, debug=False) |