from flask import Flask, request, jsonify, send_from_directory from flask_cors import CORS from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig import torch import os app = Flask(__name__, static_folder='static') CORS(app) MODEL_NAME = "KASHH-4/phi_finetuned" print(f"Loading model: {MODEL_NAME}") print("Loading tokenizer from YOUR merged model (slow tokenizer)...") # Your model HAS tokenizer files, use them with use_fast=False tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=False) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token print("Tokenizer loaded successfully!") print("Loading YOUR model weights...") # Optimized for 16GB RAM - load in 8-bit quantization quantization_config = BitsAndBytesConfig( load_in_8bit=True, # Use 8-bit to fit in 16GB RAM llm_int8_threshold=6.0 ) model = AutoModelForCausalLM.from_pretrained( MODEL_NAME, quantization_config=quantization_config, device_map="auto", low_cpu_mem_usage=True, trust_remote_code=True ) print("Model loaded successfully!") @app.route('/') def index(): return send_from_directory('static', 'index.html') @app.route('/api/generate', methods=['POST']) def generate(): try: data = request.json if not data or 'prompt' not in data: return jsonify({'error': 'Missing prompt in request body'}), 400 prompt = data['prompt'] max_new_tokens = data.get('max_new_tokens', 256) temperature = data.get('temperature', 0.7) top_p = data.get('top_p', 0.9) inputs = tokenizer(prompt, return_tensors="pt").to(model.device) with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=max_new_tokens, temperature=temperature, top_p=top_p, do_sample=True, pad_token_id=tokenizer.eos_token_id ) generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) return jsonify({ 'generated_text': generated_text, 'prompt': prompt }) except Exception as e: print(f"Error during generation: {e}") return jsonify({'error': str(e)}), 500 @app.route('/api/health', methods=['GET']) def health(): return jsonify({ 'status': 'ok', 'model': MODEL_NAME, 'device': str(model.device) }) if __name__ == '__main__': port = int(os.environ.get('PORT', 7860)) app.run(host='0.0.0.0', port=port, debug=False)