umingpeng commited on
Commit
a5a061e
·
verified ·
1 Parent(s): 830c723

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -20
app.py CHANGED
@@ -1,31 +1,34 @@
1
- import gradio as gr
2
  from transformers import AutoModelForCausalLM, AutoTokenizer
 
3
 
4
  # 加载模型和分词器
5
- model_name = "nvidia/NVLM-D-72B"
6
  tokenizer = AutoTokenizer.from_pretrained(model_name)
7
  model = AutoModelForCausalLM.from_pretrained(model_name)
8
 
9
- def generate_response(prompt):
10
- # 编码输入
11
- inputs = tokenizer(prompt, return_tensors="pt")
 
 
 
 
 
 
 
 
 
12
 
13
- # 生成输出
14
- outputs = model.generate(**inputs)
 
15
 
16
- # 解码输出
17
- response = tokenizer.decode(outputs[0], skip_special_tokens=True)
18
  return response
19
 
20
- # 创建 Gradio 接口
21
- iface = gr.Interface(
22
- fn=generate_response,
23
- inputs=gr.inputs.Textbox(lines=2, placeholder="输入你的问题..."),
24
- outputs="text",
25
- title="NVLM-D-72B 交互式问答",
26
- description="使用 NVIDIA 的 NVLM-D-72B 模型进行问答。"
27
- )
28
-
29
- # 启动应用
30
  if __name__ == "__main__":
31
- iface.launch()
 
 
 
 
1
  from transformers import AutoModelForCausalLM, AutoTokenizer
2
+ import torch
3
 
4
  # 加载模型和分词器
5
+ model_name = "umingpeng/Meta-Llama-3.1-8B-Instruct" # 选择合适的模型
6
  tokenizer = AutoTokenizer.from_pretrained(model_name)
7
  model = AutoModelForCausalLM.from_pretrained(model_name)
8
 
9
+ # 设置设备
10
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
+ model.to(device)
12
+
13
+ def generate_response(user_input):
14
+ # 构建输入格式
15
+ messages = [
16
+ {"role": "user", "content": user_input}
17
+ ]
18
+
19
+ # 使用聊天模板处理输入
20
+ tokenized_input = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors="pt").to(device)
21
 
22
+ # 生成响应
23
+ with torch.no_grad():
24
+ output = model.generate(**tokenized_input, max_new_tokens=256)
25
 
26
+ # 解码生成的文本
27
+ response = tokenizer.decode(output[0], skip_special_tokens=True)
28
  return response
29
 
30
+ # 主程序
 
 
 
 
 
 
 
 
 
31
  if __name__ == "__main__":
32
+ user_input = input("你想问什么?")
33
+ response = generate_response(user_input)
34
+ print("助手:", response)