Copy / app.py
HusainHG's picture
Upload 10 files
46545a3 verified
from flask import Flask, request, jsonify, send_from_directory
from flask_cors import CORS
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import torch
import os
import sys
app = Flask(__name__, static_folder='static')
CORS(app)
MODEL_NAME = "KASHH-4/phi_finetuned"
print("\n" + "="*80)
print("πŸš€ LEGALDOCS AI - MODEL INITIALIZATION")
print("="*80)
print(f"πŸ“¦ Model: {MODEL_NAME}")
print(f"🐍 Python: {torch.__version__}")
print(f"πŸ”₯ PyTorch: {torch.__version__}")
print(f"πŸ€— Transformers: Loading...")
print("="*80 + "\n")
print("Loading tokenizer from YOUR merged model (slow tokenizer)...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=False)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
print("βœ… Tokenizer loaded successfully!")
print(f" - Vocab size: {tokenizer.vocab_size}")
print(f" - Model max length: {tokenizer.model_max_length}")
print(f" - Pad token: {tokenizer.pad_token}")
print("Loading YOUR model weights...")
# Optimized for 18GB RAM with 4-bit quantization
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True,
)
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
quantization_config=quantization_config,
device_map="auto",
low_cpu_mem_usage=True,
trust_remote_code=True,
torch_dtype=torch.float16,
)
print("βœ… Model loaded successfully!")
print(f" - Device: {model.device}")
print(f" - Model type: {type(model).__name__}")
print(f" - Quantization: 4-bit NF4")
print(f" - Compute dtype: float16")
# Memory info
if torch.cuda.is_available():
print(f" - GPU: {torch.cuda.get_device_name(0)}")
print(f" - GPU Memory allocated: {torch.cuda.memory_allocated(0) / 1024**3:.2f} GB")
print(f" - GPU Memory reserved: {torch.cuda.memory_reserved(0) / 1024**3:.2f} GB")
else:
print(f" - Running on CPU")
print("\n" + "="*80)
print("βœ… MODEL READY - Server starting...")
print("="*80 + "\n")
@app.route('/')
def index():
return send_from_directory('static', 'index.html')
@app.route('/api/generate', methods=['POST'])
def generate():
import time
try:
print("\n" + "="*80, flush=True)
print("πŸš€ NEW GENERATION REQUEST RECEIVED", flush=True)
print("="*80, flush=True)
sys.stdout.flush()
data = request.json
if not data or 'prompt' not in data:
print("❌ ERROR: Missing prompt in request body", flush=True)
sys.stdout.flush()
return jsonify({'error': 'Missing prompt in request body'}), 400
prompt = data['prompt']
max_new_tokens = data.get('max_new_tokens', 400)
temperature = data.get('temperature', 0.7)
top_p = data.get('top_p', 0.9)
print(f"\nπŸ“ REQUEST PARAMETERS:", flush=True)
print(f" - Prompt length: {len(prompt)} characters", flush=True)
print(f" - Prompt preview: {prompt[:200]}...", flush=True)
print(f" - Max new tokens: {max_new_tokens}", flush=True)
print(f" - Temperature: {temperature}", flush=True)
print(f" - Top P: {top_p}", flush=True)
sys.stdout.flush()
print(f"\nπŸ”„ TOKENIZING INPUT...", flush=True)
sys.stdout.flush()
tokenize_start = time.time()
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
tokenize_time = time.time() - tokenize_start
input_token_count = inputs['input_ids'].shape[1]
print(f" βœ… Tokenization complete in {tokenize_time:.2f}s", flush=True)
print(f" - Input tokens: {input_token_count}", flush=True)
print(f" - Device: {model.device}", flush=True)
sys.stdout.flush()
print(f"\n🧠 GENERATING TEXT WITH MODEL...", flush=True)
print(f" Model: {MODEL_NAME}", flush=True)
print(f" Status: Running inference...", flush=True)
sys.stdout.flush()
generation_start = time.time()
# Optimized for Phi-3 on CPU - minimal tokens to avoid timeout
with torch.no_grad():
torch.set_num_threads(2) # Use both CPU cores
outputs = model.generate(
**inputs,
max_new_tokens=100, # Very short to prevent timeout
do_sample=False, # Greedy decoding
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
use_cache=False # Disable cache compatibility issue
)
generation_time = time.time() - generation_start
output_token_count = outputs.shape[1]
tokens_generated = output_token_count - input_token_count
tokens_per_second = tokens_generated / generation_time if generation_time > 0 else 0
print(f" βœ… Generation complete in {generation_time:.2f}s", flush=True)
print(f" - Output tokens: {output_token_count}", flush=True)
print(f" - New tokens generated: {tokens_generated}", flush=True)
print(f" - Speed: {tokens_per_second:.2f} tokens/second", flush=True)
sys.stdout.flush()
print(f"\nπŸ”„ DECODING OUTPUT...", flush=True)
sys.stdout.flush()
decode_start = time.time()
# Decode the full output
full_output = tokenizer.decode(outputs[0], skip_special_tokens=True)
decode_time = time.time() - decode_start
print(f" βœ… Decoding complete in {decode_time:.2f}s", flush=True)
sys.stdout.flush()
# Remove the prompt from the output to return only the generated text
generated_text = full_output[len(prompt):].strip()
print(f"\nπŸ“Š FINAL RESULTS:", flush=True)
print(f" - Generated text length: {len(generated_text)} characters", flush=True)
print(f" - Total processing time: {(time.time() - tokenize_start):.2f}s", flush=True)
print(f"\nπŸ“„ GENERATED OUTPUT:", flush=True)
print("="*80, flush=True)
print(generated_text, flush=True)
print("="*80, flush=True)
sys.stdout.flush()
print(f"\nβœ… REQUEST COMPLETED SUCCESSFULLY", flush=True)
print("="*80 + "\n", flush=True)
sys.stdout.flush()
return jsonify({
'generated_text': generated_text,
'prompt': prompt
})
except Exception as e:
print(f"\n❌ ERROR DURING GENERATION:", flush=True)
print(f" Error type: {type(e).__name__}", flush=True)
print(f" Error message: {str(e)}", flush=True)
sys.stdout.flush()
import traceback
print(f" Traceback:\n{traceback.format_exc()}", flush=True)
print("="*80 + "\n", flush=True)
sys.stdout.flush()
return jsonify({'error': str(e)}), 500
@app.route('/api/health', methods=['GET'])
def health():
return jsonify({
'status': 'ok',
'model': MODEL_NAME,
'device': str(model.device)
})
if __name__ == '__main__':
port = int(os.environ.get('PORT', 7860))
print(f"\n🌐 Starting Flask server on port {port}...")
print(f"πŸ”— Access the app at: http://localhost:{port}")
print(f"πŸ“Š Health check: http://localhost:{port}/api/health")
print(f"πŸš€ API endpoint: http://localhost:{port}/api/generate\n")
app.run(host='0.0.0.0', port=port, debug=False)