noding / blueprints /summarize.py
broadfield-dev's picture
Update blueprints/summarize.py
bc221f8 verified
raw
history blame
3.34 kB
import os
import torch
from flask import Blueprint, request, jsonify, render_template
from transformers import pipeline
from huggingface_hub import HfFolder
# Define the Blueprint
summarize_bp = Blueprint('summarize', __name__)
# Global cache to store the loaded model in memory
# This prevents reloading the model on every single request
MODEL_CACHE = {
"model_name": None,
"pipeline": None
}
def get_pipeline(model_name, task_type):
"""
Retrieves a pipeline from cache or loads it if it's new.
"""
global MODEL_CACHE
# If we already have this model loaded, return it
if MODEL_CACHE["model_name"] == model_name and MODEL_CACHE["pipeline"] is not None:
return MODEL_CACHE["pipeline"]
# Authentication
hf_token = os.getenv("HF_TOKEN")
if hf_token:
HfFolder.save_token(hf_token)
# Determine device
device = 0 if torch.cuda.is_available() else -1
dtype = torch.float16 if torch.cuda.is_available() else torch.float32
# Load Pipeline
print(f"Loading model: {model_name}...")
if task_type == 'SEQ_2_SEQ_LM': # Summarization
pipe = pipeline("summarization", model=model_id, device=device)
elif task_type == 'TOKEN_CLS':
pipe = pipeline("token-classification", model=model_name, aggregation_strategy="simple")
else:
pipe = pipeline("text-generation", model=model_name, torch_dtype=dtype, device=device)
# Update Cache
MODEL_CACHE["model_name"] = model_name
MODEL_CACHE["pipeline"] = pipe
return pipe
def run_inference_logic(config):
model_id = config['model_name']
text = config['text']
task_type = config['task_type']
pipe = get_pipeline(model_id, task_type)
if task_type == 'TOKEN_CLS':
results = pipe(text)
results_sorted = sorted(results, key=lambda x: x['start'], reverse=True)
masked_list = list(text)
for ent in results_sorted:
masked_list[ent['start']:ent['end']] = list(f"<{ent['entity_group']}>")
return {
"masked_text": "".join(masked_list),
"labels": [r['entity_group'] for r in results]
}
elif task_type == 'SEQ_2_SEQ_LM':
# Summarization specific args
out = pipe(text, max_length=512, min_length=30, do_sample=False)
return {"output": out[0]['summary_text']}
else:
out = pipe(text, max_new_tokens=1024)
return {"output": out[0]['generated_text']}
# --- Routes ---
@summarize_bp.route('/', methods=['GET'])
def index():
"""Renders the UI."""
return render_template('inference.html')
@summarize_bp.route('/api/summarize', methods=['POST'])
def api_summarize():
"""API Endpoint to handle the AJAX request from the UI."""
data = request.get_json()
if not data or 'text' not in data:
return jsonify({"error": "No text provided"}), 400
config = {
"text": data['text'],
"model_name": data.get('model_name', "facebook/bart-large-cnn"),
# We force this for the specific summarization UI,
# but the backend logic supports others.
"task_type": "SEQ_2_SEQ_LM"
}
try:
result = run_inference_logic(config)
return jsonify(result)
except Exception as e:
print(f"Error: {e}")
return jsonify({"error": str(e)}), 500