HusainHG commited on
Commit
47fae5d
·
verified ·
1 Parent(s): eb0ceee

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +93 -93
app.py CHANGED
@@ -1,93 +1,93 @@
1
- from flask import Flask, request, jsonify, send_from_directory
2
- from flask_cors import CORS
3
- from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
4
- import torch
5
- import os
6
-
7
- app = Flask(__name__, static_folder='static')
8
- CORS(app)
9
-
10
- MODEL_NAME = "KASHH-4/mistral_fine-tuned"
11
-
12
- print(f"Loading model: {MODEL_NAME}")
13
-
14
- print("Loading tokenizer from YOUR merged model (slow tokenizer)...")
15
- # Your model HAS tokenizer files, use them with use_fast=False
16
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=False)
17
-
18
- if tokenizer.pad_token is None:
19
- tokenizer.pad_token = tokenizer.eos_token
20
-
21
- print("Tokenizer loaded successfully!")
22
-
23
- print("Loading YOUR model weights...")
24
- # Optimized for 16GB RAM - load in 8-bit quantization
25
- quantization_config = BitsAndBytesConfig(
26
- load_in_8bit=True, # Use 8-bit to fit in 16GB RAM
27
- llm_int8_threshold=6.0
28
- )
29
-
30
- model = AutoModelForCausalLM.from_pretrained(
31
- MODEL_NAME,
32
- quantization_config=quantization_config,
33
- device_map="auto",
34
- low_cpu_mem_usage=True,
35
- trust_remote_code=True
36
- )
37
- print("Model loaded successfully!")
38
-
39
-
40
- @app.route('/')
41
- def index():
42
- return send_from_directory('static', 'index.html')
43
-
44
-
45
- @app.route('/api/generate', methods=['POST'])
46
- def generate():
47
- try:
48
- data = request.json
49
-
50
- if not data or 'prompt' not in data:
51
- return jsonify({'error': 'Missing prompt in request body'}), 400
52
-
53
- prompt = data['prompt']
54
- max_new_tokens = data.get('max_new_tokens', 256)
55
- temperature = data.get('temperature', 0.7)
56
- top_p = data.get('top_p', 0.9)
57
-
58
- inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
59
-
60
- with torch.no_grad():
61
- outputs = model.generate(
62
- **inputs,
63
- max_new_tokens=max_new_tokens,
64
- temperature=temperature,
65
- top_p=top_p,
66
- do_sample=True,
67
- pad_token_id=tokenizer.eos_token_id
68
- )
69
-
70
- generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
71
-
72
- return jsonify({
73
- 'generated_text': generated_text,
74
- 'prompt': prompt
75
- })
76
-
77
- except Exception as e:
78
- print(f"Error during generation: {e}")
79
- return jsonify({'error': str(e)}), 500
80
-
81
-
82
- @app.route('/api/health', methods=['GET'])
83
- def health():
84
- return jsonify({
85
- 'status': 'ok',
86
- 'model': MODEL_NAME,
87
- 'device': str(model.device)
88
- })
89
-
90
-
91
- if __name__ == '__main__':
92
- port = int(os.environ.get('PORT', 7860))
93
- app.run(host='0.0.0.0', port=port, debug=False)
 
1
+ from flask import Flask, request, jsonify, send_from_directory
2
+ from flask_cors import CORS
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
4
+ import torch
5
+ import os
6
+
7
+ app = Flask(__name__, static_folder='static')
8
+ CORS(app)
9
+
10
+ MODEL_NAME = "KASHH-4/Gemma-finetuned"
11
+
12
+ print(f"Loading model: {MODEL_NAME}")
13
+
14
+ print("Loading tokenizer from YOUR merged model (slow tokenizer)...")
15
+ # Your model HAS tokenizer files, use them with use_fast=False
16
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=False)
17
+
18
+ if tokenizer.pad_token is None:
19
+ tokenizer.pad_token = tokenizer.eos_token
20
+
21
+ print("Tokenizer loaded successfully!")
22
+
23
+ print("Loading YOUR model weights...")
24
+ # Optimized for 16GB RAM - load in 8-bit quantization
25
+ quantization_config = BitsAndBytesConfig(
26
+ load_in_8bit=True, # Use 8-bit to fit in 16GB RAM
27
+ llm_int8_threshold=6.0
28
+ )
29
+
30
+ model = AutoModelForCausalLM.from_pretrained(
31
+ MODEL_NAME,
32
+ quantization_config=quantization_config,
33
+ device_map="auto",
34
+ low_cpu_mem_usage=True,
35
+ trust_remote_code=True
36
+ )
37
+ print("Model loaded successfully!")
38
+
39
+
40
+ @app.route('/')
41
+ def index():
42
+ return send_from_directory('static', 'index.html')
43
+
44
+
45
+ @app.route('/api/generate', methods=['POST'])
46
+ def generate():
47
+ try:
48
+ data = request.json
49
+
50
+ if not data or 'prompt' not in data:
51
+ return jsonify({'error': 'Missing prompt in request body'}), 400
52
+
53
+ prompt = data['prompt']
54
+ max_new_tokens = data.get('max_new_tokens', 256)
55
+ temperature = data.get('temperature', 0.7)
56
+ top_p = data.get('top_p', 0.9)
57
+
58
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
59
+
60
+ with torch.no_grad():
61
+ outputs = model.generate(
62
+ **inputs,
63
+ max_new_tokens=max_new_tokens,
64
+ temperature=temperature,
65
+ top_p=top_p,
66
+ do_sample=True,
67
+ pad_token_id=tokenizer.eos_token_id
68
+ )
69
+
70
+ generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
71
+
72
+ return jsonify({
73
+ 'generated_text': generated_text,
74
+ 'prompt': prompt
75
+ })
76
+
77
+ except Exception as e:
78
+ print(f"Error during generation: {e}")
79
+ return jsonify({'error': str(e)}), 500
80
+
81
+
82
+ @app.route('/api/health', methods=['GET'])
83
+ def health():
84
+ return jsonify({
85
+ 'status': 'ok',
86
+ 'model': MODEL_NAME,
87
+ 'device': str(model.device)
88
+ })
89
+
90
+
91
+ if __name__ == '__main__':
92
+ port = int(os.environ.get('PORT', 7860))
93
+ app.run(host='0.0.0.0', port=port, debug=False)