newtechdevng commited on
Commit
629bec0
·
verified ·
1 Parent(s): 1105083

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -10
app.py CHANGED
@@ -1,23 +1,21 @@
1
  from flask import Flask, request, jsonify
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
3
- from peft import PeftModel
4
  import torch
5
 
6
  app = Flask(__name__)
7
 
8
- BASE_MODEL = "Qwen/Qwen2.5-0.5B-Instruct"
9
- ADAPTER = "newtechdevng/math-tutor-smollm2-360M" # your HF repo
10
 
11
  SYSTEM_PROMPT = "You are a helpful math assistant."
12
 
13
  print("Loading model...")
14
- tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
15
- base_model = AutoModelForCausalLM.from_pretrained(
16
- BASE_MODEL,
17
- torch_dtype=torch.float32,
 
18
  device_map="auto"
19
  )
20
- model = PeftModel.from_pretrained(base_model, ADAPTER)
21
  model.eval()
22
  print("✅ Model ready!")
23
 
@@ -50,11 +48,9 @@ def generate():
50
  **inputs,
51
  max_new_tokens=max_new_tokens,
52
  do_sample=False,
53
- temperature=1.0,
54
  pad_token_id=tokenizer.eos_token_id,
55
  )
56
 
57
- # Decode only the new tokens
58
  new_tokens = outputs[0][inputs["input_ids"].shape[1]:]
59
  answer = tokenizer.decode(new_tokens, skip_special_tokens=True)
60
 
 
1
  from flask import Flask, request, jsonify
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
 
3
  import torch
4
 
5
  app = Flask(__name__)
6
 
7
+ MODEL_ID = "newtechdevng/math-tutor-smollm2-360M" # full model, load directly
 
8
 
9
  SYSTEM_PROMPT = "You are a helpful math assistant."
10
 
11
  print("Loading model...")
12
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
13
+
14
+ model = AutoModelForCausalLM.from_pretrained(
15
+ MODEL_ID,
16
+ dtype=torch.float32,
17
  device_map="auto"
18
  )
 
19
  model.eval()
20
  print("✅ Model ready!")
21
 
 
48
  **inputs,
49
  max_new_tokens=max_new_tokens,
50
  do_sample=False,
 
51
  pad_token_id=tokenizer.eos_token_id,
52
  )
53
 
 
54
  new_tokens = outputs[0][inputs["input_ids"].shape[1]:]
55
  answer = tokenizer.decode(new_tokens, skip_special_tokens=True)
56