|
|
from flask import Flask, request, render_template, send_file |
|
|
from transformers import AutoModelForTokenClassification, AutoTokenizer, pipeline |
|
|
from collections import Counter |
|
|
import datetime, json |
|
|
import os |
|
|
|
|
|
app = Flask(__name__) |
|
|
|
|
|
try: |
|
|
|
|
|
model = AutoModelForTokenClassification.from_pretrained("bioelectra_model", local_files_only=True) |
|
|
tokenizer = AutoTokenizer.from_pretrained("bioelectra_model", local_files_only=True) |
|
|
nlp = pipeline("token-classification", model=model, tokenizer=tokenizer, aggregation_strategy="simple") |
|
|
print("Pipeline loaded successfully!") |
|
|
except Exception as e: |
|
|
print(f"Error loading pipeline: {str(e)}") |
|
|
raise |
|
|
|
|
|
|
|
|
@app.route("/", methods=["GET", "POST"]) |
|
|
def index(): |
|
|
predictions = [] |
|
|
tag_labels = [] |
|
|
tag_counts = [] |
|
|
input_text = "" |
|
|
|
|
|
if request.method == "POST": |
|
|
input_text = request.form["input_text"] |
|
|
print("User Input:", input_text) |
|
|
print("Tokenized:", tokenizer.tokenize(input_text)) |
|
|
predictions = nlp(input_text) |
|
|
print("Predictions:", predictions) |
|
|
if not predictions: |
|
|
return "Error: No medical abbreviation detected in the input.", 400 |
|
|
|
|
|
|
|
|
log_entry = { |
|
|
"timestamp": str(datetime.datetime.now()), |
|
|
"input": input_text, |
|
|
"predictions": predictions |
|
|
} |
|
|
|
|
|
|
|
|
with open("log.jsonl", "a") as f: |
|
|
f.write(json.dumps(log_entry, default=str) + "\n") |
|
|
|
|
|
print("\n--- LOG FILE CONTENTS ---") |
|
|
with open("log.jsonl", "r") as f: |
|
|
for line in f: |
|
|
print(line.strip()) |
|
|
print("--- END OF LOG ---\n") |
|
|
|
|
|
label_counter = Counter([item["entity_group"] for item in predictions]) |
|
|
tag_labels = list(label_counter.keys()) |
|
|
tag_counts = list(label_counter.values()) |
|
|
|
|
|
return render_template("index.html", predictions=predictions, input_text=input_text, |
|
|
tag_labels=tag_labels, tag_counts=tag_counts) |
|
|
|
|
|
|
|
|
@app.route("/view_log") |
|
|
def view_log(): |
|
|
with open("log.jsonl", "r") as f: |
|
|
contents = f.read() |
|
|
return f"<pre>{contents}</pre>" |
|
|
|
|
|
|
|
|
@app.route("/download_log") |
|
|
def download_log(): |
|
|
return send_file("log.jsonl", as_attachment=True) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
port = int(os.environ.get("PORT", 7860)) |
|
|
app.run(host="0.0.0.0", port=port) |
|
|
|