simple-ui / app.py
HusainHG's picture
Upload 5 files
29601ae verified
from flask import Flask, request, jsonify, render_template_string
from flask_cors import CORS
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import torch
import sys
import os
import gc # For garbage collection optimization
app = Flask(__name__)
CORS(app)
MODEL_NAME = "KASHH-4/phi_finetuned"
print("\n" + "="*80)
print("πŸš€ SIMPLE TEXT GENERATOR - Loading Model")
print("="*80)
print(f"Model: {MODEL_NAME}")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=False)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
print("βœ… Tokenizer loaded!")
# Optimized quantization for 2 vCPU + 18GB RAM
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=False, # Disabled for CPU efficiency
)
# Set CPU threads BEFORE loading model to reduce startup CPU spike
torch.set_num_threads(2)
torch.set_num_interop_threads(1)
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
quantization_config=quantization_config,
device_map="auto",
low_cpu_mem_usage=True,
trust_remote_code=True,
torch_dtype=torch.float16,
)
# Set model to eval mode and optimize for inference
model.eval()
for param in model.parameters():
param.requires_grad = False
print("βœ… Model loaded and optimized!")
print(f"Device: {model.device}")
print(f"Threads: {torch.get_num_threads()}")
print("="*80 + "\n")
HTML_TEMPLATE = """
<!DOCTYPE html>
<html>
<head>
<title>Phi Finetuned Chat</title>
<style>
* { margin: 0; padding: 0; box-sizing: border-box; }
body {
font-family: 'Segoe UI', Arial, sans-serif;
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
min-height: 100vh;
padding: 20px;
}
.container {
max-width: 800px;
margin: 0 auto;
background: white;
border-radius: 16px;
padding: 40px;
box-shadow: 0 20px 60px rgba(0,0,0,0.3);
}
h1 {
color: #1a237e;
margin-bottom: 10px;
font-size: 2em;
}
.subtitle {
color: #666;
margin-bottom: 30px;
}
textarea {
width: 100%;
padding: 15px;
border: 2px solid #e0e0e0;
border-radius: 8px;
font-size: 1em;
font-family: inherit;
min-height: 150px;
resize: vertical;
margin-bottom: 20px;
}
textarea:focus {
outline: none;
border-color: #667eea;
}
button {
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
color: white;
border: none;
padding: 15px 40px;
font-size: 1.1em;
font-weight: 600;
border-radius: 8px;
cursor: pointer;
transition: transform 0.2s;
width: 100%;
}
button:hover {
transform: translateY(-2px);
}
button:disabled {
opacity: 0.6;
cursor: not-allowed;
}
.output {
margin-top: 30px;
padding: 20px;
background: #f8f9fa;
border: 2px solid #e0e0e0;
border-radius: 8px;
white-space: pre-wrap;
font-family: 'Courier New', monospace;
max-height: 400px;
overflow-y: auto;
display: none;
}
.output.show {
display: block;
}
.loading {
color: #667eea;
text-align: center;
padding: 20px;
display: none;
}
.loading.show {
display: block;
}
.error {
color: #dc3545;
background: #fff5f5;
border-color: #dc3545;
}
</style>
</head>
<body>
<div class="container">
<h1>πŸ€– Phi Finetuned Chat</h1>
<p class="subtitle">Chat with your fine-tuned Phi model</p>
<textarea id="prompt" placeholder="Ask me anything..."></textarea>
<button onclick="generate()" id="generateBtn">πŸ’¬ Send</button>
<div class="loading" id="loading">
<p>⏳ Generating response... Please wait (typically 15-30 seconds on 2 vCPU)</p>
</div>
<div class="output" id="output"></div>
</div>
<script>
async function generate() {
const prompt = document.getElementById('prompt').value.trim();
const btn = document.getElementById('generateBtn');
const output = document.getElementById('output');
const loading = document.getElementById('loading');
if (!prompt) {
alert('Please enter a prompt');
return;
}
btn.disabled = true;
btn.textContent = '⏳ Generating...';
loading.classList.add('show');
output.classList.remove('show', 'error');
try {
const response = await fetch('/api/generate', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ prompt: prompt })
});
const data = await response.json();
if (response.ok) {
output.textContent = data.generated_text;
output.classList.add('show');
} else {
output.textContent = 'Error: ' + data.error;
output.classList.add('show', 'error');
}
} catch (error) {
output.textContent = 'Error: ' + error.message;
output.classList.add('show', 'error');
} finally {
btn.disabled = false;
btn.textContent = 'πŸ’¬ Send';
loading.classList.remove('show');
}
}
// Allow Enter to submit (Shift+Enter for new line)
document.getElementById('prompt').addEventListener('keydown', function(e) {
if (e.key === 'Enter' && !e.shiftKey) {
e.preventDefault();
generate();
}
});
</script>
</body>
</html>
"""
@app.route('/')
def index():
return render_template_string(HTML_TEMPLATE)
@app.route('/api/generate', methods=['POST'])
def generate():
import time
try:
data = request.json
prompt = data.get('prompt', '')
if not prompt:
return jsonify({'error': 'No prompt provided'}), 400
print(f"\n{'='*80}", flush=True)
print(f"πŸš€ NEW REQUEST", flush=True)
print(f"Prompt: {prompt[:100]}...", flush=True)
sys.stdout.flush()
start_time = time.time()
# Format for chat-based interaction with Phi model
chat_prompt = f"<|user|>\n{prompt}<|end|>\n<|assistant|>\n"
inputs = tokenizer(chat_prompt, return_tensors="pt").to(model.device)
print(f"⚑ Generating...", flush=True)
sys.stdout.flush()
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=150, # Reduced for faster response
do_sample=True,
temperature=0.3, # Lower temp = faster, more focused
top_p=0.85, # Slightly lower for efficiency
top_k=40, # Limit sampling space
repetition_penalty=1.1, # Prevent loops
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
use_cache=True, # Enable KV cache for speed
num_beams=1, # Greedy = faster
)
full_output = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Extract only the assistant's response
if "<|assistant|>" in full_output:
generated_text = full_output.split("<|assistant|>")[-1].strip()
else:
generated_text = full_output[len(chat_prompt):].strip()
elapsed = time.time() - start_time
print(f"βœ… Generated in {elapsed:.2f}s", flush=True)
print(f"\nπŸ“„ OUTPUT:", flush=True)
print(f"{'-'*80}", flush=True)
print(generated_text, flush=True)
print(f"{'-'*80}\n", flush=True)
sys.stdout.flush()
return jsonify({'generated_text': generated_text})
except Exception as e:
print(f"❌ ERROR: {str(e)}", flush=True)
sys.stdout.flush()
return jsonify({'error': str(e)}), 500
if __name__ == '__main__':
# Force garbage collection after model load
import gc
gc.collect()
port = int(os.environ.get('PORT', 7860))
print(f"🌐 Starting server on port {port}...\n")
print("πŸ’‘ CPU usage should normalize after initial model load\n")
# Use threaded mode for better concurrency on 2 vCPU
app.run(host='0.0.0.0', port=port, debug=False, threaded=True, use_reloader=False)