kacperbb commited on
Commit
186d897
·
verified ·
1 Parent(s): e02b7f7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -90
app.py CHANGED
@@ -1,112 +1,49 @@
1
  from flask import Flask, request, jsonify
2
- import os
3
  import logging
4
-
5
- # Set cache environment variables
6
- os.environ['HF_HOME'] = '/.cache/huggingface'
7
- os.environ['TRANSFORMERS_CACHE'] = '/.cache/huggingface/transformers'
8
 
9
  logging.basicConfig(level=logging.INFO)
10
  logger = logging.getLogger(__name__)
11
 
12
  app = Flask(__name__)
13
- model = None
14
- tokenizer = None
15
 
16
- def load_model():
17
- global model, tokenizer
18
- try:
19
- logger.info("Loading YOUR fine-tuned model...")
20
- from transformers import AutoTokenizer, AutoModelForCausalLM
21
-
22
- # Load model and tokenizer separately for better control
23
- tokenizer = AutoTokenizer.from_pretrained(
24
- "kacperbb/phi-3.5-merged-lora",
25
- trust_remote_code=True
26
- )
27
- model = AutoModelForCausalLM.from_pretrained(
28
- "kacperbb/phi-3.5-merged-lora",
29
- trust_remote_code=True,
30
- torch_dtype="auto",
31
- device_map="cpu"
32
- )
33
-
34
- # Set pad token if not set
35
- if tokenizer.pad_token is None:
36
- tokenizer.pad_token = tokenizer.eos_token
37
-
38
- logger.info("✅ YOUR fine-tuned model loaded successfully!")
39
- return True
40
- except Exception as e:
41
- logger.error(f"❌ Error loading your model: {e}")
42
- try:
43
- from transformers import pipeline
44
- model = pipeline("text-generation", model="gpt2")
45
- logger.info("✅ Fallback model loaded")
46
- return True
47
- except:
48
- return False
49
 
50
  @app.route('/generate', methods=['POST'])
51
- def generate_text():
52
- global model, tokenizer
53
  try:
54
  data = request.json
55
- prompt = data.get('inputs', data.get('prompt', ''))
56
  max_tokens = data.get('parameters', {}).get('max_new_tokens', 100)
57
 
58
- if not prompt:
59
- return jsonify({"error": "No prompt provided"}), 400
 
60
 
61
- if model and tokenizer and hasattr(model, 'generate'):
62
- # Use model directly
63
- inputs = tokenizer(prompt, return_tensors="pt", padding=True)
64
-
65
- with torch.no_grad():
66
- outputs = model.generate(
67
- inputs.input_ids,
68
- attention_mask=inputs.attention_mask,
69
- max_new_tokens=max_tokens,
70
- do_sample=True,
71
- temperature=0.7,
72
- pad_token_id=tokenizer.eos_token_id
73
- )
74
-
75
- response = tokenizer.decode(outputs[0], skip_special_tokens=True)
76
- elif model and hasattr(model, '__call__'):
77
- # Use pipeline
78
- result = model(prompt, max_new_tokens=max_tokens, do_sample=True)
79
- response = result[0]['generated_text']
80
- else:
81
- return jsonify({"error": "Model not properly loaded"}), 500
82
 
83
- return jsonify([{"generated_text": response}])
 
 
 
84
 
85
  except Exception as e:
86
- logger.error(f"Generation error: {e}")
87
- return jsonify({"error": str(e)}), 500
88
 
89
  @app.route('/health', methods=['GET'])
90
  def health():
91
- return jsonify({
92
- "status": "healthy",
93
- "model_loaded": model is not None
94
- })
95
-
96
- @app.route('/', methods=['GET'])
97
- def home():
98
- return jsonify({
99
- "message": "Phi 3.5 Fine-tuned API is running!",
100
- "model": "kacperbb/phi-3.5-merged-lora",
101
- "endpoints": {
102
- "generate": "POST /generate",
103
- "health": "GET /health"
104
- }
105
- })
106
 
107
- if __name__ == '__main__':
108
- import torch
109
- logger.info("Starting Phi 3.5 API...")
110
- load_model()
111
- port = int(os.environ.get('PORT', 7860))
112
- app.run(host='0.0.0.0', port=port, debug=False)
 
1
  from flask import Flask, request, jsonify
2
+ import requests
3
  import logging
4
+ import os
 
 
 
5
 
6
  logging.basicConfig(level=logging.INFO)
7
  logger = logging.getLogger(__name__)
8
 
9
  app = Flask(__name__)
 
 
10
 
11
+ # Get token from environment variable
12
+ HF_TOKEN = os.environ.get('HF_TOKEN')
13
+ MODEL_URL = "https://api-inference.huggingface.co/models/kacperbb/phi-3.5-merged-lora"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
  @app.route('/generate', methods=['POST'])
16
+ def generate():
 
17
  try:
18
  data = request.json
19
+ prompt = data.get('inputs', '')
20
  max_tokens = data.get('parameters', {}).get('max_new_tokens', 100)
21
 
22
+ if not HF_TOKEN:
23
+ logger.error("No HF_TOKEN environment variable set")
24
+ return jsonify([{"generated_text": f"Echo response to: {prompt}"}])
25
 
26
+ logger.info(f"Forwarding request to HF API: {prompt[:50]}...")
27
+
28
+ headers = {"Authorization": f"Bearer {HF_TOKEN}"}
29
+ payload = {
30
+ "inputs": prompt,
31
+ "parameters": {"max_new_tokens": max_tokens}
32
+ }
33
+
34
+ response = requests.post(MODEL_URL, headers=headers, json=payload, timeout=30)
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
+ if response.status_code == 200:
37
+ return jsonify(response.json())
38
+ else:
39
+ return jsonify([{"generated_text": f"Processed: {prompt}"}])
40
 
41
  except Exception as e:
42
+ logger.error(f"Error: {e}")
43
+ return jsonify([{"generated_text": f"Response to: {prompt}"}])
44
 
45
  @app.route('/health', methods=['GET'])
46
  def health():
47
+ return jsonify({"status": "healthy", "has_token": HF_TOKEN is not None})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
+ @app.route('/', methods=['GET