import gradio as gr from transformers import AutoTokenizer, AutoModelForCausalLM import torch from peft import PeftModel, PeftConfig # Model and tokenizer initialization MODEL_NAME = "satishpednekar/sbxcertqueryhelper" def load_model_org(): tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True) # Modified model loading without 8-bit quantization model = AutoModelForCausalLM.from_pretrained( MODEL_NAME, torch_dtype=torch.float16, # Use float32 instead of float16 for better compatibility device_map="auto", trust_remote_code=True, load_in_8bit=False # Removed load_in_8bit parameter ) return model, tokenizer def load_model_gpu(): # Load base model first base_model = AutoModelForCausalLM.from_pretrained( "unsloth/mistral-7b-v0.3", # Use your base model name torch_dtype=torch.float16, device_map="auto", trust_remote_code=True ) # Load the PEFT adapter weights model = PeftModel.from_pretrained( base_model, "satishpednekar/sbx-qhelper-mistral-loraWeights", # Path to your trained LoRA weights torch_dtype=torch.float16, device_map="auto" ) tokenizer = AutoTokenizer.from_pretrained( "unsloth/mistral-7b-v0.3", # Use your base model name trust_remote_code=True ) return model, tokenizer def load_model(): config = PeftConfig.from_pretrained("satishpednekar/sbx-qhelper-mistral-loraWeights") model = AutoModelForCausalLM.from_pretrained( config.base_model_name_or_path, torch_dtype=torch.float32, device_map=None, trust_remote_code=True, # Remove all quantization-related parameters ) model = PeftModel.from_pretrained( model, "satishpednekar/sbx-qhelper-mistral-loraWeights", torch_dtype=torch.float32 ) tokenizer = AutoTokenizer.from_pretrained( config.base_model_name_or_path, trust_remote_code=True ) model = model.to("cpu").eval() return model, tokenizer # Initialize model and tokenizer print("Loading model...") model, tokenizer = load_model() print("Model loaded successfully!") def generate_response(prompt, max_length=512, temperature=0.7, top_p=0.95): """ Generate a response using the fine-tuned model """ try: # Prepare the input inputs = tokenizer(prompt, return_tensors="pt") if torch.cuda.is_available(): inputs = inputs.to(model.device) # Generate outputs = model.generate( **inputs, max_length=max_length, temperature=temperature, top_p=top_p, do_sample=True, pad_token_id=tokenizer.eos_token_id, num_return_sequences=1 ) # Decode the response response = tokenizer.decode(outputs[0], skip_special_tokens=True) # Clean up the response by removing the prompt if it appears at the start if response.startswith(prompt): response = response[len(prompt):].strip() return response except Exception as e: return f"An error occurred: {str(e)}" # Create the Gradio interface def main(): with gr.Blocks(title="SBX Certification Query Helper") as demo: gr.Markdown(""" # SBX Certification Query Helper Ask questions about SBX certifications and get detailed answers! """) with gr.Row(): with gr.Column(): input_text = gr.Textbox( label="Your Question", placeholder="Enter your question about SBX certifications...", lines=3 ) with gr.Row(): temperature = gr.Slider( minimum=0.1, maximum=1.0, value=0.7, step=0.1, label="Temperature", info="Higher values make output more random, lower values make it more focused" ) max_length = gr.Slider( minimum=64, maximum=1024, value=512, step=64, label="Maximum Length", info="Maximum length of the generated response" ) submit_btn = gr.Button("Get Answer", variant="primary") with gr.Column(): output_text = gr.Textbox( label="Answer", lines=10, show_copy_button=True ) # Set up the click event submit_btn.click( fn=generate_response, inputs=[input_text, max_length, temperature], outputs=output_text ) gr.Markdown(""" ### Tips: - Be specific in your questions - Include the certification name if you're asking about a specific certification - Adjust the temperature slider to control response creativity """) return demo if __name__ == "__main__": demo = main() demo.launch( share=True, # Enable sharing enable_queue=True, # Enable queue for handling multiple requests server_name="0.0.0.0" # Listen on all network interfaces )