Spaces:
Runtime error
Runtime error
| import torch | |
| import os | |
| from flask import Flask, request, jsonify | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig | |
| from peft import PeftModel | |
| from flask_cors import CORS # Import CORS | |
| app = Flask(__name__) | |
| CORS(app) # Enable CORS for all routes | |
| # Define the path where the fine-tuned PEFT adapter and tokenizer were saved | |
| model_save_path = os.getenv("MODEL_SAVE_PATH", "/content/mistral_legal_lora") | |
| base_model_name = os.getenv("BASE_MODEL_NAME", "mistralai/Mistral-7B-v0.1") | |
| # Load the tokenizer | |
| tokenizer = AutoTokenizer.from_pretrained(model_save_path) | |
| # Prepare BitsAndBytesConfig for loading the base model in 4-bit for inference | |
| bnb_config = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_use_double_quant=True, | |
| bnb_4bit_quant_type="nf4", | |
| bnb_4bit_compute_dtype=torch.bfloat16, | |
| ) | |
| # Load the base model in 4-bit | |
| base_model = AutoModelForCausalLM.from_pretrained( | |
| base_model_name, | |
| quantization_config=bnb_config | |
| ) | |
| # Load the PEFT adapter and apply it to the base model | |
| model = PeftModel.from_pretrained(base_model, model_save_path) | |
| model.eval() | |
| # Define the generation function (similar to what was in the notebook) | |
| def generate_summary(instruction, input_text): | |
| prompt = ( | |
| f"### Instruction: | |
| {instruction} | |
| " | |
| f"### Input: | |
| {input_text} | |
| " | |
| f"### Summary:" | |
| ) | |
| inputs = tokenizer(prompt, return_tensors="pt").to(model.device) | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| **inputs, | |
| max_new_tokens=150, | |
| num_return_sequences=1, | |
| pad_token_id=tokenizer.eos_token_id, | |
| do_sample=True, | |
| top_k=50, | |
| top_p=0.95, | |
| temperature=0.7 | |
| ) | |
| decoded_output = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| # Extract only the summary part if the model generates extra text | |
| summary_start_tag = "### Summary:" | |
| if summary_start_tag in decoded_output: | |
| summary = decoded_output.split(summary_start_tag, 1)[1].strip() | |
| else: | |
| summary = decoded_output | |
| return summary | |
| def summarize(): | |
| data = request.json | |
| instruction = data.get('instruction', '') | |
| input_document = data.get('input_document', '') | |
| if not instruction or not input_document: | |
| return jsonify({'error': 'Instruction and input document are required.'}), 400 | |
| try: | |
| summary = generate_summary(instruction, input_document) | |
| return jsonify({'summary': summary}) | |
| except Exception as e: | |
| return jsonify({'error': str(e)}), 500 | |
| def index(): | |
| return app.send_static_file('index.html') | |
| def static_files(path): | |
| return app.send_static_file(path) | |
| if __name__ == '__main__': | |
| # In a Colab environment, you might need to adjust the host/port for external access if deploying | |
| # For local testing, host='0.0.0.0' makes it accessible within the Colab environment | |
| app.run(host='0.0.0.0', port=8000) | |