| from flask import Flask, render_template, request, flash, jsonify |
| import torch |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
| from huggingface_hub import login |
| import os, json |
|
|
| app = Flask(__name__) |
| app.secret_key = os.urandom(24) |
|
|
| |
| ee_model = None |
| ee_tokenizer = None |
| ee_config = None |
| loaded_model_name = None |
|
|
| |
| SPACE_HOST = os.environ.get("SPACE_HOST", "") |
| if SPACE_HOST: |
| SPACE_URL = f"https://{SPACE_HOST}" |
| else: |
| SPACE_URL = "http://localhost:7860" |
|
|
|
|
| @app.route("/", methods=["GET", "POST"]) |
| def index(): |
| global ee_model, ee_tokenizer, ee_config, loaded_model_name |
|
|
| if request.method == "POST": |
| action = request.form.get("action") |
|
|
| if action == "start_server": |
| ee_model_name = request.form["ee_model_name"].strip() |
| hf_token = request.form["hf_token"].strip() |
|
|
| try: |
| login(token=hf_token) |
|
|
| ee_model = AutoModelForCausalLM.from_pretrained( |
| ee_model_name, |
| torch_dtype=torch.float16, |
| device_map="auto", |
| trust_remote_code=True |
| ) |
| ee_tokenizer = AutoTokenizer.from_pretrained( |
| ee_model_name, trust_remote_code=True |
| ) |
|
|
| |
| from huggingface_hub import hf_hub_download |
| config_path = hf_hub_download(ee_model_name, "ee_config.json") |
| with open(config_path) as f: |
| ee_config = json.load(f) |
|
|
| loaded_model_name = ee_model_name |
| flash(f"Model loaded successfully: {ee_model_name}", "success") |
| flash("Point your Client Space to this Space's URL below.", "info") |
|
|
| except Exception as e: |
| flash(f"Error: {str(e)}", "danger") |
|
|
| return render_template( |
| "index.html", |
| server_ready=(ee_model is not None), |
| model_name=loaded_model_name, |
| space_url=SPACE_URL, |
| ) |
|
|
|
|
| |
| @app.route("/generate", methods=["POST"]) |
| def generate(): |
| if ee_model is None: |
| return jsonify({"error": "Server not started yet"}), 400 |
|
|
| data = request.json |
| encrypted_embeds = torch.tensor(data["encrypted_embeds"]).to(ee_model.device) |
| attention_mask = torch.tensor( |
| data.get("attention_mask", [[1] * encrypted_embeds.shape[1]]) |
| ).to(ee_model.device) |
| max_new = int(data.get("max_new_tokens", 256)) |
|
|
| with torch.no_grad(): |
| output_ids = ee_model.generate( |
| inputs_embeds=encrypted_embeds, |
| attention_mask=attention_mask, |
| max_new_tokens=max_new, |
| do_sample=True, |
| temperature=0.7, |
| top_p=0.9, |
| pad_token_id=ee_tokenizer.eos_token_id, |
| ) |
|
|
| return jsonify({"generated_ids": output_ids[0].tolist()}) |
|
|
|
|
| if __name__ == "__main__": |
| app.run(host="0.0.0.0", port=7860) |