Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from peft import PeftModel | |
| import torch | |
| import os | |
| from huggingface_hub import login | |
| import spaces | |
| # Authenticate with Hugging Face | |
| hf_token = os.getenv("HF_TOKEN") | |
| login(token=hf_token) | |
| # Model repository IDs | |
| base_model_id = "meta-llama/Llama-3.2-3B-Instruct" | |
| peft_model_id = "ubiodee/Plutuslearn-Llama-3.2-3B-Instruct" # Replace with your model repo (e.g., ubiodee/my-finetuned-model) | |
| # Load the tokenizer from the fine-tuned model | |
| tokenizer = AutoTokenizer.from_pretrained(peft_model_id, token=hf_token) | |
| # Load the base model | |
| base_model = AutoModelForCausalLM.from_pretrained( | |
| base_model_id, | |
| torch_dtype=torch.float16, | |
| device_map="auto", | |
| token=hf_token, | |
| low_cpu_mem_usage=True, | |
| trust_remote_code=True | |
| ) | |
| base_model.resize_token_embeddings(len(tokenizer)) | |
| # Load the PEFT adapter | |
| model = PeftModel.from_pretrained(base_model, peft_model_id, token=hf_token) | |
| # Define the prediction function with proper device handling | |
| def predict(text, max_length=100): | |
| try: | |
| messages = [{"role": "user", "content": text}] | |
| inputs = tokenizer.apply_chat_template(messages, return_tensors="pt", add_generation_prompt=True) | |
| # Handle inputs based on type | |
| if isinstance(inputs, dict): | |
| inputs = {key: val.to("cuda:0") for key, val in inputs.items()} | |
| outputs = model.generate(**inputs, max_length=max_length) | |
| else: | |
| # If inputs is a tensor (e.g., input_ids) | |
| inputs = inputs.to("cuda:0") | |
| outputs = model.generate(input_ids=inputs, max_length=max_length) | |
| return tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| except Exception as e: | |
| return f"Error during inference: {str(e)}" | |
| # Create Gradio interface | |
| demo = gr.Interface( | |
| fn=predict, | |
| inputs=[ | |
| gr.Textbox(label="Input Text"), | |
| gr.Slider(label="Max Length", minimum=50, maximum=500, value=100, step=1) | |
| ], | |
| outputs=gr.Textbox(label="Model Output"), | |
| title="LearnPlutus Demo", | |
| description="Test the fine-tuned Llama-3.2-3B-Instruct model on ZeroGPU.", | |
| flagging_mode="never" | |
| ) | |
| # Launch the app | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| share=True, | |
| debug=True | |
| ) |