| from flask import Flask, request, jsonify, send_from_directory
|
| from flask_cors import CORS
|
| from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
|
| import torch
|
| import os
|
| import sys
|
|
|
| app = Flask(__name__, static_folder='static')
|
| CORS(app)
|
|
|
| MODEL_NAME = "KASHH-4/phi_finetuned"
|
|
|
| print("\n" + "="*80)
|
| print("π LEGALDOCS AI - MODEL INITIALIZATION")
|
| print("="*80)
|
| print(f"π¦ Model: {MODEL_NAME}")
|
| print(f"π Python: {torch.__version__}")
|
| print(f"π₯ PyTorch: {torch.__version__}")
|
| print(f"π€ Transformers: Loading...")
|
| print("="*80 + "\n")
|
|
|
| print("Loading tokenizer from YOUR merged model (slow tokenizer)...")
|
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=False)
|
|
|
| if tokenizer.pad_token is None:
|
| tokenizer.pad_token = tokenizer.eos_token
|
|
|
| print("β
Tokenizer loaded successfully!")
|
| print(f" - Vocab size: {tokenizer.vocab_size}")
|
| print(f" - Model max length: {tokenizer.model_max_length}")
|
| print(f" - Pad token: {tokenizer.pad_token}")
|
|
|
| print("Loading YOUR model weights...")
|
|
|
| quantization_config = BitsAndBytesConfig(
|
| load_in_4bit=True,
|
| bnb_4bit_compute_dtype=torch.float16,
|
| bnb_4bit_quant_type="nf4",
|
| bnb_4bit_use_double_quant=True,
|
| )
|
|
|
| model = AutoModelForCausalLM.from_pretrained(
|
| MODEL_NAME,
|
| quantization_config=quantization_config,
|
| device_map="auto",
|
| low_cpu_mem_usage=True,
|
| trust_remote_code=True,
|
| torch_dtype=torch.float16,
|
| )
|
|
|
| print("β
Model loaded successfully!")
|
| print(f" - Device: {model.device}")
|
| print(f" - Model type: {type(model).__name__}")
|
| print(f" - Quantization: 4-bit NF4")
|
| print(f" - Compute dtype: float16")
|
|
|
|
|
| if torch.cuda.is_available():
|
| print(f" - GPU: {torch.cuda.get_device_name(0)}")
|
| print(f" - GPU Memory allocated: {torch.cuda.memory_allocated(0) / 1024**3:.2f} GB")
|
| print(f" - GPU Memory reserved: {torch.cuda.memory_reserved(0) / 1024**3:.2f} GB")
|
| else:
|
| print(f" - Running on CPU")
|
|
|
| print("\n" + "="*80)
|
| print("β
MODEL READY - Server starting...")
|
| print("="*80 + "\n")
|
|
|
|
|
| @app.route('/')
|
| def index():
|
| return send_from_directory('static', 'index.html')
|
|
|
|
|
| @app.route('/api/generate', methods=['POST'])
|
| def generate():
|
| import time
|
| try:
|
| print("\n" + "="*80, flush=True)
|
| print("π NEW GENERATION REQUEST RECEIVED", flush=True)
|
| print("="*80, flush=True)
|
| sys.stdout.flush()
|
|
|
| data = request.json
|
|
|
| if not data or 'prompt' not in data:
|
| print("β ERROR: Missing prompt in request body", flush=True)
|
| sys.stdout.flush()
|
| return jsonify({'error': 'Missing prompt in request body'}), 400
|
|
|
| prompt = data['prompt']
|
| max_new_tokens = data.get('max_new_tokens', 400)
|
| temperature = data.get('temperature', 0.7)
|
| top_p = data.get('top_p', 0.9)
|
|
|
| print(f"\nπ REQUEST PARAMETERS:", flush=True)
|
| print(f" - Prompt length: {len(prompt)} characters", flush=True)
|
| print(f" - Prompt preview: {prompt[:200]}...", flush=True)
|
| print(f" - Max new tokens: {max_new_tokens}", flush=True)
|
| print(f" - Temperature: {temperature}", flush=True)
|
| print(f" - Top P: {top_p}", flush=True)
|
| sys.stdout.flush()
|
|
|
| print(f"\nπ TOKENIZING INPUT...", flush=True)
|
| sys.stdout.flush()
|
| tokenize_start = time.time()
|
| inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
|
| tokenize_time = time.time() - tokenize_start
|
| input_token_count = inputs['input_ids'].shape[1]
|
| print(f" β
Tokenization complete in {tokenize_time:.2f}s", flush=True)
|
| print(f" - Input tokens: {input_token_count}", flush=True)
|
| print(f" - Device: {model.device}", flush=True)
|
| sys.stdout.flush()
|
|
|
| print(f"\nπ§ GENERATING TEXT WITH MODEL...", flush=True)
|
| print(f" Model: {MODEL_NAME}", flush=True)
|
| print(f" Status: Running inference...", flush=True)
|
| sys.stdout.flush()
|
| generation_start = time.time()
|
|
|
|
|
| with torch.no_grad():
|
| torch.set_num_threads(2)
|
| outputs = model.generate(
|
| **inputs,
|
| max_new_tokens=100,
|
| do_sample=False,
|
| pad_token_id=tokenizer.pad_token_id,
|
| eos_token_id=tokenizer.eos_token_id,
|
| use_cache=False
|
| )
|
|
|
| generation_time = time.time() - generation_start
|
| output_token_count = outputs.shape[1]
|
| tokens_generated = output_token_count - input_token_count
|
| tokens_per_second = tokens_generated / generation_time if generation_time > 0 else 0
|
|
|
| print(f" β
Generation complete in {generation_time:.2f}s", flush=True)
|
| print(f" - Output tokens: {output_token_count}", flush=True)
|
| print(f" - New tokens generated: {tokens_generated}", flush=True)
|
| print(f" - Speed: {tokens_per_second:.2f} tokens/second", flush=True)
|
| sys.stdout.flush()
|
|
|
| print(f"\nπ DECODING OUTPUT...", flush=True)
|
| sys.stdout.flush()
|
| decode_start = time.time()
|
|
|
| full_output = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| decode_time = time.time() - decode_start
|
| print(f" β
Decoding complete in {decode_time:.2f}s", flush=True)
|
| sys.stdout.flush()
|
|
|
|
|
| generated_text = full_output[len(prompt):].strip()
|
|
|
| print(f"\nπ FINAL RESULTS:", flush=True)
|
| print(f" - Generated text length: {len(generated_text)} characters", flush=True)
|
| print(f" - Total processing time: {(time.time() - tokenize_start):.2f}s", flush=True)
|
| print(f"\nπ GENERATED OUTPUT:", flush=True)
|
| print("="*80, flush=True)
|
| print(generated_text, flush=True)
|
| print("="*80, flush=True)
|
| sys.stdout.flush()
|
|
|
| print(f"\nβ
REQUEST COMPLETED SUCCESSFULLY", flush=True)
|
| print("="*80 + "\n", flush=True)
|
| sys.stdout.flush()
|
|
|
| return jsonify({
|
| 'generated_text': generated_text,
|
| 'prompt': prompt
|
| })
|
|
|
| except Exception as e:
|
| print(f"\nβ ERROR DURING GENERATION:", flush=True)
|
| print(f" Error type: {type(e).__name__}", flush=True)
|
| print(f" Error message: {str(e)}", flush=True)
|
| sys.stdout.flush()
|
| import traceback
|
| print(f" Traceback:\n{traceback.format_exc()}", flush=True)
|
| print("="*80 + "\n", flush=True)
|
| sys.stdout.flush()
|
| return jsonify({'error': str(e)}), 500
|
|
|
|
|
| @app.route('/api/health', methods=['GET'])
|
| def health():
|
| return jsonify({
|
| 'status': 'ok',
|
| 'model': MODEL_NAME,
|
| 'device': str(model.device)
|
| })
|
|
|
|
|
| if __name__ == '__main__':
|
| port = int(os.environ.get('PORT', 7860))
|
| print(f"\nπ Starting Flask server on port {port}...")
|
| print(f"π Access the app at: http://localhost:{port}")
|
| print(f"π Health check: http://localhost:{port}/api/health")
|
| print(f"π API endpoint: http://localhost:{port}/api/generate\n")
|
| app.run(host='0.0.0.0', port=port, debug=False)
|
|
|