|
|
from flask import Flask, request, jsonify |
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
import torch |
|
|
|
|
|
|
|
|
MODEL_NAME = "dbmdz/german-gpt2" |
|
|
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) |
|
|
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME) |
|
|
|
|
|
app = Flask(__name__) |
|
|
|
|
|
@app.route("/chat", methods=["POST"]) |
|
|
def chat(): |
|
|
data = request.json |
|
|
user_input = data.get("message", "") |
|
|
|
|
|
|
|
|
input_ids = tokenizer.encode(user_input, return_tensors="pt") |
|
|
|
|
|
|
|
|
output_ids = model.generate(input_ids, max_new_tokens=100, do_sample=True, top_k=50) |
|
|
response = tokenizer.decode(output_ids[0], skip_special_tokens=True) |
|
|
|
|
|
|
|
|
answer_only = response[len(user_input):].strip() |
|
|
|
|
|
return jsonify({"response": answer_only}) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
app.run(debug=True) |