File size: 2,633 Bytes
4f591e5 78b4c1f 4f591e5 b92108d 0daa81d 78b4c1f 8a2aebb b92108d 4f591e5 a5cd8cb 4f591e5 78b4c1f 4f591e5 78b4c1f 4f591e5 33ff5ca 4f591e5 78b4c1f 4f591e5 78b4c1f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 |
from flask import Flask
from flask_sqlalchemy import SQLAlchemy
from flask_login import LoginManager
from flask import render_template
from transformers import LongformerTokenizer, AutoTokenizer, AutoModelForCausalLM, LongformerConfig
from models.model import *
from utils.util_func import *
from safetensors.torch import load_file
db = SQLAlchemy()
login_manager = LoginManager()
MODELS_LOADED = False
LONGFORMER_TOKENIZER = None
LONGFORMER_MODEL = None
QWEN_TOKENIZER = None
QWEN_MODEL = None
MODEL_SESSION = None
def load_models():
global MODELS_LOADED, LONGFORMER_TOKENIZER, LONGFORMER_MODEL, QWEN_TOKENIZER, QWEN_MODEL
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if not MODELS_LOADED:
print("DEVICE", device)
LONGFORMER_TOKENIZER = LongformerTokenizer.from_pretrained('allenai/longformer-base-4096', device='auto')
config = LongformerConfig.from_json_file("Longformer_checkpoint/config.json")
LONGFORMER_MODEL = CustomLongformerForSequenceClassification(config).from_pretrained('SFM2001/LongFormerScorer')
LONGFORMER_MODEL = LONGFORMER_MODEL.to(device)
LONGFORMER_MODEL.eval()
model_name = 'Qwen/Qwen3-1.7B'
QWEN_TOKENIZER = AutoTokenizer.from_pretrained(model_name, device='auto')
QWEN_TOKENIZER.pad_token_id = QWEN_TOKENIZER.eos_token_id
QWEN_MODEL = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True, device_map="auto", load_in_8bit=True, torch_dtype=torch.float16).half()
QWEN_MODEL = QWEN_MODEL.to(device)
MODELS_LOADED = True
def create_app():
set_seed(42)
load_models()
app = Flask(__name__)
app.config.from_pyfile('configs.py')
db.init_app(app)
login_manager.init_app(app)
login_manager.login_view = 'auth.login'
@login_manager.user_loader
def load_user(user_id):
return User.query.get(int(user_id))
with app.app_context():
from views import auth_bp, dashboard_bp, infer_bp, about_bp, error_bp
app.register_blueprint(auth_bp)
app.register_blueprint(dashboard_bp)
app.register_blueprint(infer_bp)
app.register_blueprint(about_bp)
app.register_blueprint(error_bp)
@app.errorhandler(Exception)
def handle_all_exceptions(e):
code = getattr(e, 'code', 500)
error_message = str(e) if hasattr(e, 'description') else "Something went wrong."
return render_template('error.html', code=code, error_message=error_message), code
from database import User, History
db.create_all()
return app |