PlutusLearn / app.py
Remostart's picture
Update app.py
70a8ea1 verified
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
import torch
import os
from huggingface_hub import login
import spaces
# Authenticate with Hugging Face
hf_token = os.getenv("HF_TOKEN")
login(token=hf_token)
# Model repository IDs
base_model_id = "meta-llama/Llama-3.2-3B-Instruct"
peft_model_id = "ubiodee/Plutuslearn-Llama-3.2-3B-Instruct" # Replace with your model repo (e.g., ubiodee/my-finetuned-model)
# Load the tokenizer from the fine-tuned model
tokenizer = AutoTokenizer.from_pretrained(peft_model_id, token=hf_token)
# Load the base model
base_model = AutoModelForCausalLM.from_pretrained(
base_model_id,
torch_dtype=torch.float16,
device_map="auto",
token=hf_token,
low_cpu_mem_usage=True,
trust_remote_code=True
)
base_model.resize_token_embeddings(len(tokenizer))
# Load the PEFT adapter
model = PeftModel.from_pretrained(base_model, peft_model_id, token=hf_token)
# Define the prediction function with proper device handling
@spaces.GPU(duration=120)
def predict(text, max_length=100):
try:
messages = [{"role": "user", "content": text}]
inputs = tokenizer.apply_chat_template(messages, return_tensors="pt", add_generation_prompt=True)
# Handle inputs based on type
if isinstance(inputs, dict):
inputs = {key: val.to("cuda:0") for key, val in inputs.items()}
outputs = model.generate(**inputs, max_length=max_length)
else:
# If inputs is a tensor (e.g., input_ids)
inputs = inputs.to("cuda:0")
outputs = model.generate(input_ids=inputs, max_length=max_length)
return tokenizer.decode(outputs[0], skip_special_tokens=True)
except Exception as e:
return f"Error during inference: {str(e)}"
# Create Gradio interface
demo = gr.Interface(
fn=predict,
inputs=[
gr.Textbox(label="Input Text"),
gr.Slider(label="Max Length", minimum=50, maximum=500, value=100, step=1)
],
outputs=gr.Textbox(label="Model Output"),
title="LearnPlutus Demo",
description="Test the fine-tuned Llama-3.2-3B-Instruct model on ZeroGPU.",
flagging_mode="never"
)
# Launch the app
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=True,
debug=True
)