Jizzo commited on
Commit
c8ee064
·
verified ·
1 Parent(s): b510034

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -29
app.py CHANGED
@@ -1,36 +1,27 @@
 
1
  import torch
2
- from transformers import AutoTokenizer, AutoModelForCausalLM
3
- import gradio as gr
4
 
5
- model_name = "LeoLM/leo-mistral-chat"
6
 
7
- tokenizer = AutoTokenizer.from_pretrained(model_name)
8
- model = AutoModelForCausalLM.from_pretrained(
9
- model_name,
10
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
11
- )
12
- model = model.to("cuda" if torch.cuda.is_available() else "cpu")
13
 
14
- def chat(user_input, history=[]):
15
- # Baue den Prompt auf
16
- prompt = ""
17
- for user, bot in history:
18
- prompt += f"User: {user}\nAssistant: {bot}\n"
19
- prompt += f"User: {user_input}\nAssistant:"
20
 
21
- inputs = tokenizer(prompt, return_tensors="pt", truncation=True).to(model.device)
22
- output = model.generate(**inputs, max_new_tokens=256, do_sample=True, top_p=0.9, temperature=0.7, pad_token_id=tokenizer.eos_token_id)
23
- decoded = tokenizer.decode(output[0], skip_special_tokens=True)
24
- answer = decoded.split("Assistant:")[-1].strip()
25
- history.append((user_input, answer))
26
- return history, history
27
 
28
- with gr.Blocks() as demo:
29
- gr.Markdown("## 🤖 Leichtgewichtiger KI-Chat auf Deutsch")
30
- chatbot = gr.Chatbot()
31
- msg = gr.Textbox(label="Deine Nachricht")
32
- state = gr.State([])
33
 
34
- msg.submit(chat, [msg, state], [chatbot, state])
35
-
36
- demo.launch()
 
1
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
2
  import torch
3
+ from flask import Flask, render_template, request, jsonify
 
4
 
5
+ app = Flask(__name__)
6
 
7
+ MODEL_NAME = "ml6team/german-gpt2"
8
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
9
+ model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
10
+ generator = pipeline("text-generation", model=model, tokenizer=tokenizer, device=0 if torch.cuda.is_available() else -1)
 
 
11
 
12
+ @app.route("/")
13
+ def index():
14
+ return render_template("index.html")
 
 
 
15
 
16
+ @app.route("/chat", methods=["POST"])
17
+ def chat():
18
+ data = request.get_json()
19
+ user_input = data.get("message", "")
20
+ if not user_input:
21
+ return jsonify({"response": "Bitte geben Sie eine Nachricht ein."})
22
 
23
+ output = generator(user_input, max_length=200, num_return_sequences=1, do_sample=True)[0]["generated_text"]
24
+ return jsonify({"response": output.strip()})
 
 
 
25
 
26
+ if __name__ == "__main__":
27
+ app.run(debug=True, host="0.0.0.0", port=8080)