File size: 5,621 Bytes
43608d2
586cb23
 
e9d28c8
43608d2
586cb23
 
43608d2
4bd5343
586cb23
6cc0654
586cb23
 
a8183fa
586cb23
 
6bb3a01
6cc0654
586cb23
 
43608d2
6b01cb7
4d94ece
e9d28c8
 
 
 
 
 
 
 
 
 
 
 
 
 
dc4aae7
426c152
e9d28c8
 
 
4bd5343
6b01cb7
4bd5343
 
4d94ece
 
 
 
 
7bff2c6
 
4d94ece
7bff2c6
4d94ece
 
7bff2c6
 
 
 
 
 
 
 
 
 
4d94ece
7bff2c6
4d94ece
 
 
dc4aae7
e9d28c8
586cb23
 
 
 
43608d2
586cb23
 
 
 
 
 
6cc0654
 
 
43608d2
586cb23
 
 
 
 
 
 
 
 
 
43608d2
586cb23
 
 
 
 
 
 
 
 
 
 
43608d2
586cb23
 
4d94ece
586cb23
e9d28c8
 
586cb23
 
 
 
 
 
e9d28c8
586cb23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43608d2
 
586cb23
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
from peft import PeftModel, PeftConfig

# Model and tokenizer initialization
MODEL_NAME = "satishpednekar/sbxcertqueryhelper"

def load_model_org():
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
    # Modified model loading without 8-bit quantization
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_NAME,
        torch_dtype=torch.float16,  # Use float32 instead of float16 for better compatibility
        device_map="auto",
        trust_remote_code=True,
        load_in_8bit=False
        # Removed load_in_8bit parameter
    )
    return model, tokenizer


def load_model_gpu():
    # Load base model first
    base_model = AutoModelForCausalLM.from_pretrained(
        "unsloth/mistral-7b-v0.3",  # Use your base model name
        torch_dtype=torch.float16,
        device_map="auto",
        trust_remote_code=True
    )
    
    # Load the PEFT adapter weights
    model = PeftModel.from_pretrained(
        base_model,
        "satishpednekar/sbx-qhelper-mistral-loraWeights",  # Path to your trained LoRA weights
        torch_dtype=torch.float16,
        device_map="auto"
    )
    
    tokenizer = AutoTokenizer.from_pretrained(
        "unsloth/mistral-7b-v0.3",  # Use your base model name
        trust_remote_code=True
    )
    
    return model, tokenizer

def load_model():
    config = PeftConfig.from_pretrained("satishpednekar/sbx-qhelper-mistral-loraWeights")
    
    model = AutoModelForCausalLM.from_pretrained(
        config.base_model_name_or_path,
        torch_dtype=torch.float32,
        device_map=None,
        trust_remote_code=True,
        # Remove all quantization-related parameters
    )
    
    model = PeftModel.from_pretrained(
        model, 
        "satishpednekar/sbx-qhelper-mistral-loraWeights",
        torch_dtype=torch.float32
    )
    
    tokenizer = AutoTokenizer.from_pretrained(
        config.base_model_name_or_path,
        trust_remote_code=True
    )
    
    model = model.to("cpu").eval()
    
    return model, tokenizer



# Initialize model and tokenizer
print("Loading model...")
model, tokenizer = load_model()
print("Model loaded successfully!")

def generate_response(prompt, max_length=512, temperature=0.7, top_p=0.95):
    """
    Generate a response using the fine-tuned model
    """
    try:
        # Prepare the input
        inputs = tokenizer(prompt, return_tensors="pt")
        if torch.cuda.is_available():
            inputs = inputs.to(model.device)

        # Generate
        outputs = model.generate(
            **inputs,
            max_length=max_length,
            temperature=temperature,
            top_p=top_p,
            do_sample=True,
            pad_token_id=tokenizer.eos_token_id,
            num_return_sequences=1
        )

        # Decode the response
        response = tokenizer.decode(outputs[0], skip_special_tokens=True)
        
        # Clean up the response by removing the prompt if it appears at the start
        if response.startswith(prompt):
            response = response[len(prompt):].strip()
            
        return response
    
    except Exception as e:
        return f"An error occurred: {str(e)}"

# Create the Gradio interface
def main():
    with gr.Blocks(title="SBX Certification Query Helper") as demo:
        gr.Markdown("""
        # SBX Certification Query Helper
        Ask questions about SBX certifications and get detailed answers!
        """)
        
        with gr.Row():
            with gr.Column():
                input_text = gr.Textbox(
                    label="Your Question",
                    placeholder="Enter your question about SBX certifications...",
                    lines=3
                )
                
                with gr.Row():
                    temperature = gr.Slider(
                        minimum=0.1,
                        maximum=1.0,
                        value=0.7,
                        step=0.1,
                        label="Temperature",
                        info="Higher values make output more random, lower values make it more focused"
                    )
                    
                    max_length = gr.Slider(
                        minimum=64,
                        maximum=1024,
                        value=512,
                        step=64,
                        label="Maximum Length",
                        info="Maximum length of the generated response"
                    )
                
                submit_btn = gr.Button("Get Answer", variant="primary")
                
            with gr.Column():
                output_text = gr.Textbox(
                    label="Answer",
                    lines=10,
                    show_copy_button=True
                )
        
        # Set up the click event
        submit_btn.click(
            fn=generate_response,
            inputs=[input_text, max_length, temperature],
            outputs=output_text
        )
        
        gr.Markdown("""
        ### Tips:
        - Be specific in your questions
        - Include the certification name if you're asking about a specific certification
        - Adjust the temperature slider to control response creativity
        """)
        
    return demo

if __name__ == "__main__":
    demo = main()
    demo.launch(
        share=True,  # Enable sharing
        enable_queue=True,  # Enable queue for handling multiple requests
        server_name="0.0.0.0"  # Listen on all network interfaces
    )