Spaces:
Paused
Paused
File size: 2,156 Bytes
7425b43 916b5d1 1e4613a 7425b43 85065bb 1214f20 85065bb 1e4613a 4a38fe1 10d867d 1214f20 85065bb 4a38fe1 1214f20 9cc11a1 10d867d 9cc11a1 32a9f05 e10cf34 4a38fe1 1214f20 85065bb 1214f20 7425b43 85065bb 10d867d 85065bb 10d867d 85065bb 10d867d 85065bb 7425b43 85065bb 4a38fe1 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 | import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import os
# 加载模型和分词器,使用环境变量中的令牌
model_name = "meta-llama/Llama-2-7b-hf"
token = os.getenv("HF_TOKEN") # 从环境变量中获取 Token
# 检查是否可以使用 GPU
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
try:
tokenizer = AutoTokenizer.from_pretrained(model_name, token=token)
model = AutoModelForCausalLM.from_pretrained(
model_name,
device_map="auto", # 自动分配设备
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
token=token
)
print("Model and tokenizer loaded successfully.")
except Exception as e:
print(f"Error loading model or tokenizer: {e}")
raise e
# 定义生成功能的教学内容,保证可重复性
def generate_content_with_parameters(prompt, temperature, max_length):
try:
# 将输入移动到模型所在设备
inputs = tokenizer(prompt, return_tensors="pt", truncation=True).to(device)
# 生成文本
outputs = model.generate(
**inputs,
max_length=max_length,
temperature=temperature, # 格式化传入的温度
num_return_sequences=1,
repetition_penalty=1.2 # 增加重复惩罚
)
# 解码输出
return tokenizer.decode(outputs[0], skip_special_tokens=True)
except Exception as e:
return f"Error during generation: {e}"
# 创建 Gradio 界面
interface = gr.Interface(
fn=generate_content_with_parameters,
inputs=[
gr.Textbox(label="Prompt", placeholder="Enter your prompt here."),
gr.Slider(label="Temperature", minimum=0, maximum=1, step=0.1, value=0.7),
gr.Slider(label="Max Length", minimum=10, maximum=2048, step=10, value=512),
],
outputs="text",
title="Customizable Text Generator",
description="Enter a prompt, adjust the temperature and max length, and generate consistent outputs.",
)
# 启动 Gradio 应用
if __name__ == "__main__":
interface.launch()
|