File size: 2,633 Bytes
4f591e5
 
 
 
 
 
 
78b4c1f
4f591e5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b92108d
0daa81d
 
 
 
 
78b4c1f
 
 
 
8a2aebb
b92108d
4f591e5
a5cd8cb
4f591e5
 
78b4c1f
 
 
4f591e5
 
 
78b4c1f
 
 
4f591e5
 
33ff5ca
4f591e5
 
 
 
 
78b4c1f
 
 
 
 
 
 
4f591e5
78b4c1f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
from flask import Flask
from flask_sqlalchemy import SQLAlchemy
from flask_login import LoginManager
from flask import render_template
from transformers import LongformerTokenizer, AutoTokenizer, AutoModelForCausalLM, LongformerConfig
from models.model import *
from utils.util_func import *
from safetensors.torch import load_file

db = SQLAlchemy()
login_manager = LoginManager()

MODELS_LOADED = False
LONGFORMER_TOKENIZER = None
LONGFORMER_MODEL = None
QWEN_TOKENIZER = None
QWEN_MODEL = None
MODEL_SESSION = None

def load_models():
    global MODELS_LOADED, LONGFORMER_TOKENIZER, LONGFORMER_MODEL, QWEN_TOKENIZER, QWEN_MODEL
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    if not MODELS_LOADED:
        print("DEVICE", device)
        LONGFORMER_TOKENIZER = LongformerTokenizer.from_pretrained('allenai/longformer-base-4096', device='auto')
        config = LongformerConfig.from_json_file("Longformer_checkpoint/config.json")
        LONGFORMER_MODEL = CustomLongformerForSequenceClassification(config).from_pretrained('SFM2001/LongFormerScorer')
        LONGFORMER_MODEL = LONGFORMER_MODEL.to(device)
        LONGFORMER_MODEL.eval()
        
        model_name = 'Qwen/Qwen3-1.7B'
        QWEN_TOKENIZER = AutoTokenizer.from_pretrained(model_name, device='auto')
        QWEN_TOKENIZER.pad_token_id = QWEN_TOKENIZER.eos_token_id
        QWEN_MODEL = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True, device_map="auto", load_in_8bit=True, torch_dtype=torch.float16).half()
        QWEN_MODEL = QWEN_MODEL.to(device)
        MODELS_LOADED = True

def create_app():
    set_seed(42)
    load_models()
    app = Flask(__name__)
    app.config.from_pyfile('configs.py')
    db.init_app(app)
    login_manager.init_app(app)
    login_manager.login_view = 'auth.login'
    @login_manager.user_loader
    def load_user(user_id):
        return User.query.get(int(user_id))

    with app.app_context():
        from views import auth_bp, dashboard_bp, infer_bp, about_bp, error_bp
        app.register_blueprint(auth_bp)
        app.register_blueprint(dashboard_bp)
        app.register_blueprint(infer_bp)
        app.register_blueprint(about_bp)
        app.register_blueprint(error_bp)
        @app.errorhandler(Exception)
        def handle_all_exceptions(e):
            code = getattr(e, 'code', 500)
            error_message = str(e) if hasattr(e, 'description') else "Something went wrong."
            return render_template('error.html', code=code, error_message=error_message), code
        from database import User, History
        db.create_all()
    
    return app