kacperbb commited on
Commit
71fb2b8
·
verified ·
1 Parent(s): 267ac8a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +87 -5
app.py CHANGED
@@ -1,7 +1,89 @@
1
- from fastapi import FastAPI
 
 
2
 
3
- app = FastAPI()
 
 
4
 
5
- @app.get("/")
6
- def greet_json():
7
- return {"Hello": "World!"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from flask import Flask, request, jsonify
2
+ import os
3
+ import logging
4
 
5
+ # Set up logging
6
+ logging.basicConfig(level=logging.INFO)
7
+ logger = logging.getLogger(__name__)
8
 
9
+ app = Flask(__name__)
10
+
11
+ # Global variables for model
12
+ model = None
13
+ tokenizer = None
14
+
15
+ def load_model():
16
+ global model, tokenizer
17
+ try:
18
+ logger.info("Loading MLX model...")
19
+ from mlx_lm import load, generate as mlx_generate
20
+
21
+ # Load your fine-tuned model
22
+ model, tokenizer = load(
23
+ 'microsoft/Phi-3.5-mini-instruct',
24
+ adapter_path='kacperbb/phi-3.5-mlx-finetuned'
25
+ )
26
+ logger.info("✅ Model loaded successfully!")
27
+ return True
28
+ except Exception as e:
29
+ logger.error(f"❌ Error loading MLX model: {e}")
30
+ try:
31
+ # Fallback to transformers
32
+ from transformers import pipeline
33
+ global model
34
+ model = pipeline("text-generation", model="microsoft/Phi-3.5-mini-instruct")
35
+ logger.info("✅ Fallback model loaded!")
36
+ return True
37
+ except Exception as e2:
38
+ logger.error(f"❌ Fallback failed: {e2}")
39
+ return False
40
+
41
+ @app.route('/generate', methods=['POST'])
42
+ def generate_text():
43
+ try:
44
+ data = request.json
45
+ prompt = data.get('inputs', data.get('prompt', ''))
46
+ max_tokens = data.get('parameters', {}).get('max_new_tokens', 100)
47
+
48
+ if not prompt:
49
+ return jsonify({"error": "No prompt provided"}), 400
50
+
51
+ if model and tokenizer:
52
+ # Use MLX
53
+ from mlx_lm import generate as mlx_generate
54
+ response = mlx_generate(model, tokenizer, prompt=prompt, max_tokens=max_tokens)
55
+ elif model:
56
+ # Use transformers pipeline
57
+ result = model(prompt, max_length=len(prompt.split()) + max_tokens, do_sample=True)
58
+ response = result[0]['generated_text']
59
+ else:
60
+ return jsonify({"error": "Model not loaded"}), 500
61
+
62
+ return jsonify([{"generated_text": response}])
63
+
64
+ except Exception as e:
65
+ logger.error(f"Generation error: {e}")
66
+ return jsonify({"error": str(e)}), 500
67
+
68
+ @app.route('/health', methods=['GET'])
69
+ def health():
70
+ return jsonify({
71
+ "status": "healthy",
72
+ "model_loaded": model is not None
73
+ })
74
+
75
+ @app.route('/', methods=['GET'])
76
+ def home():
77
+ return jsonify({
78
+ "message": "Phi 3.5 API is running!",
79
+ "endpoints": {
80
+ "generate": "POST /generate",
81
+ "health": "GET /health"
82
+ }
83
+ })
84
+
85
+ if __name__ == '__main__':
86
+ logger.info("Starting Phi 3.5 API...")
87
+ load_model()
88
+ port = int(os.environ.get('PORT', 7860))
89
+ app.run(host='0.0.0.0', port=port, debug=False)