Gajendra5490 commited on
Commit
38304ad
·
verified ·
1 Parent(s): ae003b2

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +192 -0
app.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import json
3
+ import logging
4
+ import gradio as gr
5
+ from pathlib import Path
6
+ from huggingface_hub import hf_hub_download
7
+
8
+ # Configuration constants
9
+ MODEL_ID = "Gajendra5490/Scrached_Trained_Model"
10
+ CURRENT_USER = "gajendra82"
11
+ CURRENT_UTC = "2025-05-06 15:05:18"
12
+
13
+ def setup_logging():
14
+ logging.basicConfig(
15
+ level=logging.INFO,
16
+ format='%(asctime)s - %(levelname)s - %(message)s',
17
+ handlers=[
18
+ logging.FileHandler('inference.log'),
19
+ logging.StreamHandler()
20
+ ]
21
+ )
22
+ return logging.getLogger(__name__)
23
+
24
+ class ModelInference:
25
+ def __init__(self, model_id):
26
+ self.logger = logging.getLogger(__name__)
27
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
28
+ self.model_id = model_id
29
+ self.load_model()
30
+
31
+ def load_model(self):
32
+ try:
33
+ # Download model and tokenizer from Hugging Face
34
+ self.logger.info(f"Downloading model from {self.model_id}")
35
+
36
+ model_path = hf_hub_download(
37
+ repo_id=self.model_id,
38
+ filename="model.pt"
39
+ )
40
+
41
+ tokenizer_path = hf_hub_download(
42
+ repo_id=self.model_id,
43
+ filename="tokenizer.json"
44
+ )
45
+
46
+ # Load model with weights_only=False
47
+ model_data = torch.load(
48
+ model_path,
49
+ map_location=self.device,
50
+ weights_only=False
51
+ )
52
+
53
+ # Load tokenizer
54
+ with open(tokenizer_path, 'r', encoding='utf-8') as f:
55
+ tokenizer_data = json.load(f)
56
+
57
+ # Initialize model
58
+ from model import ImprovedTransformer
59
+ model_config = model_data['model_config']
60
+
61
+ self.model = ImprovedTransformer(
62
+ vocab_size=len(tokenizer_data['vocab']),
63
+ d_model=model_config.get('d_model', 512),
64
+ nhead=model_config.get('nhead', 8),
65
+ num_encoder_layers=model_config.get('num_encoder_layers', 6),
66
+ num_decoder_layers=model_config.get('num_decoder_layers', 6),
67
+ dim_feedforward=model_config.get('dim_feedforward', 2048),
68
+ dropout=model_config.get('dropout', 0.1),
69
+ max_seq_length=model_config.get('max_seq_length', 128)
70
+ ).to(self.device)
71
+
72
+ # Load state dict
73
+ self.model.load_state_dict(model_data['model_state_dict'])
74
+ self.model.eval()
75
+
76
+ # Initialize tokenizer
77
+ from tokenizer import EnhancedTokenizer
78
+ self.tokenizer = EnhancedTokenizer(tokenizer_data['vocab'])
79
+
80
+ self.logger.info("Model loaded successfully")
81
+
82
+ except Exception as e:
83
+ self.logger.error(f"Error loading model: {e}")
84
+ raise
85
+
86
+ @torch.no_grad()
87
+ def generate_answer(self, input_text: str) -> str:
88
+ try:
89
+ # Tokenize input
90
+ input_ids = self.tokenizer.encode(f"<user> {input_text} <sep>")
91
+ input_tensor = torch.tensor([input_ids]).to(self.device)
92
+
93
+ # Initialize response with start token
94
+ response_ids = [self.tokenizer.special_tokens["<assistant>"]]
95
+ response_tensor = torch.tensor([response_ids]).to(self.device)
96
+
97
+ # Generate output
98
+ outputs = self.model(input_tensor, response_tensor)
99
+
100
+ # Get predicted tokens
101
+ predicted_ids = []
102
+ for _ in range(150): # max length
103
+ curr_output = self.model(input_tensor, torch.tensor([response_ids]).to(self.device))
104
+ next_token = curr_output[0, -1].argmax().item()
105
+
106
+ if next_token == self.tokenizer.special_tokens["<eos>"]:
107
+ break
108
+
109
+ response_ids.append(next_token)
110
+
111
+ # Decode output
112
+ answer = self.tokenizer.decode(response_ids)
113
+ answer = answer.replace("<assistant>", "").replace("<eos>", "").strip()
114
+
115
+ return answer
116
+
117
+ except Exception as e:
118
+ self.logger.error(f"Error generating answer: {e}")
119
+ return "Error generating answer"
120
+
121
+ # Initialize model globally
122
+ try:
123
+ print("Loading model from Hugging Face...")
124
+ model = ModelInference(MODEL_ID)
125
+ print("Model loaded successfully")
126
+ except Exception as e:
127
+ print(f"Error loading model: {e}")
128
+ raise
129
+
130
+ def process_input(input_text):
131
+ """Process input through Gradio"""
132
+ try:
133
+ # Log the input
134
+ logger = logging.getLogger(__name__)
135
+ logger.info(f"Input received: {input_text}")
136
+
137
+ # Generate answer
138
+ answer = model.generate_answer(input_text)
139
+
140
+ # Log the output
141
+ logger.info(f"Generated answer: {answer}")
142
+
143
+ return answer
144
+ except Exception as e:
145
+ logger.error(f"Error processing input: {e}")
146
+ return f"Error: {str(e)}"
147
+
148
+ def create_gradio_interface():
149
+ """Create Gradio interface"""
150
+ iface = gr.Interface(
151
+ fn=process_input,
152
+ inputs=gr.Textbox(
153
+ label="Input",
154
+ placeholder="Enter your input here...",
155
+ lines=2
156
+ ),
157
+ outputs=gr.Textbox(
158
+ label="Answer",
159
+ lines=4
160
+ ),
161
+ title="Inference Interface",
162
+ description=f"""
163
+ Model: {MODEL_ID}
164
+ Current User: {CURRENT_USER}
165
+ Last Updated: {CURRENT_UTC} UTC
166
+ """,
167
+ theme=gr.themes.Soft(),
168
+ allow_flagging="never",
169
+ analytics_enabled=False
170
+ )
171
+ return iface
172
+
173
+ def main():
174
+ logger = setup_logging()
175
+ logger.info(f"Starting inference at {CURRENT_UTC}")
176
+ logger.info(f"User: {CURRENT_USER}")
177
+
178
+ try:
179
+ # Create and launch Gradio interface
180
+ iface = create_gradio_interface()
181
+ iface.launch(
182
+ server_name="0.0.0.0",
183
+ server_port=7860,
184
+ share=False
185
+ )
186
+
187
+ except Exception as e:
188
+ logger.error(f"Error in main: {e}")
189
+ print(f"Error: {str(e)}")
190
+
191
+ if __name__ == "__main__":
192
+ main()