Nav772 commited on
Commit
fb6f0ce
·
verified ·
1 Parent(s): ee43ff0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -5
app.py CHANGED
@@ -10,7 +10,7 @@ DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
10
 
11
  # --- Basic Agent Definition ---
12
  # ----- THIS IS WERE YOU CAN BUILD WHAT YOU WANT ------
13
- from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
14
 
15
  class BasicAgent:
16
  def __init__(self):
@@ -18,13 +18,16 @@ class BasicAgent:
18
 
19
  model_id = "mistralai/Mistral-7B-Instruct-v0.1"
20
 
 
21
  self.tokenizer = AutoTokenizer.from_pretrained(model_id)
22
  self.model = AutoModelForCausalLM.from_pretrained(model_id)
 
 
23
  self.pipeline = pipeline(
24
  "text-generation",
25
  model=self.model,
26
  tokenizer=self.tokenizer,
27
- device=-1 # CPU
28
  )
29
 
30
  def __call__(self, question: str) -> str:
@@ -32,13 +35,17 @@ class BasicAgent:
32
 
33
  try:
34
  prompt = f"<s>[INST] {question.strip()} [/INST]"
35
- result = self.pipeline(prompt, max_new_tokens=256, temperature=0.7)
36
- return result[0]["generated_text"].split("[/INST]")[-1].strip()
 
 
 
 
 
37
  except Exception as e:
38
  print(f"❌ Error during model inference: {e}")
39
  return f"❌ Model Error: {str(e)}"
40
 
41
-
42
  def run_and_submit_all( profile: gr.OAuthProfile | None):
43
  """
44
  Fetches all questions, runs the BasicAgent on them, submits all answers,
 
10
 
11
  # --- Basic Agent Definition ---
12
  # ----- THIS IS WERE YOU CAN BUILD WHAT YOU WANT ------
13
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
14
 
15
  class BasicAgent:
16
  def __init__(self):
 
18
 
19
  model_id = "mistralai/Mistral-7B-Instruct-v0.1"
20
 
21
+ # Load model and tokenizer directly
22
  self.tokenizer = AutoTokenizer.from_pretrained(model_id)
23
  self.model = AutoModelForCausalLM.from_pretrained(model_id)
24
+
25
+ # Create inference pipeline
26
  self.pipeline = pipeline(
27
  "text-generation",
28
  model=self.model,
29
  tokenizer=self.tokenizer,
30
+ device=-1
31
  )
32
 
33
  def __call__(self, question: str) -> str:
 
35
 
36
  try:
37
  prompt = f"<s>[INST] {question.strip()} [/INST]"
38
+ output = self.pipeline(prompt, max_new_tokens=256, temperature=0.7)
39
+
40
+ # Extract and clean the response
41
+ generated_text = output[0]["generated_text"]
42
+ response = generated_text.split("[/INST]")[-1].strip()
43
+ return response
44
+
45
  except Exception as e:
46
  print(f"❌ Error during model inference: {e}")
47
  return f"❌ Model Error: {str(e)}"
48
 
 
49
  def run_and_submit_all( profile: gr.OAuthProfile | None):
50
  """
51
  Fetches all questions, runs the BasicAgent on them, submits all answers,