AC-1ML commited on
Commit
1aefa2b
·
verified ·
1 Parent(s): cb51eed

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -31
app.py CHANGED
@@ -1,38 +1,19 @@
1
  import gradio as gr
2
- from transformers import AutoTokenizer, AutoModelForCausalLM
3
  import torch
4
 
5
- # Load model and tokenizer
6
- model_name = "mistralai/Mistral-7B-Instruct-v0.1"
7
  tokenizer = AutoTokenizer.from_pretrained(model_name)
8
- model = AutoModelForCausalLM.from_pretrained(
9
- model_name,
10
- torch_dtype=torch.float16,
11
- device_map="auto"
12
- )
13
 
14
- def chat_with_mistral(user_input):
15
- prompt = f"<s>[INST] {user_input.strip()} [/INST]"
16
- inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
17
- output = model.generate(
18
- **inputs,
19
- max_new_tokens=256,
20
- temperature=0.7,
21
- top_p=0.9,
22
- do_sample=True
23
- )
24
- response = tokenizer.decode(output[0], skip_special_tokens=True)
25
- # Remove the prompt portion
26
- if "[/INST]" in response:
27
- response = response.split("[/INST]")[1].strip()
28
- return response
29
 
30
- iface = gr.Interface(
31
- fn=chat_with_mistral,
32
- inputs=gr.Textbox(lines=2, placeholder="Ask something..."),
33
- outputs=gr.Textbox(),
34
- title="Mistral 7B Chatbot",
35
- description="A chatbot powered by Mistral-7B-Instruct-v0.1."
36
- )
37
 
38
- iface.launch()
 
1
  import gradio as gr
2
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
3
  import torch
4
 
5
+ # Load free public model
6
+ model_name = "google/flan-t5-small"
7
  tokenizer = AutoTokenizer.from_pretrained(model_name)
8
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
 
 
 
 
9
 
10
+ def chat(prompt):
11
+ input_ids = tokenizer(prompt, return_tensors="pt").input_ids
12
+ with torch.no_grad():
13
+ outputs = model.generate(input_ids, max_new_tokens=200)
14
+ reply = tokenizer.decode(outputs[0], skip_special_tokens=True)
15
+ return reply
 
 
 
 
 
 
 
 
 
16
 
17
+ demo = gr.Interface(fn=chat, inputs="text", outputs="text", title="FLAN-T5 Chatbot")
 
 
 
 
 
 
18
 
19
+ demo.launch()