hhelesto commited on
Commit
501f078
·
verified ·
1 Parent(s): bb94c3d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -9
app.py CHANGED
@@ -1,8 +1,8 @@
1
  import os
2
  import torch
 
3
  from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
4
  from peft import PeftModel
5
- from flask import Flask, request, jsonify, render_template
6
 
7
  # --- Load Model & Tokenizer ---
8
 
@@ -35,23 +35,25 @@ print("Model ready!")
35
 
36
  app = Flask(__name__)
37
 
38
- @app.route("/")
39
  def index():
40
  return render_template("index.html")
41
 
42
  @app.route("/generate", methods=["POST"])
43
  def generate():
44
- data = request.get_json()
45
- prompt = data.get("prompt")
46
 
47
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
48
  outputs = model.generate(**inputs, max_new_tokens=100)
49
  text = tokenizer.decode(outputs[0], skip_special_tokens=True)
50
-
51
- return jsonify({
52
- "generated_text": text
53
- })
 
 
 
54
 
55
  if __name__ == "__main__":
56
  port = int(os.environ.get("PORT", 7860))
57
- app.run(host="0.0.0.0", port=port)
 
1
  import os
2
  import torch
3
+ from flask import Flask, render_template, request, redirect, url_for
4
  from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
5
  from peft import PeftModel
 
6
 
7
  # --- Load Model & Tokenizer ---
8
 
 
35
 
36
  app = Flask(__name__)
37
 
38
+ @app.route("/", methods=["GET"])
39
  def index():
40
  return render_template("index.html")
41
 
42
  @app.route("/generate", methods=["POST"])
43
  def generate():
44
+ prompt = request.form["prompt"]
 
45
 
46
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
47
  outputs = model.generate(**inputs, max_new_tokens=100)
48
  text = tokenizer.decode(outputs[0], skip_special_tokens=True)
49
+
50
+ return redirect(url_for("result", generated_text=text))
51
+
52
+ @app.route("/result")
53
+ def result():
54
+ generated_text = request.args.get("generated_text", "")
55
+ return render_template("result.html", generated_text=generated_text)
56
 
57
  if __name__ == "__main__":
58
  port = int(os.environ.get("PORT", 7860))
59
+ app.run(host="0.0.0.0", port=port)