File size: 2,156 Bytes
7425b43
 
916b5d1
1e4613a
7425b43
85065bb
 
 
1214f20
85065bb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1e4613a
4a38fe1
 
10d867d
1214f20
85065bb
 
4a38fe1
1214f20
 
9cc11a1
10d867d
 
9cc11a1
32a9f05
e10cf34
4a38fe1
1214f20
85065bb
1214f20
 
7425b43
85065bb
 
10d867d
85065bb
10d867d
 
 
85065bb
 
10d867d
 
85065bb
7425b43
85065bb
 
4a38fe1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import os

# 加载模型和分词器,使用环境变量中的令牌
model_name = "meta-llama/Llama-2-7b-hf"
token = os.getenv("HF_TOKEN")  # 从环境变量中获取 Token

# 检查是否可以使用 GPU
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

try:
    tokenizer = AutoTokenizer.from_pretrained(model_name, token=token)
    model = AutoModelForCausalLM.from_pretrained(
        model_name, 
        device_map="auto",  # 自动分配设备
        torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
        token=token
    )
    print("Model and tokenizer loaded successfully.")
except Exception as e:
    print(f"Error loading model or tokenizer: {e}")
    raise e

# 定义生成功能的教学内容,保证可重复性

def generate_content_with_parameters(prompt, temperature, max_length):
    try:
        # 将输入移动到模型所在设备
        inputs = tokenizer(prompt, return_tensors="pt", truncation=True).to(device)

        # 生成文本
        outputs = model.generate(
            **inputs,
            max_length=max_length,
            temperature=temperature,  # 格式化传入的温度
            num_return_sequences=1,
            repetition_penalty=1.2  # 增加重复惩罚
        )

        # 解码输出
        return tokenizer.decode(outputs[0], skip_special_tokens=True)
    except Exception as e:
        return f"Error during generation: {e}"

# 创建 Gradio 界面
interface = gr.Interface(
    fn=generate_content_with_parameters,
    inputs=[
        gr.Textbox(label="Prompt", placeholder="Enter your prompt here."),
        gr.Slider(label="Temperature", minimum=0, maximum=1, step=0.1, value=0.7),
        gr.Slider(label="Max Length", minimum=10, maximum=2048, step=10, value=512),
    ],
    outputs="text",
    title="Customizable Text Generator",
    description="Enter a prompt, adjust the temperature and max length, and generate consistent outputs.",
)

# 启动 Gradio 应用
if __name__ == "__main__":
    interface.launch()