prakhardoneria commited on
Commit
129cbeb
·
verified ·
1 Parent(s): 282328e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -8
app.py CHANGED
@@ -2,26 +2,23 @@ import torch
2
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer
3
  import gradio as gr
4
 
5
- # Use lightweight, public model
6
  model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
7
 
8
  tokenizer = AutoTokenizer.from_pretrained(model_id)
9
  model = AutoModelForCausalLM.from_pretrained(
10
  model_id,
11
- torch_dtype=torch.float32,
12
- device_map="auto"
13
- )
14
 
15
  streamer = TextStreamer(tokenizer, skip_prompt=True)
16
 
17
- # Chat formatting
18
  def chat(message, history):
19
  prompt = ""
20
  for user, bot in history:
21
  prompt += f"<|user|>\n{user.strip()}\n<|assistant|>\n{bot.strip()}\n"
22
  prompt += f"<|user|>\n{message.strip()}\n<|assistant|>\n"
23
 
24
- inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
25
  outputs = model.generate(
26
  **inputs,
27
  max_new_tokens=256,
@@ -34,5 +31,4 @@ def chat(message, history):
34
  reply = text.split("<|assistant|>")[-1].strip()
35
  return reply
36
 
37
- # Gradio UI
38
- gr.ChatInterface(chat, title="TinyLlama Chat", description="Lightweight local LLM (1.1B)").launch()
 
2
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer
3
  import gradio as gr
4
 
 
5
  model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
6
 
7
  tokenizer = AutoTokenizer.from_pretrained(model_id)
8
  model = AutoModelForCausalLM.from_pretrained(
9
  model_id,
10
+ torch_dtype=torch.float32 # use float32 for CPU compatibility
11
+ ).to("cpu")
 
12
 
13
  streamer = TextStreamer(tokenizer, skip_prompt=True)
14
 
 
15
  def chat(message, history):
16
  prompt = ""
17
  for user, bot in history:
18
  prompt += f"<|user|>\n{user.strip()}\n<|assistant|>\n{bot.strip()}\n"
19
  prompt += f"<|user|>\n{message.strip()}\n<|assistant|>\n"
20
 
21
+ inputs = tokenizer(prompt, return_tensors="pt").to("cpu")
22
  outputs = model.generate(
23
  **inputs,
24
  max_new_tokens=256,
 
31
  reply = text.split("<|assistant|>")[-1].strip()
32
  return reply
33
 
34
+ gr.ChatInterface(chat, title="TinyLlama Chat").launch()