File size: 1,748 Bytes
7339449
 
02f2b8a
7339449
 
 
 
2aa19f3
58b71d4
2aa19f3
58b71d4
7339449
05fe403
7339449
 
 
58b71d4
2aa19f3
 
58b71d4
2aa19f3
05fe403
7339449
 
 
2aa19f3
7339449
2aa19f3
 
 
58b71d4
7339449
 
 
 
 
 
 
 
58b71d4
 
2aa19f3
 
 
7339449
 
58b71d4
 
2aa19f3
7339449
 
2aa19f3
 
58b71d4
 
 
02f2b8a
7339449
 
 
 
 
 
 
58b71d4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import flask
from flask import request, jsonify
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

app = flask.Flask(__name__)

# ---------------------------
# SUPER FAST SMALL MODEL
# ---------------------------
model_id = "HuggingFaceTB/SmolLM2-360M-Instruct"

print(f"🔄 Loading {model_id} model...")

tokenizer = AutoTokenizer.from_pretrained(model_id)

# Best dtype for CPU speed = bfloat16
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

print(f"✅ {model_id} loaded successfully!")

# ---------------------------
# Chat Endpoint
# ---------------------------
@app.route("/chat", methods=["POST"])
def chat():
    try:
        data = request.get_json()
        msg = data.get("message", "")

        if not msg:
            return jsonify({"error": "No message sent"}), 400

        # SmolLM2 format: no chat special tokens needed
        prompt = f"User: {msg}\nAssistant:"

        inputs = tokenizer(prompt, return_tensors="pt").to(device)

        output = model.generate(
            **inputs,
            max_new_tokens=128,   # fast
            do_sample=False,      # FASTEST
            pad_token_id=tokenizer.eos_token_id,
        )

        reply = tokenizer.decode(output[0], skip_special_tokens=True)

        # Remove prompt text from output
        if "Assistant:" in reply:
            reply = reply.split("Assistant:")[-1].strip()

        return jsonify({"reply": reply})

    except Exception as e:
        return jsonify({"error": str(e)}), 500


if __name__ == "__main__":
    app.run(host="0.0.0.0", port=7860)