Spaces:
Sleeping
Sleeping
File size: 2,640 Bytes
47fae5d a1a4453 47fae5d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 | from flask import Flask, request, jsonify, send_from_directory
from flask_cors import CORS
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import torch
import os
app = Flask(__name__, static_folder='static')
CORS(app)
MODEL_NAME = "KASHH-4/phi_finetuned"
print(f"Loading model: {MODEL_NAME}")
print("Loading tokenizer from YOUR merged model (slow tokenizer)...")
# Your model HAS tokenizer files, use them with use_fast=False
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=False)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
print("Tokenizer loaded successfully!")
print("Loading YOUR model weights...")
# Optimized for 16GB RAM - load in 8-bit quantization
quantization_config = BitsAndBytesConfig(
load_in_8bit=True, # Use 8-bit to fit in 16GB RAM
llm_int8_threshold=6.0
)
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
quantization_config=quantization_config,
device_map="auto",
low_cpu_mem_usage=True,
trust_remote_code=True
)
print("Model loaded successfully!")
@app.route('/')
def index():
return send_from_directory('static', 'index.html')
@app.route('/api/generate', methods=['POST'])
def generate():
try:
data = request.json
if not data or 'prompt' not in data:
return jsonify({'error': 'Missing prompt in request body'}), 400
prompt = data['prompt']
max_new_tokens = data.get('max_new_tokens', 256)
temperature = data.get('temperature', 0.7)
top_p = data.get('top_p', 0.9)
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
do_sample=True,
pad_token_id=tokenizer.eos_token_id
)
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
return jsonify({
'generated_text': generated_text,
'prompt': prompt
})
except Exception as e:
print(f"Error during generation: {e}")
return jsonify({'error': str(e)}), 500
@app.route('/api/health', methods=['GET'])
def health():
return jsonify({
'status': 'ok',
'model': MODEL_NAME,
'device': str(model.device)
})
if __name__ == '__main__':
port = int(os.environ.get('PORT', 7860))
app.run(host='0.0.0.0', port=port, debug=False)
|