kacperbb commited on
Commit
8d130b1
·
verified ·
1 Parent(s): 01e870b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +54 -36
app.py CHANGED
@@ -1,67 +1,85 @@
1
  from flask import Flask, request, jsonify
2
- import requests
3
  import logging
4
- import os
5
 
6
  logging.basicConfig(level=logging.INFO)
7
  logger = logging.getLogger(__name__)
8
 
9
  app = Flask(__name__)
 
 
10
 
11
- # Your model endpoint
12
- MODEL_URL = "https://api-inference.huggingface.co/models/kacperbb/phi-3.5-merged-lora"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
  @app.route('/generate', methods=['POST'])
15
  def generate():
 
16
  try:
17
  data = request.json
18
  prompt = data.get('inputs', '')
19
- max_tokens = data.get('parameters', {}).get('max_new_tokens', 100)
20
-
21
- # Get HF token from environment variable (you'll need to set this in Space settings)
22
- hf_token = os.environ.get('HF_TOKEN')
23
 
24
- if hf_token:
25
- # Call your actual model via HF Inference API
26
- headers = {"Authorization": f"Bearer {hf_token}"}
27
- payload = {
28
- "inputs": prompt,
29
- "parameters": {"max_new_tokens": max_tokens}
30
- }
 
31
 
32
- try:
33
- response = requests.post(MODEL_URL, headers=headers, json=payload, timeout=30)
34
- if response.status_code == 200:
35
- return jsonify(response.json())
36
- else:
37
- logger.warning(f"HF API returned status {response.status_code}")
38
- except requests.exceptions.RequestException as e:
39
- logger.error(f"Request failed: {e}")
40
-
41
- # Fallback response if API call fails
42
- return jsonify([{"generated_text": f"AI Response to '{prompt}': This would be generated by your fine-tuned Phi 3.5 model."}])
43
-
 
44
  except Exception as e:
45
- logger.error(f"Error: {e}")
46
  return jsonify({"error": str(e)}), 500
47
 
48
  @app.route('/health', methods=['GET'])
49
  def health():
50
- has_token = bool(os.environ.get('HF_TOKEN'))
51
  return jsonify({
52
- "status": "healthy",
53
- "model": "kacperbb/phi-3.5-merged-lora",
54
- "has_token": has_token
55
  })
56
 
57
  @app.route('/')
58
  def home():
59
  return jsonify({
60
- "message": "Phi 3.5 Fine-tuned API Running!",
61
- "model": "kacperbb/phi-3.5-merged-lora",
62
- "endpoint": "/generate"
63
  })
64
 
65
  if __name__ == '__main__':
66
- logger.info("Starting Phi 3.5 Fine-tuned API...")
 
67
  app.run(host='0.0.0.0', port=7860)
 
1
  from flask import Flask, request, jsonify
 
2
  import logging
3
+ import torch
4
 
5
  logging.basicConfig(level=logging.INFO)
6
  logger = logging.getLogger(__name__)
7
 
8
  app = Flask(__name__)
9
+ model = None
10
+ tokenizer = None
11
 
12
+ def load_model():
13
+ global model, tokenizer
14
+ try:
15
+ logger.info("Loading YOUR fine-tuned model with transformers...")
16
+ from transformers import AutoTokenizer, AutoModelForCausalLM
17
+
18
+ tokenizer = AutoTokenizer.from_pretrained(
19
+ "kacperbb/phi-3.5-merged-lora",
20
+ trust_remote_code=True
21
+ )
22
+ model = AutoModelForCausalLM.from_pretrained(
23
+ "kacperbb/phi-3.5-merged-lora",
24
+ trust_remote_code=True,
25
+ torch_dtype=torch.float16,
26
+ device_map="cpu"
27
+ )
28
+ logger.info("✅ YOUR fine-tuned model loaded successfully!")
29
+ return True
30
+ except Exception as e:
31
+ logger.error(f"❌ Error loading model: {e}")
32
+ return False
33
 
34
  @app.route('/generate', methods=['POST'])
35
  def generate():
36
+ global model, tokenizer
37
  try:
38
  data = request.json
39
  prompt = data.get('inputs', '')
40
+ max_tokens = data.get('parameters', {}).get('max_new_tokens', 50)
 
 
 
41
 
42
+ if model and tokenizer:
43
+ messages = [{"role": "user", "content": prompt}]
44
+ inputs = tokenizer.apply_chat_template(
45
+ messages,
46
+ add_generation_prompt=True,
47
+ tokenize=True,
48
+ return_tensors="pt"
49
+ )
50
 
51
+ with torch.no_grad():
52
+ outputs = model.generate(
53
+ inputs,
54
+ max_new_tokens=max_tokens,
55
+ do_sample=True,
56
+ temperature=0.7
57
+ )
58
+
59
+ response = tokenizer.decode(outputs[0][inputs.shape[-1]:], skip_special_tokens=True)
60
+ return jsonify([{"generated_text": response}])
61
+ else:
62
+ return jsonify({"error": "Model not loaded"}), 500
63
+
64
  except Exception as e:
65
+ logger.error(f"Generation error: {e}")
66
  return jsonify({"error": str(e)}), 500
67
 
68
  @app.route('/health', methods=['GET'])
69
  def health():
 
70
  return jsonify({
71
+ "status": "healthy",
72
+ "model_loaded": model is not None
 
73
  })
74
 
75
  @app.route('/')
76
  def home():
77
  return jsonify({
78
+ "message": "Your Phi 3.5 Fine-tuned Model API",
79
+ "model": "kacperbb/phi-3.5-merged-lora"
 
80
  })
81
 
82
  if __name__ == '__main__':
83
+ logger.info("Starting your fine-tuned model API...")
84
+ load_model()
85
  app.run(host='0.0.0.0', port=7860)