arhamTariq commited on
Commit
e94343a
·
verified ·
1 Parent(s): 490dd4d

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +48 -10
app.py CHANGED
@@ -1,16 +1,54 @@
 
1
  import os
2
- from huggingface_hub import InferenceClient
3
 
4
- api_token = os.environ.get("HUGGINGFACEHUB_API_TOKEN")
 
 
5
 
6
- client = InferenceClient(
7
- model="google/flan-t5-large",
8
- token=api_token
9
- )
10
 
11
- prompt = "Hello world"
 
12
 
13
- # Use text generation method instead
14
- response = client.text_generation(prompt)
15
 
16
- print(response)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # import python-dotenv
2
  import os
3
+ from dotenv import load_dotenv
4
 
5
+ # import from huggingface
6
+ from transformers import AutoTokenizer, AutoModelForCausalLM
7
+ import torch
8
 
9
+ # import regex for clean response
10
+ import re
 
 
11
 
12
+ # import gradio for gui
13
+ import gradio as gr
14
 
 
 
15
 
16
+ # take environment variables from .env file
17
+ load_dotenv()
18
+ token = os.getenv('HUGGINGFACEHUB_API_TOKEN')
19
+
20
+ # set some stuffs
21
+ model_id = "google/gemma-2b-it"
22
+ dtype = torch.bfloat16
23
+
24
+ # start with chat
25
+ def gemma_chat(message, history):
26
+ tokenizer = AutoTokenizer.from_pretrained(model_id, token=token)
27
+ model = AutoModelForCausalLM.from_pretrained(
28
+ model_id,
29
+ token=token,
30
+ hidden_activation="gelu_pytorch_tanh",
31
+ device_map="cuda",
32
+ torch_dtype=dtype,
33
+ )
34
+
35
+ chat = [
36
+ { "role": "user", "content": message },
37
+ ]
38
+
39
+ prompt = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
40
+
41
+ inputs = tokenizer.encode(prompt, add_special_tokens=False, return_tensors="pt")
42
+ outputs = model.generate(input_ids=inputs.to(model.device), max_new_tokens=2048)
43
+
44
+ response = tokenizer.decode(outputs[0])
45
+
46
+
47
+ # clean the response
48
+ response_cleaned = re.split("model", response)
49
+
50
+ # return the response
51
+ return response_cleaned[1]
52
+
53
+
54
+ gr.ChatInterface(gemma_chat).launch()