SimpleAES / create_app.py
SFM2001's picture
>
a1577b4
raw
history blame
3.27 kB
from flask import Flask
from flask_sqlalchemy import SQLAlchemy
from flask_login import LoginManager
from flask import render_template
from transformers import LongformerTokenizer, AutoTokenizer, AutoModelForCausalLM, LongformerConfig
from models.model import *
from utils.util_func import *
import os
db = SQLAlchemy()
login_manager = LoginManager()
MODELS_LOADED = False
LONGFORMER_TOKENIZER = None
LONGFORMER_MODEL = None
QWEN_TOKENIZER = None
QWEN_MODEL = None
MODEL_SESSION = None
def load_models():
global MODELS_LOADED, LONGFORMER_TOKENIZER, LONGFORMER_MODEL, QWEN_TOKENIZER, QWEN_MODEL
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("DEIVCE=", device)
if not MODELS_LOADED:
print("START TO GET QWEN")
model_name = 'Qwen/Qwen3-0.6B'
QWEN_TOKENIZER = AutoTokenizer.from_pretrained(model_name, device='auto')
QWEN_TOKENIZER.pad_token_id = QWEN_TOKENIZER.eos_token_id
print("QWEN TOKENIZER LOADED")
try:
QWEN_MODEL = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")
QWEN_MODEL = QWEN_MODEL.to(device)
print("QWEN MODEL LOADED", flush=True)
except Exception as e:
print(f"ERROR LOADING QWEN MODEL: {str(e)}")
raise # Re-raise to see full traceback
# QWEN_MODEL = AutoModelForCausalLM.from_pretrained(model_name).half()
# QWEN_MODEL = QWEN_MODEL.to(device)
print(QWEN_MODEL)
print("QWEN MODEL LOADED", flush=True)
LONGFORMER_TOKENIZER = LongformerTokenizer.from_pretrained('allenai/longformer-base-4096', device='auto')
config = LongformerConfig.from_json_file("Longformer_checkpoint/config.json")
print("LONGFORMER TOKENIZER LOADED")
LONGFORMER_MODEL = CustomLongformerForSequenceClassification(config).from_pretrained('SFM2001/LongFormerScorer')
LONGFORMER_MODEL = LONGFORMER_MODEL.to(device)
print(LONGFORMER_MODEL)
print("LONGFORMER MODEL LOADED", flush=True)
LONGFORMER_MODEL.eval()
MODELS_LOADED = True
print("LOAD ENDED", flush=True)
def create_app():
set_seed(42)
load_models()
app = Flask(__name__)
app.config.from_pyfile('configs.py')
app.config['SQLALCHEMY_DATABASE_URI'] = 'sqlite:///' + os.path.join(app.instance_path, 'users.db')
db.init_app(app)
login_manager.init_app(app)
login_manager.login_view = 'auth.login'
@login_manager.user_loader
def load_user(user_id):
return User.query.get(int(user_id))
with app.app_context():
from views import auth_bp, dashboard_bp, infer_bp, about_bp, error_bp
app.register_blueprint(auth_bp)
app.register_blueprint(dashboard_bp)
app.register_blueprint(infer_bp)
app.register_blueprint(about_bp)
app.register_blueprint(error_bp)
@app.errorhandler(Exception)
def handle_all_exceptions(e):
code = getattr(e, 'code', 500)
error_message = str(e) if hasattr(e, 'description') else "Something went wrong."
return render_template('error.html', code=code, error_message=error_message), code
from database import User, History
db.create_all()
return app