kacperbb commited on
Commit
c24fba4
Β·
verified Β·
1 Parent(s): a756dca

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -23
app.py CHANGED
@@ -11,43 +11,45 @@ logger = logging.getLogger(__name__)
11
 
12
  app = Flask(__name__)
13
  model = None
 
14
 
15
  def load_model():
16
- global model
17
  try:
18
  logger.info("Loading YOUR fine-tuned model...")
19
- from transformers import pipeline
20
 
21
- model = pipeline(
22
- "text-generation",
23
- model="kacperbb/phi-3.5-merged-lora",
24
  trust_remote_code=True
25
  )
 
 
 
 
 
 
 
 
 
 
 
26
  logger.info("βœ… YOUR fine-tuned model loaded successfully!")
27
  return True
28
  except Exception as e:
29
  logger.error(f"❌ Error loading your model: {e}")
30
- logger.info("Trying with base model...")
31
  try:
32
- model = pipeline(
33
- "text-generation",
34
- model="microsoft/Phi-3.5-mini-instruct",
35
- trust_remote_code=True
36
- )
37
- logger.info("βœ… Base model loaded as fallback")
38
  return True
39
- except Exception as e2:
40
- logger.error(f"❌ Fallback failed: {e2}")
41
- try:
42
- model = pipeline("text-generation", model="gpt2")
43
- logger.info("βœ… GPT-2 fallback model loaded")
44
- return True
45
- except:
46
- return False
47
 
48
  @app.route('/generate', methods=['POST'])
49
  def generate_text():
50
- global model
51
  try:
52
  data = request.json
53
  prompt = data.get('inputs', data.get('prompt', ''))
@@ -56,11 +58,27 @@ def generate_text():
56
  if not prompt:
57
  return jsonify({"error": "No prompt provided"}), 400
58
 
59
- if model:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  result = model(prompt, max_new_tokens=max_tokens, do_sample=True)
61
  response = result[0]['generated_text']
62
  else:
63
- return jsonify({"error": "Model not loaded"}), 500
64
 
65
  return jsonify([{"generated_text": response}])
66
 
@@ -87,6 +105,7 @@ def home():
87
  })
88
 
89
  if __name__ == '__main__':
 
90
  logger.info("Starting Phi 3.5 API...")
91
  load_model()
92
  port = int(os.environ.get('PORT', 7860))
 
11
 
12
  app = Flask(__name__)
13
  model = None
14
+ tokenizer = None
15
 
16
  def load_model():
17
+ global model, tokenizer
18
  try:
19
  logger.info("Loading YOUR fine-tuned model...")
20
+ from transformers import AutoTokenizer, AutoModelForCausalLM
21
 
22
+ # Load model and tokenizer separately for better control
23
+ tokenizer = AutoTokenizer.from_pretrained(
24
+ "kacperbb/phi-3.5-merged-lora",
25
  trust_remote_code=True
26
  )
27
+ model = AutoModelForCausalLM.from_pretrained(
28
+ "kacperbb/phi-3.5-merged-lora",
29
+ trust_remote_code=True,
30
+ torch_dtype="auto",
31
+ device_map="cpu"
32
+ )
33
+
34
+ # Set pad token if not set
35
+ if tokenizer.pad_token is None:
36
+ tokenizer.pad_token = tokenizer.eos_token
37
+
38
  logger.info("βœ… YOUR fine-tuned model loaded successfully!")
39
  return True
40
  except Exception as e:
41
  logger.error(f"❌ Error loading your model: {e}")
 
42
  try:
43
+ from transformers import pipeline
44
+ model = pipeline("text-generation", model="gpt2")
45
+ logger.info("βœ… Fallback model loaded")
 
 
 
46
  return True
47
+ except:
48
+ return False
 
 
 
 
 
 
49
 
50
  @app.route('/generate', methods=['POST'])
51
  def generate_text():
52
+ global model, tokenizer
53
  try:
54
  data = request.json
55
  prompt = data.get('inputs', data.get('prompt', ''))
 
58
  if not prompt:
59
  return jsonify({"error": "No prompt provided"}), 400
60
 
61
+ if model and tokenizer and hasattr(model, 'generate'):
62
+ # Use model directly
63
+ inputs = tokenizer(prompt, return_tensors="pt", padding=True)
64
+
65
+ with torch.no_grad():
66
+ outputs = model.generate(
67
+ inputs.input_ids,
68
+ attention_mask=inputs.attention_mask,
69
+ max_new_tokens=max_tokens,
70
+ do_sample=True,
71
+ temperature=0.7,
72
+ pad_token_id=tokenizer.eos_token_id
73
+ )
74
+
75
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
76
+ elif model and hasattr(model, '__call__'):
77
+ # Use pipeline
78
  result = model(prompt, max_new_tokens=max_tokens, do_sample=True)
79
  response = result[0]['generated_text']
80
  else:
81
+ return jsonify({"error": "Model not properly loaded"}), 500
82
 
83
  return jsonify([{"generated_text": response}])
84
 
 
105
  })
106
 
107
  if __name__ == '__main__':
108
+ import torch
109
  logger.info("Starting Phi 3.5 API...")
110
  load_model()
111
  port = int(os.environ.get('PORT', 7860))