VictorM-Coder commited on
Commit
922d792
·
verified ·
1 Parent(s): 1319336

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -6
app.py CHANGED
@@ -2,13 +2,14 @@ import gradio as gr
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
3
  import torch
4
 
5
- # Load model (choose Falcon or Mistral)
6
- model_name = "mistralai/Mistral-7B-Instruct-v0.2"
7
  tokenizer = AutoTokenizer.from_pretrained(model_name)
 
8
  model = AutoModelForCausalLM.from_pretrained(
9
  model_name,
10
- device_map="auto",
11
- torch_dtype=torch.float16
12
  )
13
 
14
  # Humanizer function
@@ -24,7 +25,7 @@ def humanize_text(ai_text):
24
  "Humanized version:"
25
  )
26
 
27
- inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
28
  outputs = model.generate(
29
  **inputs,
30
  max_new_tokens=400,
@@ -41,7 +42,7 @@ def humanize_text(ai_text):
41
  f"Humanized: {humanized}\n\n"
42
  "Feedback:"
43
  )
44
- fb_inputs = tokenizer(feedback_prompt, return_tensors="pt").to("cuda")
45
  fb_outputs = model.generate(**fb_inputs, max_new_tokens=250, temperature=0.8, top_p=0.9)
46
  feedback = tokenizer.decode(fb_outputs[0], skip_special_tokens=True)
47
 
 
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
3
  import torch
4
 
5
+ # Free, open Llama 3 instruct model
6
+ model_name = "meta-llama/Meta-Llama-3-8B-Instruct"
7
  tokenizer = AutoTokenizer.from_pretrained(model_name)
8
+
9
  model = AutoModelForCausalLM.from_pretrained(
10
  model_name,
11
+ device_map="auto", # auto = CPU in free Spaces
12
+ low_cpu_mem_usage=True
13
  )
14
 
15
  # Humanizer function
 
25
  "Humanized version:"
26
  )
27
 
28
+ inputs = tokenizer(prompt, return_tensors="pt")
29
  outputs = model.generate(
30
  **inputs,
31
  max_new_tokens=400,
 
42
  f"Humanized: {humanized}\n\n"
43
  "Feedback:"
44
  )
45
+ fb_inputs = tokenizer(feedback_prompt, return_tensors="pt")
46
  fb_outputs = model.generate(**fb_inputs, max_new_tokens=250, temperature=0.8, top_p=0.9)
47
  feedback = tokenizer.decode(fb_outputs[0], skip_special_tokens=True)
48