khaledsayed1 commited on
Commit
c81829d
·
verified ·
1 Parent(s): 6230fd9

Upload handler.py

Browse files
Files changed (1) hide show
  1. handler.py +85 -38
handler.py CHANGED
@@ -1,45 +1,92 @@
1
  import torch
 
2
  from transformers import AutoModelForCausalLM, AutoTokenizer
3
 
4
- # Load the model and tokenizer from Hugging Face (with GPU support)
5
- model_name = "khaledsayed1/llama_QA" # Replace with your actual model name
6
- model = AutoModelForCausalLM.from_pretrained(model_name).to("cuda") # Ensure it's loaded on GPU
7
- tokenizer = AutoTokenizer.from_pretrained(model_name)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
  def predict(input_data):
10
  """
11
- Process the input data and generate an answer from the model.
12
- Args:
13
- input_data (dict): The input question.
14
- Returns:
15
- dict: The model's generated answer.
16
- """
17
- question = input_data.get('question', '')
18
- if not question:
19
- return {"error": "No question provided."}
20
-
21
- # Define the prompt with the user's question
22
- formatted_prompt = f"""
23
- السؤال: {question}
24
- الإجابة:
25
  """
26
- inputs = tokenizer([formatted_prompt], return_tensors="pt").to("cuda") # Move inputs to GPU
27
-
28
- try:
29
- # Generate the output using the model
30
- outputs = model.generate(
31
- **inputs,
32
- max_new_tokens=128,
33
- temperature=0.7,
34
- top_k=50,
35
- top_p=0.95,
36
- )
37
- decoded_output = tokenizer.batch_decode(outputs, skip_special_tokens=True)
38
-
39
- # Clean up the output and remove the question itself
40
- clean_output = decoded_output[0].replace("السؤال:", "").replace("الإجابة:", "").strip()
41
-
42
- return {"answer": clean_output}
43
-
44
- except Exception as e:
45
- return {"error": str(e)}
 
1
  import torch
2
+ import os
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
 
5
+ class ModelHandler:
6
+ def __init__(self):
7
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
8
+ self.model = None
9
+ self.tokenizer = None
10
+ self.initialized = False
11
+
12
+ def initialize(self):
13
+ """Initialize the model and tokenizer"""
14
+ if self.initialized:
15
+ return
16
+
17
+ try:
18
+ # Load model and tokenizer from the local path
19
+ model_path = os.path.dirname(os.path.abspath(__file__))
20
+ self.model = AutoModelForCausalLM.from_pretrained(
21
+ model_path,
22
+ device_map="auto",
23
+ torch_dtype=torch.float16 # Use float16 for T4 GPU optimization
24
+ )
25
+ self.tokenizer = AutoTokenizer.from_pretrained(model_path)
26
+ self.initialized = True
27
+ except Exception as e:
28
+ raise RuntimeError(f"Error initializing model: {str(e)}")
29
+
30
+ def predict(self, input_data):
31
+ """
32
+ Process the input data and generate an answer from the model.
33
+ Args:
34
+ input_data (dict): The input question.
35
+ Returns:
36
+ dict: The model's generated answer.
37
+ """
38
+ if not self.initialized:
39
+ self.initialize()
40
+
41
+ try:
42
+ # Extract the question from input_data
43
+ question = input_data.get('question', '')
44
+ if not question:
45
+ return {"error": "No question provided."}
46
+
47
+ # Define the prompt with the user's question
48
+ alpaca_prompt = f"""
49
+ السؤال: {question}
50
+ الإجابة:
51
+ """
52
+ formatted_prompt = alpaca_prompt.strip()
53
+
54
+ # Tokenize the input
55
+ inputs = self.tokenizer([formatted_prompt], return_tensors="pt")
56
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
57
+
58
+ # Generate with proper error handling and memory management
59
+ with torch.no_grad():
60
+ outputs = self.model.generate(
61
+ **inputs,
62
+ max_new_tokens=128,
63
+ temperature=0.7,
64
+ top_k=50,
65
+ top_p=0.95,
66
+ use_cache=True,
67
+ pad_token_id=self.tokenizer.eos_token_id
68
+ )
69
+
70
+ # Decode the output
71
+ decoded_output = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
72
+
73
+ # Clean up the output
74
+ clean_output = decoded_output[0].replace("السؤال:", "").replace("الإجابة:", "").strip()
75
+
76
+ # Clear CUDA cache if using GPU
77
+ if self.device == "cuda":
78
+ torch.cuda.empty_cache()
79
+
80
+ return {"answer": clean_output}
81
+
82
+ except Exception as e:
83
+ return {"error": f"Prediction error: {str(e)}"}
84
+
85
+ # Create a global handler instance
86
+ handler = ModelHandler()
87
 
88
  def predict(input_data):
89
  """
90
+ Wrapper function for the handler's predict method
 
 
 
 
 
 
 
 
 
 
 
 
 
91
  """
92
+ return handler.predict(input_data)