from flask import Flask, request, jsonify, send_from_directory from flask_cors import CORS from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig import torch import os import sys app = Flask(__name__, static_folder='static') CORS(app) MODEL_NAME = "KASHH-4/phi_finetuned" print("\n" + "="*80) print("šŸš€ LEGALDOCS AI - MODEL INITIALIZATION") print("="*80) print(f"šŸ“¦ Model: {MODEL_NAME}") print(f"šŸ Python: {torch.__version__}") print(f"šŸ”„ PyTorch: {torch.__version__}") print(f"šŸ¤— Transformers: Loading...") print("="*80 + "\n") print("Loading tokenizer from YOUR merged model (slow tokenizer)...") tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=False) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token print("āœ… Tokenizer loaded successfully!") print(f" - Vocab size: {tokenizer.vocab_size}") print(f" - Model max length: {tokenizer.model_max_length}") print(f" - Pad token: {tokenizer.pad_token}") print("Loading YOUR model weights...") # Optimized for 18GB RAM with 4-bit quantization quantization_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16, bnb_4bit_quant_type="nf4", bnb_4bit_use_double_quant=True, ) model = AutoModelForCausalLM.from_pretrained( MODEL_NAME, quantization_config=quantization_config, device_map="auto", low_cpu_mem_usage=True, trust_remote_code=True, torch_dtype=torch.float16, ) print("āœ… Model loaded successfully!") print(f" - Device: {model.device}") print(f" - Model type: {type(model).__name__}") print(f" - Quantization: 4-bit NF4") print(f" - Compute dtype: float16") # Memory info if torch.cuda.is_available(): print(f" - GPU: {torch.cuda.get_device_name(0)}") print(f" - GPU Memory allocated: {torch.cuda.memory_allocated(0) / 1024**3:.2f} GB") print(f" - GPU Memory reserved: {torch.cuda.memory_reserved(0) / 1024**3:.2f} GB") else: print(f" - Running on CPU") print("\n" + "="*80) print("āœ… MODEL READY - Server starting...") print("="*80 + "\n") @app.route('/') def index(): return send_from_directory('static', 'index.html') @app.route('/api/generate', methods=['POST']) def generate(): import time try: print("\n" + "="*80, flush=True) print("šŸš€ NEW GENERATION REQUEST RECEIVED", flush=True) print("="*80, flush=True) sys.stdout.flush() data = request.json if not data or 'prompt' not in data: print("āŒ ERROR: Missing prompt in request body", flush=True) sys.stdout.flush() return jsonify({'error': 'Missing prompt in request body'}), 400 prompt = data['prompt'] max_new_tokens = data.get('max_new_tokens', 400) temperature = data.get('temperature', 0.7) top_p = data.get('top_p', 0.9) print(f"\nšŸ“ REQUEST PARAMETERS:", flush=True) print(f" - Prompt length: {len(prompt)} characters", flush=True) print(f" - Prompt preview: {prompt[:200]}...", flush=True) print(f" - Max new tokens: {max_new_tokens}", flush=True) print(f" - Temperature: {temperature}", flush=True) print(f" - Top P: {top_p}", flush=True) sys.stdout.flush() print(f"\nšŸ”„ TOKENIZING INPUT...", flush=True) sys.stdout.flush() tokenize_start = time.time() inputs = tokenizer(prompt, return_tensors="pt").to(model.device) tokenize_time = time.time() - tokenize_start input_token_count = inputs['input_ids'].shape[1] print(f" āœ… Tokenization complete in {tokenize_time:.2f}s", flush=True) print(f" - Input tokens: {input_token_count}", flush=True) print(f" - Device: {model.device}", flush=True) sys.stdout.flush() print(f"\n🧠 GENERATING TEXT WITH MODEL...", flush=True) print(f" Model: {MODEL_NAME}", flush=True) print(f" Status: Running inference...", flush=True) sys.stdout.flush() generation_start = time.time() # Optimized for Phi-3 on CPU - minimal tokens to avoid timeout with torch.no_grad(): torch.set_num_threads(2) # Use both CPU cores outputs = model.generate( **inputs, max_new_tokens=100, # Very short to prevent timeout do_sample=False, # Greedy decoding pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id, use_cache=False # Disable cache compatibility issue ) generation_time = time.time() - generation_start output_token_count = outputs.shape[1] tokens_generated = output_token_count - input_token_count tokens_per_second = tokens_generated / generation_time if generation_time > 0 else 0 print(f" āœ… Generation complete in {generation_time:.2f}s", flush=True) print(f" - Output tokens: {output_token_count}", flush=True) print(f" - New tokens generated: {tokens_generated}", flush=True) print(f" - Speed: {tokens_per_second:.2f} tokens/second", flush=True) sys.stdout.flush() print(f"\nšŸ”„ DECODING OUTPUT...", flush=True) sys.stdout.flush() decode_start = time.time() # Decode the full output full_output = tokenizer.decode(outputs[0], skip_special_tokens=True) decode_time = time.time() - decode_start print(f" āœ… Decoding complete in {decode_time:.2f}s", flush=True) sys.stdout.flush() # Remove the prompt from the output to return only the generated text generated_text = full_output[len(prompt):].strip() print(f"\nšŸ“Š FINAL RESULTS:", flush=True) print(f" - Generated text length: {len(generated_text)} characters", flush=True) print(f" - Total processing time: {(time.time() - tokenize_start):.2f}s", flush=True) print(f"\nšŸ“„ GENERATED OUTPUT:", flush=True) print("="*80, flush=True) print(generated_text, flush=True) print("="*80, flush=True) sys.stdout.flush() print(f"\nāœ… REQUEST COMPLETED SUCCESSFULLY", flush=True) print("="*80 + "\n", flush=True) sys.stdout.flush() return jsonify({ 'generated_text': generated_text, 'prompt': prompt }) except Exception as e: print(f"\nāŒ ERROR DURING GENERATION:", flush=True) print(f" Error type: {type(e).__name__}", flush=True) print(f" Error message: {str(e)}", flush=True) sys.stdout.flush() import traceback print(f" Traceback:\n{traceback.format_exc()}", flush=True) print("="*80 + "\n", flush=True) sys.stdout.flush() return jsonify({'error': str(e)}), 500 @app.route('/api/health', methods=['GET']) def health(): return jsonify({ 'status': 'ok', 'model': MODEL_NAME, 'device': str(model.device) }) if __name__ == '__main__': port = int(os.environ.get('PORT', 7860)) print(f"\n🌐 Starting Flask server on port {port}...") print(f"šŸ”— Access the app at: http://localhost:{port}") print(f"šŸ“Š Health check: http://localhost:{port}/api/health") print(f"šŸš€ API endpoint: http://localhost:{port}/api/generate\n") app.run(host='0.0.0.0', port=port, debug=False)