mistral-api / app.py
HusainHG's picture
Update app.py
a1a4453 verified
from flask import Flask, request, jsonify, send_from_directory
from flask_cors import CORS
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import torch
import os
app = Flask(__name__, static_folder='static')
CORS(app)
MODEL_NAME = "KASHH-4/phi_finetuned"
print(f"Loading model: {MODEL_NAME}")
print("Loading tokenizer from YOUR merged model (slow tokenizer)...")
# Your model HAS tokenizer files, use them with use_fast=False
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=False)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
print("Tokenizer loaded successfully!")
print("Loading YOUR model weights...")
# Optimized for 16GB RAM - load in 8-bit quantization
quantization_config = BitsAndBytesConfig(
load_in_8bit=True, # Use 8-bit to fit in 16GB RAM
llm_int8_threshold=6.0
)
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
quantization_config=quantization_config,
device_map="auto",
low_cpu_mem_usage=True,
trust_remote_code=True
)
print("Model loaded successfully!")
@app.route('/')
def index():
return send_from_directory('static', 'index.html')
@app.route('/api/generate', methods=['POST'])
def generate():
try:
data = request.json
if not data or 'prompt' not in data:
return jsonify({'error': 'Missing prompt in request body'}), 400
prompt = data['prompt']
max_new_tokens = data.get('max_new_tokens', 256)
temperature = data.get('temperature', 0.7)
top_p = data.get('top_p', 0.9)
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
do_sample=True,
pad_token_id=tokenizer.eos_token_id
)
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
return jsonify({
'generated_text': generated_text,
'prompt': prompt
})
except Exception as e:
print(f"Error during generation: {e}")
return jsonify({'error': str(e)}), 500
@app.route('/api/health', methods=['GET'])
def health():
return jsonify({
'status': 'ok',
'model': MODEL_NAME,
'device': str(model.device)
})
if __name__ == '__main__':
port = int(os.environ.get('PORT', 7860))
app.run(host='0.0.0.0', port=port, debug=False)