File size: 7,742 Bytes
c1b2914 46545a3 c1b2914 46545a3 c1b2914 46545a3 c1b2914 46545a3 c1b2914 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 | from flask import Flask, request, jsonify, send_from_directory
from flask_cors import CORS
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import torch
import os
import sys
app = Flask(__name__, static_folder='static')
CORS(app)
MODEL_NAME = "KASHH-4/phi_finetuned"
print("\n" + "="*80)
print("π LEGALDOCS AI - MODEL INITIALIZATION")
print("="*80)
print(f"π¦ Model: {MODEL_NAME}")
print(f"π Python: {torch.__version__}")
print(f"π₯ PyTorch: {torch.__version__}")
print(f"π€ Transformers: Loading...")
print("="*80 + "\n")
print("Loading tokenizer from YOUR merged model (slow tokenizer)...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=False)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
print("β
Tokenizer loaded successfully!")
print(f" - Vocab size: {tokenizer.vocab_size}")
print(f" - Model max length: {tokenizer.model_max_length}")
print(f" - Pad token: {tokenizer.pad_token}")
print("Loading YOUR model weights...")
# Optimized for 18GB RAM with 4-bit quantization
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True,
)
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
quantization_config=quantization_config,
device_map="auto",
low_cpu_mem_usage=True,
trust_remote_code=True,
torch_dtype=torch.float16,
)
print("β
Model loaded successfully!")
print(f" - Device: {model.device}")
print(f" - Model type: {type(model).__name__}")
print(f" - Quantization: 4-bit NF4")
print(f" - Compute dtype: float16")
# Memory info
if torch.cuda.is_available():
print(f" - GPU: {torch.cuda.get_device_name(0)}")
print(f" - GPU Memory allocated: {torch.cuda.memory_allocated(0) / 1024**3:.2f} GB")
print(f" - GPU Memory reserved: {torch.cuda.memory_reserved(0) / 1024**3:.2f} GB")
else:
print(f" - Running on CPU")
print("\n" + "="*80)
print("β
MODEL READY - Server starting...")
print("="*80 + "\n")
@app.route('/')
def index():
return send_from_directory('static', 'index.html')
@app.route('/api/generate', methods=['POST'])
def generate():
import time
try:
print("\n" + "="*80, flush=True)
print("π NEW GENERATION REQUEST RECEIVED", flush=True)
print("="*80, flush=True)
sys.stdout.flush()
data = request.json
if not data or 'prompt' not in data:
print("β ERROR: Missing prompt in request body", flush=True)
sys.stdout.flush()
return jsonify({'error': 'Missing prompt in request body'}), 400
prompt = data['prompt']
max_new_tokens = data.get('max_new_tokens', 400)
temperature = data.get('temperature', 0.7)
top_p = data.get('top_p', 0.9)
print(f"\nπ REQUEST PARAMETERS:", flush=True)
print(f" - Prompt length: {len(prompt)} characters", flush=True)
print(f" - Prompt preview: {prompt[:200]}...", flush=True)
print(f" - Max new tokens: {max_new_tokens}", flush=True)
print(f" - Temperature: {temperature}", flush=True)
print(f" - Top P: {top_p}", flush=True)
sys.stdout.flush()
print(f"\nπ TOKENIZING INPUT...", flush=True)
sys.stdout.flush()
tokenize_start = time.time()
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
tokenize_time = time.time() - tokenize_start
input_token_count = inputs['input_ids'].shape[1]
print(f" β
Tokenization complete in {tokenize_time:.2f}s", flush=True)
print(f" - Input tokens: {input_token_count}", flush=True)
print(f" - Device: {model.device}", flush=True)
sys.stdout.flush()
print(f"\nπ§ GENERATING TEXT WITH MODEL...", flush=True)
print(f" Model: {MODEL_NAME}", flush=True)
print(f" Status: Running inference...", flush=True)
sys.stdout.flush()
generation_start = time.time()
# Optimized for Phi-3 on CPU - minimal tokens to avoid timeout
with torch.no_grad():
torch.set_num_threads(2) # Use both CPU cores
outputs = model.generate(
**inputs,
max_new_tokens=100, # Very short to prevent timeout
do_sample=False, # Greedy decoding
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
use_cache=False # Disable cache compatibility issue
)
generation_time = time.time() - generation_start
output_token_count = outputs.shape[1]
tokens_generated = output_token_count - input_token_count
tokens_per_second = tokens_generated / generation_time if generation_time > 0 else 0
print(f" β
Generation complete in {generation_time:.2f}s", flush=True)
print(f" - Output tokens: {output_token_count}", flush=True)
print(f" - New tokens generated: {tokens_generated}", flush=True)
print(f" - Speed: {tokens_per_second:.2f} tokens/second", flush=True)
sys.stdout.flush()
print(f"\nπ DECODING OUTPUT...", flush=True)
sys.stdout.flush()
decode_start = time.time()
# Decode the full output
full_output = tokenizer.decode(outputs[0], skip_special_tokens=True)
decode_time = time.time() - decode_start
print(f" β
Decoding complete in {decode_time:.2f}s", flush=True)
sys.stdout.flush()
# Remove the prompt from the output to return only the generated text
generated_text = full_output[len(prompt):].strip()
print(f"\nπ FINAL RESULTS:", flush=True)
print(f" - Generated text length: {len(generated_text)} characters", flush=True)
print(f" - Total processing time: {(time.time() - tokenize_start):.2f}s", flush=True)
print(f"\nπ GENERATED OUTPUT:", flush=True)
print("="*80, flush=True)
print(generated_text, flush=True)
print("="*80, flush=True)
sys.stdout.flush()
print(f"\nβ
REQUEST COMPLETED SUCCESSFULLY", flush=True)
print("="*80 + "\n", flush=True)
sys.stdout.flush()
return jsonify({
'generated_text': generated_text,
'prompt': prompt
})
except Exception as e:
print(f"\nβ ERROR DURING GENERATION:", flush=True)
print(f" Error type: {type(e).__name__}", flush=True)
print(f" Error message: {str(e)}", flush=True)
sys.stdout.flush()
import traceback
print(f" Traceback:\n{traceback.format_exc()}", flush=True)
print("="*80 + "\n", flush=True)
sys.stdout.flush()
return jsonify({'error': str(e)}), 500
@app.route('/api/health', methods=['GET'])
def health():
return jsonify({
'status': 'ok',
'model': MODEL_NAME,
'device': str(model.device)
})
if __name__ == '__main__':
port = int(os.environ.get('PORT', 7860))
print(f"\nπ Starting Flask server on port {port}...")
print(f"π Access the app at: http://localhost:{port}")
print(f"π Health check: http://localhost:{port}/api/health")
print(f"π API endpoint: http://localhost:{port}/api/generate\n")
app.run(host='0.0.0.0', port=port, debug=False)
|