File size: 3,039 Bytes
9abf8fd
4502f97
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9abf8fd
4502f97
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103

import torch
import os
from flask import Flask, request, jsonify
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from peft import PeftModel
from flask_cors import CORS # Import CORS

app = Flask(__name__)
CORS(app) # Enable CORS for all routes

# Define the path where the fine-tuned PEFT adapter and tokenizer were saved
model_save_path = os.getenv("MODEL_SAVE_PATH", "/content/mistral_legal_lora")
base_model_name = os.getenv("BASE_MODEL_NAME", "mistralai/Mistral-7B-v0.1")

# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_save_path)

# Prepare BitsAndBytesConfig for loading the base model in 4-bit for inference
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
)

# Load the base model in 4-bit
base_model = AutoModelForCausalLM.from_pretrained(
    base_model_name,
    quantization_config=bnb_config
)

# Load the PEFT adapter and apply it to the base model
model = PeftModel.from_pretrained(base_model, model_save_path)
model.eval()

# Define the generation function (similar to what was in the notebook)
def generate_summary(instruction, input_text):
    prompt = (
        f"### Instruction:
{instruction}

"
        f"### Input:
{input_text}

"
        f"### Summary:"
    )

    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=150,
            num_return_sequences=1,
            pad_token_id=tokenizer.eos_token_id,
            do_sample=True,
            top_k=50,
            top_p=0.95,
            temperature=0.7
        )

    decoded_output = tokenizer.decode(outputs[0], skip_special_tokens=True)

    # Extract only the summary part if the model generates extra text
    summary_start_tag = "### Summary:"
    if summary_start_tag in decoded_output:
        summary = decoded_output.split(summary_start_tag, 1)[1].strip()
    else:
        summary = decoded_output

    return summary

@app.route('/summarize', methods=['POST'])
def summarize():
    data = request.json
    instruction = data.get('instruction', '')
    input_document = data.get('input_document', '')

    if not instruction or not input_document:
        return jsonify({'error': 'Instruction and input document are required.'}), 400

    try:
        summary = generate_summary(instruction, input_document)
        return jsonify({'summary': summary})
    except Exception as e:
        return jsonify({'error': str(e)}), 500

@app.route('/')
def index():
    return app.send_static_file('index.html')

@app.route('/<path:path>')
def static_files(path):
    return app.send_static_file(path)

if __name__ == '__main__':
    # In a Colab environment, you might need to adjust the host/port for external access if deploying
    # For local testing, host='0.0.0.0' makes it accessible within the Colab environment
    app.run(host='0.0.0.0', port=8000)