mis / app.py
Shivangguptasih's picture
Update app.py
421c010 verified
import os
import sys
import torch
from flask import Flask, request, jsonify
from transformers import AutoModelForCausalLM, AutoTokenizer
# --- Configuration ---
# Target: TinyLlama 1.1B, highly likely to succeed on 16GB RAM.
MODEL_ID = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
app = Flask(__name__)
model = None
tokenizer = None
def load_optimized_model():
"""Loads the model onto CPU using low-precision dtype for memory savings."""
global model, tokenizer
print(f"Loading memory-optimized model: {MODEL_ID} to CPU...")
# Use torch.float16 (half precision) to halve the memory footprint,
# even on CPU. This is the key to surviving the 16GB limit.
model_dtype = torch.float16
try:
# 1. Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
if not tokenizer.pad_token:
tokenizer.pad_token = tokenizer.eos_token
# 2. Load the model with float16 precision and map explicitly to 'cpu'
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype=model_dtype,
device_map="cpu",
trust_remote_code=True,
)
# NOTE: Once loaded in FP16, the entire model must be explicitly on the CPU
# for inference. We trust device_map="cpu" handles this, but the torch_dtype
# is the most important memory saver here.
print("Model loaded successfully with FP16 precision!")
except Exception as e:
print(f"❌ CRITICAL ERROR: Failed to load model {MODEL_ID}: {e}", file=sys.stderr)
model = None
# --- Model Initialization ---
with app.app_context():
load_optimized_model()
@app.route('/generate', methods=['POST'])
def generate_text():
"""API endpoint for text generation, compatible with chatbot.py memory."""
if model is None or tokenizer is None:
# Returns the error if initialization failed
return jsonify({"error": "Model initialization failed. Check Space logs."}), 500
data = request.get_json()
prompt = data.get('prompt')
max_new_tokens = data.get('max_new_tokens', 100)
temperature = data.get('temperature', 0.7)
if not prompt:
return jsonify({"error": "Missing 'prompt' in request body."}), 400
try:
# 1. Format prompt using the Llama chat template
# The prompt received from the client script is already the full history + new prompt
# 2. Tokenize input
input_ids = tokenizer.encode(
prompt,
return_tensors="pt",
truncation=True
).to(model.device) # Move tensor to CPU (model.device should be 'cpu')
# 3. Generate output
generated_ids = model.generate(
input_ids,
max_new_tokens=max_new_tokens,
do_sample=True,
temperature=temperature,
pad_token_id=tokenizer.eos_token_id
)
# 4. Decode the new reply (excluding the input prompt)
new_text_start_index = input_ids.shape[-1]
output_text = tokenizer.decode(
generated_ids[0][new_text_start_index:],
skip_special_tokens=True
)
return jsonify({"generated_text": output_text.strip()})
except Exception as e:
# Catch unexpected errors during inference
return jsonify({"error": f"Inference failed during generation: {str(e)}"}), 500
@app.route('/', methods=['GET'])
def home():
"""Simple health check endpoint."""
return "TinyLlama FP16 API is Running!"
if __name__ == '__main__':
app.run(host='0.0.0.0', port=7860, debug=True)