Juna190825 commited on
Commit
818b367
·
verified ·
1 Parent(s): ea2f813

Update Dockerfile

Browse files
Files changed (1) hide show
  1. app.py +38 -6
app.py CHANGED
@@ -1,13 +1,45 @@
1
 
2
  import gradio as gr
3
- from huggingface_hub import InferenceClient
 
4
 
5
- client = InferenceClient("google/gemma-2b-it")
 
6
 
7
- def generate_text(prompt):
8
- response = client.text_generation(prompt, max_new_tokens=50)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  return response
10
 
11
- demo = gr.Interface(fn=generate_text, inputs="text", outputs="text")
12
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
 
1
 
2
  import gradio as gr
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM
4
+ import torch
5
 
6
+ # Load model (will use cached version if available)
7
+ model_id = "meta-llama/Llama-2-7b-chat-hf"
8
 
9
+ # Check for GPU
10
+ device = "cuda" if torch.cuda.is_available() else "cpu"
11
+
12
+ # Load tokenizer and model
13
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
14
+ model = AutoModelForCausalLM.from_pretrained(model_id).to(device)
15
+
16
+ def generate_text(prompt, max_length=200):
17
+ inputs = tokenizer(prompt, return_tensors="pt").to(device)
18
+
19
+ # Generate response
20
+ outputs = model.generate(
21
+ **inputs,
22
+ max_new_tokens=max_length,
23
+ temperature=0.7,
24
+ do_sample=True
25
+ )
26
+
27
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
28
  return response
29
 
30
+ # Create Gradio interface
31
+ with gr.Blocks() as demo:
32
+ gr.Markdown("# LLaMA 2 7B Chat Demo")
33
+ with gr.Row():
34
+ input_text = gr.Textbox(label="Input Prompt", lines=3)
35
+ output_text = gr.Textbox(label="Generated Response", lines=3)
36
+
37
+ generate_btn = gr.Button("Generate")
38
+ generate_btn.click(
39
+ fn=generate_text,
40
+ inputs=input_text,
41
+ outputs=output_text
42
+ )
43
+
44
+ demo.launch(server_name="0.0.0.0", server_port=7860)
45