| from flask import Flask, request, jsonify, render_template_string
|
| from flask_cors import CORS
|
| from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
|
| import torch
|
| import sys
|
| import os
|
| import gc
|
|
|
| app = Flask(__name__)
|
| CORS(app)
|
|
|
| MODEL_NAME = "KASHH-4/phi_finetuned"
|
|
|
| print("\n" + "="*80)
|
| print("π SIMPLE TEXT GENERATOR - Loading Model")
|
| print("="*80)
|
| print(f"Model: {MODEL_NAME}")
|
|
|
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=False)
|
| if tokenizer.pad_token is None:
|
| tokenizer.pad_token = tokenizer.eos_token
|
|
|
| print("β
Tokenizer loaded!")
|
|
|
|
|
| quantization_config = BitsAndBytesConfig(
|
| load_in_4bit=True,
|
| bnb_4bit_compute_dtype=torch.float16,
|
| bnb_4bit_quant_type="nf4",
|
| bnb_4bit_use_double_quant=False,
|
| )
|
|
|
|
|
| torch.set_num_threads(2)
|
| torch.set_num_interop_threads(1)
|
|
|
| model = AutoModelForCausalLM.from_pretrained(
|
| MODEL_NAME,
|
| quantization_config=quantization_config,
|
| device_map="auto",
|
| low_cpu_mem_usage=True,
|
| trust_remote_code=True,
|
| torch_dtype=torch.float16,
|
| )
|
|
|
|
|
| model.eval()
|
| for param in model.parameters():
|
| param.requires_grad = False
|
|
|
| print("β
Model loaded and optimized!")
|
| print(f"Device: {model.device}")
|
| print(f"Threads: {torch.get_num_threads()}")
|
| print("="*80 + "\n")
|
|
|
| HTML_TEMPLATE = """
|
| <!DOCTYPE html>
|
| <html>
|
| <head>
|
| <title>Phi Finetuned Chat</title>
|
| <style>
|
| * { margin: 0; padding: 0; box-sizing: border-box; }
|
| body {
|
| font-family: 'Segoe UI', Arial, sans-serif;
|
| background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
| min-height: 100vh;
|
| padding: 20px;
|
| }
|
| .container {
|
| max-width: 800px;
|
| margin: 0 auto;
|
| background: white;
|
| border-radius: 16px;
|
| padding: 40px;
|
| box-shadow: 0 20px 60px rgba(0,0,0,0.3);
|
| }
|
| h1 {
|
| color: #1a237e;
|
| margin-bottom: 10px;
|
| font-size: 2em;
|
| }
|
| .subtitle {
|
| color: #666;
|
| margin-bottom: 30px;
|
| }
|
| textarea {
|
| width: 100%;
|
| padding: 15px;
|
| border: 2px solid #e0e0e0;
|
| border-radius: 8px;
|
| font-size: 1em;
|
| font-family: inherit;
|
| min-height: 150px;
|
| resize: vertical;
|
| margin-bottom: 20px;
|
| }
|
| textarea:focus {
|
| outline: none;
|
| border-color: #667eea;
|
| }
|
| button {
|
| background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
| color: white;
|
| border: none;
|
| padding: 15px 40px;
|
| font-size: 1.1em;
|
| font-weight: 600;
|
| border-radius: 8px;
|
| cursor: pointer;
|
| transition: transform 0.2s;
|
| width: 100%;
|
| }
|
| button:hover {
|
| transform: translateY(-2px);
|
| }
|
| button:disabled {
|
| opacity: 0.6;
|
| cursor: not-allowed;
|
| }
|
| .output {
|
| margin-top: 30px;
|
| padding: 20px;
|
| background: #f8f9fa;
|
| border: 2px solid #e0e0e0;
|
| border-radius: 8px;
|
| white-space: pre-wrap;
|
| font-family: 'Courier New', monospace;
|
| max-height: 400px;
|
| overflow-y: auto;
|
| display: none;
|
| }
|
| .output.show {
|
| display: block;
|
| }
|
| .loading {
|
| color: #667eea;
|
| text-align: center;
|
| padding: 20px;
|
| display: none;
|
| }
|
| .loading.show {
|
| display: block;
|
| }
|
| .error {
|
| color: #dc3545;
|
| background: #fff5f5;
|
| border-color: #dc3545;
|
| }
|
| </style>
|
| </head>
|
| <body>
|
| <div class="container">
|
| <h1>π€ Phi Finetuned Chat</h1>
|
| <p class="subtitle">Chat with your fine-tuned Phi model</p>
|
|
|
| <textarea id="prompt" placeholder="Ask me anything..."></textarea>
|
|
|
| <button onclick="generate()" id="generateBtn">π¬ Send</button>
|
|
|
| <div class="loading" id="loading">
|
| <p>β³ Generating response... Please wait (typically 15-30 seconds on 2 vCPU)</p>
|
| </div>
|
|
|
| <div class="output" id="output"></div>
|
| </div>
|
|
|
| <script>
|
| async function generate() {
|
| const prompt = document.getElementById('prompt').value.trim();
|
| const btn = document.getElementById('generateBtn');
|
| const output = document.getElementById('output');
|
| const loading = document.getElementById('loading');
|
|
|
| if (!prompt) {
|
| alert('Please enter a prompt');
|
| return;
|
| }
|
|
|
| btn.disabled = true;
|
| btn.textContent = 'β³ Generating...';
|
| loading.classList.add('show');
|
| output.classList.remove('show', 'error');
|
|
|
| try {
|
| const response = await fetch('/api/generate', {
|
| method: 'POST',
|
| headers: { 'Content-Type': 'application/json' },
|
| body: JSON.stringify({ prompt: prompt })
|
| });
|
|
|
| const data = await response.json();
|
|
|
| if (response.ok) {
|
| output.textContent = data.generated_text;
|
| output.classList.add('show');
|
| } else {
|
| output.textContent = 'Error: ' + data.error;
|
| output.classList.add('show', 'error');
|
| }
|
| } catch (error) {
|
| output.textContent = 'Error: ' + error.message;
|
| output.classList.add('show', 'error');
|
| } finally {
|
| btn.disabled = false;
|
| btn.textContent = 'π¬ Send';
|
| loading.classList.remove('show');
|
| }
|
| }
|
|
|
| // Allow Enter to submit (Shift+Enter for new line)
|
| document.getElementById('prompt').addEventListener('keydown', function(e) {
|
| if (e.key === 'Enter' && !e.shiftKey) {
|
| e.preventDefault();
|
| generate();
|
| }
|
| });
|
| </script>
|
| </body>
|
| </html>
|
| """
|
|
|
| @app.route('/')
|
| def index():
|
| return render_template_string(HTML_TEMPLATE)
|
|
|
| @app.route('/api/generate', methods=['POST'])
|
| def generate():
|
| import time
|
| try:
|
| data = request.json
|
| prompt = data.get('prompt', '')
|
|
|
| if not prompt:
|
| return jsonify({'error': 'No prompt provided'}), 400
|
|
|
| print(f"\n{'='*80}", flush=True)
|
| print(f"π NEW REQUEST", flush=True)
|
| print(f"Prompt: {prompt[:100]}...", flush=True)
|
| sys.stdout.flush()
|
|
|
| start_time = time.time()
|
|
|
|
|
| chat_prompt = f"<|user|>\n{prompt}<|end|>\n<|assistant|>\n"
|
|
|
| inputs = tokenizer(chat_prompt, return_tensors="pt").to(model.device)
|
|
|
| print(f"β‘ Generating...", flush=True)
|
| sys.stdout.flush()
|
|
|
| with torch.no_grad():
|
| outputs = model.generate(
|
| **inputs,
|
| max_new_tokens=150,
|
| do_sample=True,
|
| temperature=0.3,
|
| top_p=0.85,
|
| top_k=40,
|
| repetition_penalty=1.1,
|
| pad_token_id=tokenizer.pad_token_id,
|
| eos_token_id=tokenizer.eos_token_id,
|
| use_cache=True,
|
| num_beams=1,
|
| )
|
|
|
| full_output = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
|
|
| if "<|assistant|>" in full_output:
|
| generated_text = full_output.split("<|assistant|>")[-1].strip()
|
| else:
|
| generated_text = full_output[len(chat_prompt):].strip()
|
|
|
| elapsed = time.time() - start_time
|
|
|
| print(f"β
Generated in {elapsed:.2f}s", flush=True)
|
| print(f"\nπ OUTPUT:", flush=True)
|
| print(f"{'-'*80}", flush=True)
|
| print(generated_text, flush=True)
|
| print(f"{'-'*80}\n", flush=True)
|
| sys.stdout.flush()
|
|
|
| return jsonify({'generated_text': generated_text})
|
|
|
| except Exception as e:
|
| print(f"β ERROR: {str(e)}", flush=True)
|
| sys.stdout.flush()
|
| return jsonify({'error': str(e)}), 500
|
|
|
| if __name__ == '__main__':
|
|
|
| import gc
|
| gc.collect()
|
|
|
| port = int(os.environ.get('PORT', 7860))
|
| print(f"π Starting server on port {port}...\n")
|
| print("π‘ CPU usage should normalize after initial model load\n")
|
|
|
|
|
| app.run(host='0.0.0.0', port=port, debug=False, threaded=True, use_reloader=False)
|
|
|