Spaces:
Sleeping
Sleeping
| from flask import Flask, request, jsonify, send_from_directory | |
| from flask_cors import CORS | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig | |
| import torch | |
| import os | |
| app = Flask(__name__, static_folder='static') | |
| CORS(app) | |
| MODEL_NAME = "KASHH-4/phi_finetuned" | |
| print(f"Loading model: {MODEL_NAME}") | |
| print("Loading tokenizer from YOUR merged model (slow tokenizer)...") | |
| # Your model HAS tokenizer files, use them with use_fast=False | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=False) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| print("Tokenizer loaded successfully!") | |
| print("Loading YOUR model weights...") | |
| # Optimized for 16GB RAM - load in 8-bit quantization | |
| quantization_config = BitsAndBytesConfig( | |
| load_in_8bit=True, # Use 8-bit to fit in 16GB RAM | |
| llm_int8_threshold=6.0 | |
| ) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_NAME, | |
| quantization_config=quantization_config, | |
| device_map="auto", | |
| low_cpu_mem_usage=True, | |
| trust_remote_code=True | |
| ) | |
| print("Model loaded successfully!") | |
| def index(): | |
| return send_from_directory('static', 'index.html') | |
| def generate(): | |
| try: | |
| data = request.json | |
| if not data or 'prompt' not in data: | |
| return jsonify({'error': 'Missing prompt in request body'}), 400 | |
| prompt = data['prompt'] | |
| max_new_tokens = data.get('max_new_tokens', 256) | |
| temperature = data.get('temperature', 0.7) | |
| top_p = data.get('top_p', 0.9) | |
| inputs = tokenizer(prompt, return_tensors="pt").to(model.device) | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| **inputs, | |
| max_new_tokens=max_new_tokens, | |
| temperature=temperature, | |
| top_p=top_p, | |
| do_sample=True, | |
| pad_token_id=tokenizer.eos_token_id | |
| ) | |
| generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| return jsonify({ | |
| 'generated_text': generated_text, | |
| 'prompt': prompt | |
| }) | |
| except Exception as e: | |
| print(f"Error during generation: {e}") | |
| return jsonify({'error': str(e)}), 500 | |
| def health(): | |
| return jsonify({ | |
| 'status': 'ok', | |
| 'model': MODEL_NAME, | |
| 'device': str(model.device) | |
| }) | |
| if __name__ == '__main__': | |
| port = int(os.environ.get('PORT', 7860)) | |
| app.run(host='0.0.0.0', port=port, debug=False) | |