Spectral99 commited on
Commit
0e9284b
·
verified ·
1 Parent(s): 1462060

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -32
app.py CHANGED
@@ -2,14 +2,13 @@
2
  # Imports
3
  # ===============================
4
  import torch
5
- import spaces
6
  from transformers import AutoTokenizer, AutoModelForCausalLM
7
 
8
 
9
  # ===============================
10
- # Load Model & Tokenizer (ONCE)
11
  # ===============================
12
- print("Loading Aksara v1 model...")
13
 
14
  tokenizer = AutoTokenizer.from_pretrained(
15
  "cropinailab/aksara_v1",
@@ -18,19 +17,17 @@ tokenizer = AutoTokenizer.from_pretrained(
18
 
19
  model = AutoModelForCausalLM.from_pretrained(
20
  "cropinailab/aksara_v1",
21
- torch_dtype=torch.bfloat16,
22
- device_map="cuda", # ensures GPU usage on Spaces
23
  )
24
 
25
  model.eval()
26
 
27
- print("Model loaded successfully!")
28
 
29
 
30
  # ===============================
31
- # Generation Function (GPU)
32
  # ===============================
33
- @spaces.GPU
34
  def generate_agri_response(plant, disease):
35
  prompt = f"""
36
  You are an agricultural expert specializing in plant pathology, crop nutrition, and safe farm management.
@@ -135,29 +132,22 @@ Include:
135
  - Proper storage of harvested produce (if disease affects storage)
136
  """
137
 
138
- # Tokenize
139
- inputs = tokenizer(
140
- prompt,
141
- return_tensors="pt",
142
- ).to(model.device)
143
-
144
- # Generate
145
- outputs = model.generate(
146
- **inputs,
147
- max_new_tokens=600,
148
- temperature=0.7,
149
- top_p=0.9,
150
- repetition_penalty=1.15,
151
- do_sample=True,
152
- pad_token_id=tokenizer.eos_token_id,
153
- )
154
-
155
- full_output = tokenizer.decode(
156
- outputs[0],
157
- skip_special_tokens=True,
158
- )
159
-
160
- # Remove echoed prompt safely
161
  if full_output.startswith(prompt):
162
  cleaned = full_output[len(prompt):].strip()
163
  else:
@@ -167,7 +157,7 @@ Include:
167
 
168
 
169
  # ===============================
170
- # Local Test (optional)
171
  # ===============================
172
  if __name__ == "__main__":
173
  print(generate_agri_response("Potato", "Late Blight"))
 
2
  # Imports
3
  # ===============================
4
  import torch
 
5
  from transformers import AutoTokenizer, AutoModelForCausalLM
6
 
7
 
8
  # ===============================
9
+ # Load Model & Tokenizer (CPU)
10
  # ===============================
11
+ print("Loading Aksara v1 model (CPU)...")
12
 
13
  tokenizer = AutoTokenizer.from_pretrained(
14
  "cropinailab/aksara_v1",
 
17
 
18
  model = AutoModelForCausalLM.from_pretrained(
19
  "cropinailab/aksara_v1",
20
+ torch_dtype=torch.float32, # CPU-safe dtype
 
21
  )
22
 
23
  model.eval()
24
 
25
+ print("Model loaded successfully on CPU!")
26
 
27
 
28
  # ===============================
29
+ # Generation Function
30
  # ===============================
 
31
  def generate_agri_response(plant, disease):
32
  prompt = f"""
33
  You are an agricultural expert specializing in plant pathology, crop nutrition, and safe farm management.
 
132
  - Proper storage of harvested produce (if disease affects storage)
133
  """
134
 
135
+ inputs = tokenizer(prompt, return_tensors="pt")
136
+
137
+ with torch.no_grad(): # important for CPU efficiency
138
+ outputs = model.generate(
139
+ **inputs,
140
+ max_new_tokens=600,
141
+ temperature=0.7,
142
+ top_p=0.9,
143
+ repetition_penalty=1.15,
144
+ do_sample=True,
145
+ pad_token_id=tokenizer.eos_token_id,
146
+ )
147
+
148
+ full_output = tokenizer.decode(outputs[0], skip_special_tokens=True)
149
+
150
+ # Remove echoed prompt
 
 
 
 
 
 
 
151
  if full_output.startswith(prompt):
152
  cleaned = full_output[len(prompt):].strip()
153
  else:
 
157
 
158
 
159
  # ===============================
160
+ # Test
161
  # ===============================
162
  if __name__ == "__main__":
163
  print(generate_agri_response("Potato", "Late Blight"))