import os import torch from flask import Flask, render_template, request, Response, stream_with_context from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig from peft import PeftModel # Load model & tokenizer once base_model_name = "unsloth/llama-3.2-3b-bnb-4bit" adapter_model_name = "aismaanly/ai_synthetic" bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_use_double_quant=True, bnb_4bit_compute_dtype=torch.bfloat16 ) print("Loading base model...") model = AutoModelForCausalLM.from_pretrained( base_model_name, quantization_config=bnb_config, device_map="auto", ) print("Loading tokenizer...") tokenizer = AutoTokenizer.from_pretrained(base_model_name) print("Loading adapter...") model = PeftModel.from_pretrained(model, adapter_model_name) model = model.merge_and_unload() print("Model loaded!") app = Flask(__name__) @app.route("/") def index(): return render_template("index.html") @app.route("/generate", methods=["POST"]) def generate(): data = request.get_json() prompt = data.get("prompt") input_ids = tokenizer(prompt, return_tensors="pt").to(model.device) def generate_stream(): generated = input_ids["input_ids"] past_key_values = None max_tokens = 100 for _ in range(max_tokens): outputs = model( input_ids=generated[:, -1:], past_key_values=past_key_values, use_cache=True, ) logits = outputs.logits[:, -1, :] next_token_id = torch.argmax(logits, dim=-1, keepdim=True) generated = torch.cat([generated, next_token_id], dim=-1) past_key_values = outputs.past_key_values token_text = tokenizer.decode(next_token_id[0]) yield token_text if token_text.strip() in ["", "\n", "\r\n", "<|endoftext|>"]: break return Response(stream_with_context(generate_stream()), content_type='text/plain') if __name__ == "__main__": port = int(os.environ.get("PORT", 7860)) app.run(host="0.0.0.0", port=port, threaded=True)