WarApple commited on
Commit
18586c6
·
verified ·
1 Parent(s): fcb3be6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +55 -21
app.py CHANGED
@@ -7,45 +7,79 @@ logger = logging.getLogger(__name__)
7
 
8
  app = Flask(__name__)
9
 
10
- # Используем pipeline для простоты
11
- chatbot = pipeline(
12
- "text-generation",
13
- model="microsoft/DialoGPT-medium",
14
- torch_dtype="auto",
15
- device_map="auto"
16
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
  @app.route('/chat', methods=['POST'])
19
  def chat():
20
  try:
21
  data = request.get_json()
22
- message = data.get('message', '')
23
 
24
- if not message:
25
- return jsonify({"error": "No message provided"}), 400
 
 
 
 
 
 
26
 
27
- # Генерация ответа
28
  response = chatbot(
29
- message,
30
- max_length=100,
31
- temperature=0.7,
32
  do_sample=True,
33
  top_p=0.9,
34
- repetition_penalty=1.1
 
 
35
  )
36
 
 
 
 
 
 
 
 
 
 
 
37
  return jsonify({
38
- "response": response[0]['generated_text'],
39
  "status": "success"
40
  })
41
 
42
  except Exception as e:
43
- logger.error(f"Error: {e}")
44
  return jsonify({"error": "Internal server error"}), 500
45
 
46
- @app.route('/health', methods=['GET'])
47
- def health():
48
- return jsonify({"status": "healthy"})
 
49
 
50
  if __name__ == '__main__':
51
- app.run(host='0.0.0.0', port=7860)
 
7
 
8
  app = Flask(__name__)
9
 
10
+ # Инициализация pipeline с русскоязычной моделью
11
+ try:
12
+ chatbot = pipeline(
13
+ "text-generation",
14
+ model="sberbank-ai/rugpt3medium_based_on_gpt2", # Быстрая русская модель
15
+ torch_dtype="auto",
16
+ device_map="auto",
17
+ tokenizer="sberbank-ai/rugpt3medium_based_on_gpt2"
18
+ )
19
+ logger.info("Russian model loaded successfully")
20
+ except Exception as e:
21
+ logger.error(f"Error loading Russian model: {e}")
22
+ # Fallback to English model
23
+ chatbot = pipeline(
24
+ "text-generation",
25
+ model="microsoft/DialoGPT-medium",
26
+ torch_dtype="auto",
27
+ device_map="auto"
28
+ )
29
+ logger.info("Fallback English model loaded")
30
+
31
+ @app.route('/health', methods=['GET'])
32
+ def health_check():
33
+ return jsonify({"status": "healthy", "message": "Service is running"})
34
 
35
  @app.route('/chat', methods=['POST'])
36
  def chat():
37
  try:
38
  data = request.get_json()
 
39
 
40
+ if not data or 'message' not in data:
41
+ return jsonify({"error": "Missing 'message' in request"}), 400
42
+
43
+ user_message = data['message']
44
+ logger.info(f"Received message: {user_message}")
45
+
46
+ # Генерация ответа с промптом на русском
47
+ prompt = f"Пользователь: {user_message}\nАссистент:"
48
 
 
49
  response = chatbot(
50
+ prompt,
51
+ max_length=150,
52
+ temperature=0.8,
53
  do_sample=True,
54
  top_p=0.9,
55
+ repetition_penalty=1.1,
56
+ num_return_sequences=1,
57
+ pad_token_id=chatbot.tokenizer.eos_token_id
58
  )
59
 
60
+ # Извлекаем только ответ ассистента
61
+ generated_text = response[0]['generated_text']
62
+ assistant_response = generated_text.replace(prompt, "").strip()
63
+
64
+ # Очищаем ответ
65
+ if "Пользователь:" in assistant_response:
66
+ assistant_response = assistant_response.split("Пользователь:")[0].strip()
67
+
68
+ logger.info(f"Generated response: {assistant_response}")
69
+
70
  return jsonify({
71
+ "response": assistant_response,
72
  "status": "success"
73
  })
74
 
75
  except Exception as e:
76
+ logger.error(f"Error in /chat: {e}")
77
  return jsonify({"error": "Internal server error"}), 500
78
 
79
+ @app.route('/clear', methods=['POST'])
80
+ def clear_history():
81
+ # Для pipeline очистка не требуется, но endpoint для совместимости
82
+ return jsonify({"status": "success", "message": "OK"})
83
 
84
  if __name__ == '__main__':
85
+ app.run(host='0.0.0.0', port=7860, debug=False)