kacperbb commited on
Commit
4748268
·
verified ·
1 Parent(s): 9754e19

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -32
app.py CHANGED
@@ -2,44 +2,29 @@ from flask import Flask, request, jsonify
2
  import os
3
  import logging
4
 
5
- # Set up logging
6
  logging.basicConfig(level=logging.INFO)
7
  logger = logging.getLogger(__name__)
8
 
9
  app = Flask(__name__)
10
-
11
- # Global variables for model
12
  model = None
13
- tokenizer = None
14
 
15
  def load_model():
16
- global model, tokenizer
17
  try:
18
- logger.info("Loading MLX model...")
19
- from mlx_lm import load, generate as mlx_generate
20
 
21
- # Load your fine-tuned model
22
- model, tokenizer = load(
23
- 'microsoft/Phi-3.5-mini-instruct',
24
- adapter_path='kacperbb/phi-3.5-mlx-finetuned'
25
- )
26
  logger.info("✅ Model loaded successfully!")
27
  return True
28
  except Exception as e:
29
- logger.error(f"❌ Error loading MLX model: {e}")
30
- try:
31
- # Fallback to transformers
32
- from transformers import pipeline
33
- model = pipeline("text-generation", model="microsoft/Phi-3.5-mini-instruct")
34
- logger.info("✅ Fallback model loaded!")
35
- return True
36
- except Exception as e2:
37
- logger.error(f"❌ Fallback failed: {e2}")
38
- return False
39
 
40
  @app.route('/generate', methods=['POST'])
41
  def generate_text():
42
- global model, tokenizer
43
  try:
44
  data = request.json
45
  prompt = data.get('inputs', data.get('prompt', ''))
@@ -48,13 +33,8 @@ def generate_text():
48
  if not prompt:
49
  return jsonify({"error": "No prompt provided"}), 400
50
 
51
- if model and tokenizer:
52
- # Use MLX
53
- from mlx_lm import generate as mlx_generate
54
- response = mlx_generate(model, tokenizer, prompt=prompt, max_tokens=max_tokens)
55
- elif model:
56
- # Use transformers pipeline
57
- result = model(prompt, max_length=len(prompt.split()) + max_tokens, do_sample=True)
58
  response = result[0]['generated_text']
59
  else:
60
  return jsonify({"error": "Model not loaded"}), 500
@@ -68,7 +48,7 @@ def generate_text():
68
  @app.route('/health', methods=['GET'])
69
  def health():
70
  return jsonify({
71
- "status": "healthy",
72
  "model_loaded": model is not None
73
  })
74
 
@@ -77,7 +57,7 @@ def home():
77
  return jsonify({
78
  "message": "Phi 3.5 API is running!",
79
  "endpoints": {
80
- "generate": "POST /generate",
81
  "health": "GET /health"
82
  }
83
  })
 
2
  import os
3
  import logging
4
 
 
5
  logging.basicConfig(level=logging.INFO)
6
  logger = logging.getLogger(__name__)
7
 
8
  app = Flask(__name__)
 
 
9
  model = None
 
10
 
11
  def load_model():
12
+ global model
13
  try:
14
+ logger.info("Loading model with transformers...")
15
+ from transformers import pipeline
16
 
17
+ # Use the base model
18
+ model = pipeline("text-generation", model="microsoft/Phi-3.5-mini-instruct", trust_remote_code=True)
 
 
 
19
  logger.info("✅ Model loaded successfully!")
20
  return True
21
  except Exception as e:
22
+ logger.error(f"❌ Error loading model: {e}")
23
+ return False
 
 
 
 
 
 
 
 
24
 
25
  @app.route('/generate', methods=['POST'])
26
  def generate_text():
27
+ global model
28
  try:
29
  data = request.json
30
  prompt = data.get('inputs', data.get('prompt', ''))
 
33
  if not prompt:
34
  return jsonify({"error": "No prompt provided"}), 400
35
 
36
+ if model:
37
+ result = model(prompt, max_new_tokens=max_tokens, do_sample=True, temperature=0.7)
 
 
 
 
 
38
  response = result[0]['generated_text']
39
  else:
40
  return jsonify({"error": "Model not loaded"}), 500
 
48
  @app.route('/health', methods=['GET'])
49
  def health():
50
  return jsonify({
51
+ "status": "healthy",
52
  "model_loaded": model is not None
53
  })
54
 
 
57
  return jsonify({
58
  "message": "Phi 3.5 API is running!",
59
  "endpoints": {
60
+ "generate": "POST /generate",
61
  "health": "GET /health"
62
  }
63
  })