from flask import Flask, request, render_template, send_file from transformers import AutoModelForTokenClassification, AutoTokenizer, pipeline from collections import Counter import datetime, json import os app = Flask(__name__) try: # Load model from local files model = AutoModelForTokenClassification.from_pretrained("bioelectra_model", local_files_only=True) tokenizer = AutoTokenizer.from_pretrained("bioelectra_model", local_files_only=True) nlp = pipeline("token-classification", model=model, tokenizer=tokenizer, aggregation_strategy="simple") print("Pipeline loaded successfully!") except Exception as e: print(f"Error loading pipeline: {str(e)}") raise # Main route @app.route("/", methods=["GET", "POST"]) def index(): predictions = [] tag_labels = [] tag_counts = [] input_text = "" if request.method == "POST": input_text = request.form["input_text"] print("User Input:", input_text) print("Tokenized:", tokenizer.tokenize(input_text)) predictions = nlp(input_text) print("Predictions:", predictions) if not predictions: return "Error: No medical abbreviation detected in the input.", 400 log_entry = { "timestamp": str(datetime.datetime.now()), "input": input_text, "predictions": predictions } # Save and print log in the log tab with open("log.jsonl", "a") as f: f.write(json.dumps(log_entry, default=str) + "\n") print("\n--- LOG FILE CONTENTS ---") with open("log.jsonl", "r") as f: for line in f: print(line.strip()) print("--- END OF LOG ---\n") label_counter = Counter([item["entity_group"] for item in predictions]) tag_labels = list(label_counter.keys()) tag_counts = list(label_counter.values()) return render_template("index.html", predictions=predictions, input_text=input_text, tag_labels=tag_labels, tag_counts=tag_counts) # Route to view log @app.route("/view_log") def view_log(): with open("log.jsonl", "r") as f: contents = f.read() return f"
{contents}
" # Route to download the log @app.route("/download_log") def download_log(): return send_file("log.jsonl", as_attachment=True) if __name__ == "__main__": port = int(os.environ.get("PORT", 7860)) app.run(host="0.0.0.0", port=port)