from flask import Flask, request, jsonify, render_template_string
from flask_cors import CORS
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import torch
import sys
import os
import gc # For garbage collection optimization
app = Flask(__name__)
CORS(app)
MODEL_NAME = "KASHH-4/phi_finetuned"
print("\n" + "="*80)
print("š SIMPLE TEXT GENERATOR - Loading Model")
print("="*80)
print(f"Model: {MODEL_NAME}")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=False)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
print("ā
Tokenizer loaded!")
# Optimized quantization for 2 vCPU + 18GB RAM
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=False, # Disabled for CPU efficiency
)
# Set CPU threads BEFORE loading model to reduce startup CPU spike
torch.set_num_threads(2)
torch.set_num_interop_threads(1)
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
quantization_config=quantization_config,
device_map="auto",
low_cpu_mem_usage=True,
trust_remote_code=True,
torch_dtype=torch.float16,
)
# Set model to eval mode and optimize for inference
model.eval()
for param in model.parameters():
param.requires_grad = False
print("ā
Model loaded and optimized!")
print(f"Device: {model.device}")
print(f"Threads: {torch.get_num_threads()}")
print("="*80 + "\n")
HTML_TEMPLATE = """
Phi Finetuned Chat
š¤ Phi Finetuned Chat
Chat with your fine-tuned Phi model
ā³ Generating response... Please wait (typically 15-30 seconds on 2 vCPU)
"""
@app.route('/')
def index():
return render_template_string(HTML_TEMPLATE)
@app.route('/api/generate', methods=['POST'])
def generate():
import time
try:
data = request.json
prompt = data.get('prompt', '')
if not prompt:
return jsonify({'error': 'No prompt provided'}), 400
print(f"\n{'='*80}", flush=True)
print(f"š NEW REQUEST", flush=True)
print(f"Prompt: {prompt[:100]}...", flush=True)
sys.stdout.flush()
start_time = time.time()
# Format for chat-based interaction with Phi model
chat_prompt = f"<|user|>\n{prompt}<|end|>\n<|assistant|>\n"
inputs = tokenizer(chat_prompt, return_tensors="pt").to(model.device)
print(f"ā” Generating...", flush=True)
sys.stdout.flush()
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=150, # Reduced for faster response
do_sample=True,
temperature=0.3, # Lower temp = faster, more focused
top_p=0.85, # Slightly lower for efficiency
top_k=40, # Limit sampling space
repetition_penalty=1.1, # Prevent loops
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
use_cache=True, # Enable KV cache for speed
num_beams=1, # Greedy = faster
)
full_output = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Extract only the assistant's response
if "<|assistant|>" in full_output:
generated_text = full_output.split("<|assistant|>")[-1].strip()
else:
generated_text = full_output[len(chat_prompt):].strip()
elapsed = time.time() - start_time
print(f"ā
Generated in {elapsed:.2f}s", flush=True)
print(f"\nš OUTPUT:", flush=True)
print(f"{'-'*80}", flush=True)
print(generated_text, flush=True)
print(f"{'-'*80}\n", flush=True)
sys.stdout.flush()
return jsonify({'generated_text': generated_text})
except Exception as e:
print(f"ā ERROR: {str(e)}", flush=True)
sys.stdout.flush()
return jsonify({'error': str(e)}), 500
if __name__ == '__main__':
# Force garbage collection after model load
import gc
gc.collect()
port = int(os.environ.get('PORT', 7860))
print(f"š Starting server on port {port}...\n")
print("š” CPU usage should normalize after initial model load\n")
# Use threaded mode for better concurrency on 2 vCPU
app.run(host='0.0.0.0', port=port, debug=False, threaded=True, use_reloader=False)