|
|
from flask import Flask, render_template, request, jsonify |
|
|
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, pipeline |
|
|
from mistral_7b import generate_text |
|
|
import torch |
|
|
from inference import voice_inference |
|
|
|
|
|
|
|
|
app = Flask(__name__, static_url_path='/static') |
|
|
|
|
|
|
|
|
tokenizer_name = "gogamza/kobart-base-v2" |
|
|
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) |
|
|
|
|
|
model_name = "/content/flask/eojin/checkpoint-142243" |
|
|
model = AutoModelForSeq2SeqLM.from_pretrained(model_name) |
|
|
|
|
|
|
|
|
nlg_pipeline = pipeline('translation_ko_to_ko', model=model, tokenizer=tokenizer) |
|
|
|
|
|
@app.route('/') |
|
|
def index(): |
|
|
return render_template("index.html") |
|
|
|
|
|
@app.route('/voice/<sentence>') |
|
|
def voice(sentence): |
|
|
OUTPUT_WAV_PATH = voice_inference(sentence) |
|
|
return OUTPUT_WAV_PATH |
|
|
|
|
|
@app.route('/chatbot') |
|
|
def chatbot(): |
|
|
return render_template("chatbot.html") |
|
|
|
|
|
|
|
|
@app.route('/process_input/<input_text>') |
|
|
def process_input(input_text): |
|
|
try: |
|
|
print("input_text", input_text, "==========================================") |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
answer = generate_text(input_text) |
|
|
|
|
|
print(answer, "=============================================") |
|
|
|
|
|
|
|
|
jeju_answer = nlg_pipeline(answer, max_length=60)[0]['translation_text'] |
|
|
|
|
|
print(jeju_answer, "=============================================") |
|
|
|
|
|
|
|
|
return jsonify({'answer': answer, 'jeju_answer': jeju_answer}) |
|
|
|
|
|
except Exception as e: |
|
|
print("Exception:", str(e)) |
|
|
|
|
|
return jsonify({'error': str(e)}) |
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
app.run(debug=True, host='0.0.0.0', port=8000) |
|
|
|