import torch import os from flask import Flask, request, jsonify from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig from peft import PeftModel from flask_cors import CORS # Import CORS app = Flask(__name__) CORS(app) # Enable CORS for all routes # Define the path where the fine-tuned PEFT adapter and tokenizer were saved model_save_path = os.getenv("MODEL_SAVE_PATH", "/content/mistral_legal_lora") base_model_name = os.getenv("BASE_MODEL_NAME", "mistralai/Mistral-7B-v0.1") # Load the tokenizer tokenizer = AutoTokenizer.from_pretrained(model_save_path) # Prepare BitsAndBytesConfig for loading the base model in 4-bit for inference bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16, ) # Load the base model in 4-bit base_model = AutoModelForCausalLM.from_pretrained( base_model_name, quantization_config=bnb_config ) # Load the PEFT adapter and apply it to the base model model = PeftModel.from_pretrained(base_model, model_save_path) model.eval() # Define the generation function (similar to what was in the notebook) def generate_summary(instruction, input_text): prompt = ( f"### Instruction: {instruction} " f"### Input: {input_text} " f"### Summary:" ) inputs = tokenizer(prompt, return_tensors="pt").to(model.device) with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=150, num_return_sequences=1, pad_token_id=tokenizer.eos_token_id, do_sample=True, top_k=50, top_p=0.95, temperature=0.7 ) decoded_output = tokenizer.decode(outputs[0], skip_special_tokens=True) # Extract only the summary part if the model generates extra text summary_start_tag = "### Summary:" if summary_start_tag in decoded_output: summary = decoded_output.split(summary_start_tag, 1)[1].strip() else: summary = decoded_output return summary @app.route('/summarize', methods=['POST']) def summarize(): data = request.json instruction = data.get('instruction', '') input_document = data.get('input_document', '') if not instruction or not input_document: return jsonify({'error': 'Instruction and input document are required.'}), 400 try: summary = generate_summary(instruction, input_document) return jsonify({'summary': summary}) except Exception as e: return jsonify({'error': str(e)}), 500 @app.route('/') def index(): return app.send_static_file('index.html') @app.route('/') def static_files(path): return app.send_static_file(path) if __name__ == '__main__': # In a Colab environment, you might need to adjust the host/port for external access if deploying # For local testing, host='0.0.0.0' makes it accessible within the Colab environment app.run(host='0.0.0.0', port=8000)