Spaces:
Runtime error
Runtime error
File size: 5,621 Bytes
43608d2 586cb23 e9d28c8 43608d2 586cb23 43608d2 4bd5343 586cb23 6cc0654 586cb23 a8183fa 586cb23 6bb3a01 6cc0654 586cb23 43608d2 6b01cb7 4d94ece e9d28c8 dc4aae7 426c152 e9d28c8 4bd5343 6b01cb7 4bd5343 4d94ece 7bff2c6 4d94ece 7bff2c6 4d94ece 7bff2c6 4d94ece 7bff2c6 4d94ece dc4aae7 e9d28c8 586cb23 43608d2 586cb23 6cc0654 43608d2 586cb23 43608d2 586cb23 43608d2 586cb23 4d94ece 586cb23 e9d28c8 586cb23 e9d28c8 586cb23 43608d2 586cb23 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 |
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
from peft import PeftModel, PeftConfig
# Model and tokenizer initialization
MODEL_NAME = "satishpednekar/sbxcertqueryhelper"
def load_model_org():
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
# Modified model loading without 8-bit quantization
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
torch_dtype=torch.float16, # Use float32 instead of float16 for better compatibility
device_map="auto",
trust_remote_code=True,
load_in_8bit=False
# Removed load_in_8bit parameter
)
return model, tokenizer
def load_model_gpu():
# Load base model first
base_model = AutoModelForCausalLM.from_pretrained(
"unsloth/mistral-7b-v0.3", # Use your base model name
torch_dtype=torch.float16,
device_map="auto",
trust_remote_code=True
)
# Load the PEFT adapter weights
model = PeftModel.from_pretrained(
base_model,
"satishpednekar/sbx-qhelper-mistral-loraWeights", # Path to your trained LoRA weights
torch_dtype=torch.float16,
device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained(
"unsloth/mistral-7b-v0.3", # Use your base model name
trust_remote_code=True
)
return model, tokenizer
def load_model():
config = PeftConfig.from_pretrained("satishpednekar/sbx-qhelper-mistral-loraWeights")
model = AutoModelForCausalLM.from_pretrained(
config.base_model_name_or_path,
torch_dtype=torch.float32,
device_map=None,
trust_remote_code=True,
# Remove all quantization-related parameters
)
model = PeftModel.from_pretrained(
model,
"satishpednekar/sbx-qhelper-mistral-loraWeights",
torch_dtype=torch.float32
)
tokenizer = AutoTokenizer.from_pretrained(
config.base_model_name_or_path,
trust_remote_code=True
)
model = model.to("cpu").eval()
return model, tokenizer
# Initialize model and tokenizer
print("Loading model...")
model, tokenizer = load_model()
print("Model loaded successfully!")
def generate_response(prompt, max_length=512, temperature=0.7, top_p=0.95):
"""
Generate a response using the fine-tuned model
"""
try:
# Prepare the input
inputs = tokenizer(prompt, return_tensors="pt")
if torch.cuda.is_available():
inputs = inputs.to(model.device)
# Generate
outputs = model.generate(
**inputs,
max_length=max_length,
temperature=temperature,
top_p=top_p,
do_sample=True,
pad_token_id=tokenizer.eos_token_id,
num_return_sequences=1
)
# Decode the response
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Clean up the response by removing the prompt if it appears at the start
if response.startswith(prompt):
response = response[len(prompt):].strip()
return response
except Exception as e:
return f"An error occurred: {str(e)}"
# Create the Gradio interface
def main():
with gr.Blocks(title="SBX Certification Query Helper") as demo:
gr.Markdown("""
# SBX Certification Query Helper
Ask questions about SBX certifications and get detailed answers!
""")
with gr.Row():
with gr.Column():
input_text = gr.Textbox(
label="Your Question",
placeholder="Enter your question about SBX certifications...",
lines=3
)
with gr.Row():
temperature = gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.7,
step=0.1,
label="Temperature",
info="Higher values make output more random, lower values make it more focused"
)
max_length = gr.Slider(
minimum=64,
maximum=1024,
value=512,
step=64,
label="Maximum Length",
info="Maximum length of the generated response"
)
submit_btn = gr.Button("Get Answer", variant="primary")
with gr.Column():
output_text = gr.Textbox(
label="Answer",
lines=10,
show_copy_button=True
)
# Set up the click event
submit_btn.click(
fn=generate_response,
inputs=[input_text, max_length, temperature],
outputs=output_text
)
gr.Markdown("""
### Tips:
- Be specific in your questions
- Include the certification name if you're asking about a specific certification
- Adjust the temperature slider to control response creativity
""")
return demo
if __name__ == "__main__":
demo = main()
demo.launch(
share=True, # Enable sharing
enable_queue=True, # Enable queue for handling multiple requests
server_name="0.0.0.0" # Listen on all network interfaces
) |