Spaces:
Build error
Build error
| import os | |
| import sys | |
| import torch | |
| from flask import Flask, request, jsonify | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| # --- Configuration --- | |
| # Target: TinyLlama 1.1B, highly likely to succeed on 16GB RAM. | |
| MODEL_ID = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" | |
| app = Flask(__name__) | |
| model = None | |
| tokenizer = None | |
| def load_optimized_model(): | |
| """Loads the model onto CPU using low-precision dtype for memory savings.""" | |
| global model, tokenizer | |
| print(f"Loading memory-optimized model: {MODEL_ID} to CPU...") | |
| # Use torch.float16 (half precision) to halve the memory footprint, | |
| # even on CPU. This is the key to surviving the 16GB limit. | |
| model_dtype = torch.float16 | |
| try: | |
| # 1. Load the tokenizer | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True) | |
| if not tokenizer.pad_token: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| # 2. Load the model with float16 precision and map explicitly to 'cpu' | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_ID, | |
| torch_dtype=model_dtype, | |
| device_map="cpu", | |
| trust_remote_code=True, | |
| ) | |
| # NOTE: Once loaded in FP16, the entire model must be explicitly on the CPU | |
| # for inference. We trust device_map="cpu" handles this, but the torch_dtype | |
| # is the most important memory saver here. | |
| print("Model loaded successfully with FP16 precision!") | |
| except Exception as e: | |
| print(f"❌ CRITICAL ERROR: Failed to load model {MODEL_ID}: {e}", file=sys.stderr) | |
| model = None | |
| # --- Model Initialization --- | |
| with app.app_context(): | |
| load_optimized_model() | |
| def generate_text(): | |
| """API endpoint for text generation, compatible with chatbot.py memory.""" | |
| if model is None or tokenizer is None: | |
| # Returns the error if initialization failed | |
| return jsonify({"error": "Model initialization failed. Check Space logs."}), 500 | |
| data = request.get_json() | |
| prompt = data.get('prompt') | |
| max_new_tokens = data.get('max_new_tokens', 100) | |
| temperature = data.get('temperature', 0.7) | |
| if not prompt: | |
| return jsonify({"error": "Missing 'prompt' in request body."}), 400 | |
| try: | |
| # 1. Format prompt using the Llama chat template | |
| # The prompt received from the client script is already the full history + new prompt | |
| # 2. Tokenize input | |
| input_ids = tokenizer.encode( | |
| prompt, | |
| return_tensors="pt", | |
| truncation=True | |
| ).to(model.device) # Move tensor to CPU (model.device should be 'cpu') | |
| # 3. Generate output | |
| generated_ids = model.generate( | |
| input_ids, | |
| max_new_tokens=max_new_tokens, | |
| do_sample=True, | |
| temperature=temperature, | |
| pad_token_id=tokenizer.eos_token_id | |
| ) | |
| # 4. Decode the new reply (excluding the input prompt) | |
| new_text_start_index = input_ids.shape[-1] | |
| output_text = tokenizer.decode( | |
| generated_ids[0][new_text_start_index:], | |
| skip_special_tokens=True | |
| ) | |
| return jsonify({"generated_text": output_text.strip()}) | |
| except Exception as e: | |
| # Catch unexpected errors during inference | |
| return jsonify({"error": f"Inference failed during generation: {str(e)}"}), 500 | |
| def home(): | |
| """Simple health check endpoint.""" | |
| return "TinyLlama FP16 API is Running!" | |
| if __name__ == '__main__': | |
| app.run(host='0.0.0.0', port=7860, debug=True) | |