Spaces:
Sleeping
Sleeping
| import os | |
| import torch | |
| from flask import Flask, render_template, request, Response, stream_with_context | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig | |
| from peft import PeftModel | |
| # Load model & tokenizer once | |
| base_model_name = "unsloth/llama-3.2-3b-bnb-4bit" | |
| adapter_model_name = "aismaanly/ai_synthetic" | |
| bnb_config = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_quant_type="nf4", | |
| bnb_4bit_use_double_quant=True, | |
| bnb_4bit_compute_dtype=torch.bfloat16 | |
| ) | |
| print("Loading base model...") | |
| model = AutoModelForCausalLM.from_pretrained( | |
| base_model_name, | |
| quantization_config=bnb_config, | |
| device_map="auto", | |
| ) | |
| print("Loading tokenizer...") | |
| tokenizer = AutoTokenizer.from_pretrained(base_model_name) | |
| print("Loading adapter...") | |
| model = PeftModel.from_pretrained(model, adapter_model_name) | |
| model = model.merge_and_unload() | |
| print("Model loaded!") | |
| app = Flask(__name__) | |
| def index(): | |
| return render_template("index.html") | |
| def generate(): | |
| data = request.get_json() | |
| prompt = data.get("prompt") | |
| input_ids = tokenizer(prompt, return_tensors="pt").to(model.device) | |
| def generate_stream(): | |
| generated = input_ids["input_ids"] | |
| past_key_values = None | |
| max_tokens = 100 | |
| for _ in range(max_tokens): | |
| outputs = model( | |
| input_ids=generated[:, -1:], | |
| past_key_values=past_key_values, | |
| use_cache=True, | |
| ) | |
| logits = outputs.logits[:, -1, :] | |
| next_token_id = torch.argmax(logits, dim=-1, keepdim=True) | |
| generated = torch.cat([generated, next_token_id], dim=-1) | |
| past_key_values = outputs.past_key_values | |
| token_text = tokenizer.decode(next_token_id[0]) | |
| yield token_text | |
| if token_text.strip() in ["", "\n", "\r\n", "<|endoftext|>"]: | |
| break | |
| return Response(stream_with_context(generate_stream()), content_type='text/plain') | |
| if __name__ == "__main__": | |
| port = int(os.environ.get("PORT", 7860)) | |
| app.run(host="0.0.0.0", port=port, threaded=True) |