indio / app.py
rishu834763's picture
Upload 3 files
4502f97 verified
import torch
import os
from flask import Flask, request, jsonify
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from peft import PeftModel
from flask_cors import CORS # Import CORS
app = Flask(__name__)
CORS(app) # Enable CORS for all routes
# Define the path where the fine-tuned PEFT adapter and tokenizer were saved
model_save_path = os.getenv("MODEL_SAVE_PATH", "/content/mistral_legal_lora")
base_model_name = os.getenv("BASE_MODEL_NAME", "mistralai/Mistral-7B-v0.1")
# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_save_path)
# Prepare BitsAndBytesConfig for loading the base model in 4-bit for inference
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
)
# Load the base model in 4-bit
base_model = AutoModelForCausalLM.from_pretrained(
base_model_name,
quantization_config=bnb_config
)
# Load the PEFT adapter and apply it to the base model
model = PeftModel.from_pretrained(base_model, model_save_path)
model.eval()
# Define the generation function (similar to what was in the notebook)
def generate_summary(instruction, input_text):
prompt = (
f"### Instruction:
{instruction}
"
f"### Input:
{input_text}
"
f"### Summary:"
)
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=150,
num_return_sequences=1,
pad_token_id=tokenizer.eos_token_id,
do_sample=True,
top_k=50,
top_p=0.95,
temperature=0.7
)
decoded_output = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Extract only the summary part if the model generates extra text
summary_start_tag = "### Summary:"
if summary_start_tag in decoded_output:
summary = decoded_output.split(summary_start_tag, 1)[1].strip()
else:
summary = decoded_output
return summary
@app.route('/summarize', methods=['POST'])
def summarize():
data = request.json
instruction = data.get('instruction', '')
input_document = data.get('input_document', '')
if not instruction or not input_document:
return jsonify({'error': 'Instruction and input document are required.'}), 400
try:
summary = generate_summary(instruction, input_document)
return jsonify({'summary': summary})
except Exception as e:
return jsonify({'error': str(e)}), 500
@app.route('/')
def index():
return app.send_static_file('index.html')
@app.route('/<path:path>')
def static_files(path):
return app.send_static_file(path)
if __name__ == '__main__':
# In a Colab environment, you might need to adjust the host/port for external access if deploying
# For local testing, host='0.0.0.0' makes it accessible within the Colab environment
app.run(host='0.0.0.0', port=8000)