kacperbb commited on
Commit
4245fe0
Β·
verified Β·
1 Parent(s): b260d66

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -32
app.py CHANGED
@@ -1,9 +1,8 @@
1
  from flask import Flask, request, jsonify
2
  import os
3
  import logging
4
- from transformers import pipeline
5
 
6
- # ====== Cache & Logging Setup ======
7
  os.environ['HF_HOME'] = '/.cache/huggingface'
8
  os.environ['TRANSFORMERS_CACHE'] = '/.cache/huggingface/transformers'
9
 
@@ -13,57 +12,68 @@ logger = logging.getLogger(__name__)
13
  app = Flask(__name__)
14
  model = None
15
 
16
- # ====== Model Loader ======
17
  def load_model():
18
  global model
19
- models_to_try = [
20
- ("kacperbb/phi-3.5-hf-finetuned", "YOUR fine-tuned model"),
21
- ("microsoft/Phi-3.5-mini-instruct", "Base model"),
22
- ("microsoft/DialoGPT-medium", "Smaller fallback model")
23
- ]
24
-
25
- for model_name, description in models_to_try:
 
 
 
 
 
 
 
26
  try:
27
- logger.info(f"Loading {description}...")
28
  model = pipeline(
29
- "text-generation",
30
- model=model_name,
31
- trust_remote_code=True,
32
- cache_dir="/.cache/huggingface"
33
  )
34
- logger.info(f"βœ… {description} loaded successfully!")
35
  return True
36
- except Exception as e:
37
- logger.error(f"❌ Failed to load {description}: {e}")
 
 
 
 
 
 
38
 
39
- logger.error("❌ All model loading attempts failed!")
40
- return False
41
-
42
- # ====== Endpoints ======
43
  @app.route('/generate', methods=['POST'])
44
  def generate_text():
45
  global model
46
  try:
47
  data = request.json
48
- prompt = data.get('inputs') or data.get('prompt', '')
49
  max_tokens = data.get('parameters', {}).get('max_new_tokens', 100)
50
-
51
  if not prompt:
52
  return jsonify({"error": "No prompt provided"}), 400
53
-
54
- if not model:
 
 
 
55
  return jsonify({"error": "Model not loaded"}), 500
56
-
57
- result = model(prompt, max_new_tokens=max_tokens, do_sample=True, temperature=0.7)
58
- return jsonify([{"generated_text": result[0]['generated_text']}])
59
-
60
  except Exception as e:
61
  logger.error(f"Generation error: {e}")
62
  return jsonify({"error": str(e)}), 500
63
 
64
  @app.route('/health', methods=['GET'])
65
  def health():
66
- return jsonify({"status": "healthy", "model_loaded": model is not None})
 
 
 
67
 
68
  @app.route('/', methods=['GET'])
69
  def home():
@@ -76,7 +86,6 @@ def home():
76
  }
77
  })
78
 
79
- # ====== Start Server ======
80
  if __name__ == '__main__':
81
  logger.info("Starting Phi 3.5 API...")
82
  load_model()
 
1
  from flask import Flask, request, jsonify
2
  import os
3
  import logging
 
4
 
5
+ # Set cache environment variables
6
  os.environ['HF_HOME'] = '/.cache/huggingface'
7
  os.environ['TRANSFORMERS_CACHE'] = '/.cache/huggingface/transformers'
8
 
 
12
  app = Flask(__name__)
13
  model = None
14
 
 
15
  def load_model():
16
  global model
17
+ try:
18
+ logger.info("Loading YOUR fine-tuned model...")
19
+ from transformers import pipeline
20
+
21
+ model = pipeline(
22
+ "text-generation",
23
+ model="kacperbb/phi-3.5-hf-finetuned",
24
+ trust_remote_code=True
25
+ )
26
+ logger.info("βœ… YOUR fine-tuned model loaded successfully!")
27
+ return True
28
+ except Exception as e:
29
+ logger.error(f"❌ Error loading your model: {e}")
30
+ logger.info("Trying with base model...")
31
  try:
 
32
  model = pipeline(
33
+ "text-generation",
34
+ model="microsoft/Phi-3.5-mini-instruct",
35
+ trust_remote_code=True
 
36
  )
37
+ logger.info("βœ… Base model loaded as fallback")
38
  return True
39
+ except Exception as e2:
40
+ logger.error(f"❌ Fallback failed: {e2}")
41
+ try:
42
+ model = pipeline("text-generation", model="gpt2")
43
+ logger.info("βœ… GPT-2 fallback model loaded")
44
+ return True
45
+ except:
46
+ return False
47
 
 
 
 
 
48
  @app.route('/generate', methods=['POST'])
49
  def generate_text():
50
  global model
51
  try:
52
  data = request.json
53
+ prompt = data.get('inputs', data.get('prompt', ''))
54
  max_tokens = data.get('parameters', {}).get('max_new_tokens', 100)
55
+
56
  if not prompt:
57
  return jsonify({"error": "No prompt provided"}), 400
58
+
59
+ if model:
60
+ result = model(prompt, max_new_tokens=max_tokens, do_sample=True)
61
+ response = result[0]['generated_text']
62
+ else:
63
  return jsonify({"error": "Model not loaded"}), 500
64
+
65
+ return jsonify([{"generated_text": response}])
66
+
 
67
  except Exception as e:
68
  logger.error(f"Generation error: {e}")
69
  return jsonify({"error": str(e)}), 500
70
 
71
  @app.route('/health', methods=['GET'])
72
  def health():
73
+ return jsonify({
74
+ "status": "healthy",
75
+ "model_loaded": model is not None
76
+ })
77
 
78
  @app.route('/', methods=['GET'])
79
  def home():
 
86
  }
87
  })
88
 
 
89
  if __name__ == '__main__':
90
  logger.info("Starting Phi 3.5 API...")
91
  load_model()