| from flask import Flask, render_template, request, flash, jsonify |
| import torch |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
| from huggingface_hub import login, whoami, HfApi |
| import numpy as np |
| import os, json |
|
|
| app = Flask(__name__) |
| app.secret_key = os.urandom(24) |
|
|
| |
| ee_model = None |
| ee_tokenizer = None |
| ee_config = None |
|
|
| @app.route("/", methods=["GET", "POST"]) |
| def index(): |
| global ee_model, ee_tokenizer, ee_config |
| if request.method == "POST": |
| action = request.form.get("action") |
|
|
| if action == "start_server": |
| ee_model_name = request.form["ee_model_name"].strip() |
| hf_token = request.form["hf_token"].strip() |
|
|
| try: |
| login(token=hf_token) |
| global ee_model, ee_tokenizer, ee_config |
|
|
| ee_model = AutoModelForCausalLM.from_pretrained( |
| ee_model_name, |
| torch_dtype=torch.float16, |
| |
| device_map="auto", |
| trust_remote_code=True |
| ) |
| ee_tokenizer = AutoTokenizer.from_pretrained(ee_model_name, trust_remote_code=True) |
|
|
| |
| from huggingface_hub import hf_hub_download |
| config_path = hf_hub_download(ee_model_name, "ee_config.json") |
| with open(config_path) as f: |
| ee_config = json.load(f) |
|
|
| flash(f"✅ Server ready! Model loaded: {ee_model_name}", "success") |
| flash("Now use the Client Space and point it to this Space's URL", "info") |
|
|
| except Exception as e: |
| flash(f"Error: {str(e)}", "danger") |
|
|
| return render_template("index.html") |
|
|
| |
| @app.route("/generate", methods=["POST"]) |
| def generate(): |
| if ee_model is None: |
| return jsonify({"error": "Server not started yet"}), 400 |
|
|
| data = request.json |
| encrypted_embeds = torch.tensor(data["encrypted_embeds"]).to(ee_model.device) |
| attention_mask = torch.tensor(data.get("attention_mask", [[1]*encrypted_embeds.shape[1]])).to(ee_model.device) |
| max_new = int(data.get("max_new_tokens", 256)) |
|
|
| with torch.no_grad(): |
| output_ids = ee_model.generate( |
| inputs_embeds=encrypted_embeds, |
| attention_mask=attention_mask, |
| max_new_tokens=max_new, |
| do_sample=True, |
| temperature=0.7, |
| top_p=0.9, |
| pad_token_id=ee_tokenizer.eos_token_id |
| ) |
|
|
| return jsonify({"generated_ids": output_ids[0].tolist()}) |
|
|
| if __name__ == "__main__": |
| app.run(host="0.0.0.0", port=7860) |