import gradio as gr from transformers import AutoModelForCausalLM, AutoTokenizer from peft import PeftModel import torch import os from huggingface_hub import login import spaces # Authenticate with Hugging Face hf_token = os.getenv("HF_TOKEN") login(token=hf_token) # Model repository IDs base_model_id = "meta-llama/Llama-3.2-3B-Instruct" peft_model_id = "ubiodee/Plutuslearn-Llama-3.2-3B-Instruct" # Replace with your model repo (e.g., ubiodee/my-finetuned-model) # Load the tokenizer from the fine-tuned model tokenizer = AutoTokenizer.from_pretrained(peft_model_id, token=hf_token) # Load the base model base_model = AutoModelForCausalLM.from_pretrained( base_model_id, torch_dtype=torch.float16, device_map="auto", token=hf_token, low_cpu_mem_usage=True, trust_remote_code=True ) base_model.resize_token_embeddings(len(tokenizer)) # Load the PEFT adapter model = PeftModel.from_pretrained(base_model, peft_model_id, token=hf_token) # Define the prediction function with proper device handling @spaces.GPU(duration=120) def predict(text, max_length=100): try: messages = [{"role": "user", "content": text}] inputs = tokenizer.apply_chat_template(messages, return_tensors="pt", add_generation_prompt=True) # Handle inputs based on type if isinstance(inputs, dict): inputs = {key: val.to("cuda:0") for key, val in inputs.items()} outputs = model.generate(**inputs, max_length=max_length) else: # If inputs is a tensor (e.g., input_ids) inputs = inputs.to("cuda:0") outputs = model.generate(input_ids=inputs, max_length=max_length) return tokenizer.decode(outputs[0], skip_special_tokens=True) except Exception as e: return f"Error during inference: {str(e)}" # Create Gradio interface demo = gr.Interface( fn=predict, inputs=[ gr.Textbox(label="Input Text"), gr.Slider(label="Max Length", minimum=50, maximum=500, value=100, step=1) ], outputs=gr.Textbox(label="Model Output"), title="LearnPlutus Demo", description="Test the fine-tuned Llama-3.2-3B-Instruct model on ZeroGPU.", flagging_mode="never" ) # Launch the app demo.launch( server_name="0.0.0.0", server_port=7860, share=True, debug=True )