limitedonly41 commited on
Commit
f1aa3d7
Β·
verified Β·
1 Parent(s): 4eaaba6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -89
app.py CHANGED
@@ -8,40 +8,25 @@ from datetime import datetime, timezone
8
  import httpx
9
  from deep_translator import GoogleTranslator
10
  import torch
11
- from transformers import AutoTokenizer, AutoModelForCausalLM
12
-
13
- # Initialize model globals without unsloth
14
- model_name = "limitedonly41/mistral7b_v3_4_categories"
15
- model = None
16
- tokenizer = None
17
-
18
- def load_model():
19
- """Load model without unsloth"""
20
- global model, tokenizer
21
-
22
- print("Loading model with transformers...")
23
- try:
24
- tokenizer = AutoTokenizer.from_pretrained(model_name)
25
- model = AutoModelForCausalLM.from_pretrained(
26
- model_name,
27
- torch_dtype=torch.float16,
28
- device_map="auto",
29
- load_in_4bit=True,
30
- trust_remote_code=True
31
- )
32
-
33
- # Set pad token
34
- if tokenizer.pad_token is None:
35
- tokenizer.pad_token = tokenizer.eos_token
36
-
37
- print("Model loaded successfully")
38
- return True
39
- except Exception as e:
40
- print(f"Model loading error: {e}")
41
- return False
42
-
43
- # Try to load model at startup
44
- model_loaded = load_model()
45
 
46
  # In-memory storage (replacing Redis)
47
  task_storage = {}
@@ -88,69 +73,38 @@ def translate_text(text: str) -> str:
88
 
89
  @spaces.GPU
90
  def predict_inference(translated_text: str) -> str:
91
- """GPU-accelerated inference function using transformers"""
92
  try:
93
- global model, tokenizer
94
-
95
- if not model_loaded or model is None or tokenizer is None:
96
- return 'MODEL_ERROR'
97
-
98
  if len(translated_text) < 150:
99
  return 'Short'
100
 
101
  prompt = f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
102
 
103
  ### Instruction:
104
- Categorize the website into one of the 4 categories:\n\n1) OTHER\n2) NEWS/BLOG\n3) E-commerce\n4) COMPANIES
105
 
106
  ### Input:
107
  {translated_text}
108
 
109
  ### Response:"""
110
 
111
- # Tokenize input
112
- inputs = tokenizer(
113
- prompt,
114
- return_tensors="pt",
115
- max_length=2048,
116
- truncation=True,
117
- padding=True
118
- )
119
 
120
- # Move to GPU
121
- if torch.cuda.is_available():
122
- inputs = {k: v.to('cuda') for k, v in inputs.items()}
123
-
124
- # Generate response
125
- with torch.no_grad():
126
- outputs = model.generate(
127
- **inputs,
128
- max_new_tokens=16,
129
- temperature=0.1,
130
- do_sample=False,
131
- pad_token_id=tokenizer.eos_token_id
132
- )
133
-
134
- # Decode response
135
- generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
136
-
137
- # Extract prediction
138
- if '### Response:' in generated_text:
139
- ans_pred = generated_text.split('### Response:')[1].strip()
140
- else:
141
- ans_pred = generated_text.split(prompt)[1].strip() if prompt in generated_text else generated_text
142
 
143
- # Clean and categorize
144
- ans_pred = ans_pred.split('<')[0].strip()
145
 
146
- if 'OTHER' in ans_pred.upper():
147
  return 'OTHER'
148
- elif 'NEWS/BLOG' in ans_pred.upper() or 'NEWS' in ans_pred.upper() or 'BLOG' in ans_pred.upper():
149
  return 'NEWS/BLOG'
150
- elif 'E-COMMERCE' in ans_pred.upper() or 'ECOMMERCE' in ans_pred.upper():
151
  return 'E-commerce'
152
  else:
153
- return 'OTHER' # Default fallback
154
 
155
  except Exception as e:
156
  print(f"Inference error: {e}")
@@ -159,10 +113,10 @@ Categorize the website into one of the 4 categories:\n\n1) OTHER\n2) NEWS/BLOG\n
159
  async def scrape_single_url(session: httpx.AsyncClient, url: str) -> Dict:
160
  """Scrape a single URL"""
161
  try:
162
- response = await session.get(url, timeout=30.0, follow_redirects=True)
163
  if response.status_code == 200:
164
- # Simple text extraction
165
- text_content = response.text[:5000]
166
  return {
167
  "url": url,
168
  "text": text_content,
@@ -222,9 +176,6 @@ async def process_urls_batch(urls: List[str], progress_callback=None) -> Dict[st
222
 
223
  def process_url_list(url_text: str, progress=gr.Progress()) -> str:
224
  """Main processing function for Gradio interface"""
225
- if not model_loaded:
226
- return "❌ Model loading failed. Please check the logs and try again."
227
-
228
  if not url_text.strip():
229
  return "Please provide URLs to process."
230
 
@@ -234,8 +185,8 @@ def process_url_list(url_text: str, progress=gr.Progress()) -> str:
234
  if not urls:
235
  return "No valid URLs found."
236
 
237
- if len(urls) > 20: # Reduced limit for stability
238
- return f"Too many URLs ({len(urls)}). Please limit to 20 URLs."
239
 
240
  try:
241
  # Process URLs
@@ -262,12 +213,9 @@ def process_url_list(url_text: str, progress=gr.Progress()) -> str:
262
 
263
  # Create Gradio interface
264
  def create_interface():
265
- status_msg = "βœ… Model loaded successfully" if model_loaded else "❌ Model loading failed"
266
-
267
  with gr.Blocks(title="Website Category Classifier") as interface:
268
  gr.HTML("<h1>πŸ” Website Category Classifier</h1>")
269
- gr.HTML(f"<p>Classify websites into categories: OTHER, NEWS/BLOG, or E-commerce</p>")
270
- gr.HTML(f"<p><strong>Status:</strong> {status_msg}</p>")
271
 
272
  with gr.Row():
273
  with gr.Column():
 
8
  import httpx
9
  from deep_translator import GoogleTranslator
10
  import torch
11
+ from torch.amp import autocast
12
+ from unsloth import FastLanguageModel
13
+
14
+ # Initialize model globally (outside GPU decorator)
15
+ max_seq_length = 2048
16
+ dtype = None
17
+ load_in_4bit = True
18
+ peft_model_name = "limitedonly41/website_mistral7b_v02"
19
+
20
+ # Load model once at startup
21
+ print("Loading model...")
22
+ model, tokenizer = FastLanguageModel.from_pretrained(
23
+ model_name=peft_model_name,
24
+ max_seq_length=max_seq_length,
25
+ dtype=dtype,
26
+ load_in_4bit=load_in_4bit,
27
+ )
28
+ FastLanguageModel.for_inference(model)
29
+ print("Model loaded successfully")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
  # In-memory storage (replacing Redis)
32
  task_storage = {}
 
73
 
74
  @spaces.GPU
75
  def predict_inference(translated_text: str) -> str:
76
+ """GPU-accelerated inference function"""
77
  try:
 
 
 
 
 
78
  if len(translated_text) < 150:
79
  return 'Short'
80
 
81
  prompt = f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
82
 
83
  ### Instruction:
84
+ Categorize the website into one of the 3 categories:\n\n1) OTHER \n2) NEWS/BLOG\n3) E-commerce
85
 
86
  ### Input:
87
  {translated_text}
88
 
89
  ### Response:"""
90
 
91
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
 
 
 
 
 
 
92
 
93
+ with autocast(device_type='cuda'):
94
+ inputs = tokenizer(prompt, return_tensors="pt").to(device)
95
+ outputs = model.generate(**inputs, max_new_tokens=16, use_cache=True)
96
+ ans = tokenizer.batch_decode(outputs)[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
 
98
+ ans_pred = ans.split('### Response:')[1].split('<')[0].strip()
 
99
 
100
+ if 'OTHER' in ans_pred:
101
  return 'OTHER'
102
+ elif 'NEWS/BLOG' in ans_pred:
103
  return 'NEWS/BLOG'
104
+ elif 'E-commerce' in ans_pred:
105
  return 'E-commerce'
106
  else:
107
+ return 'ERROR'
108
 
109
  except Exception as e:
110
  print(f"Inference error: {e}")
 
113
  async def scrape_single_url(session: httpx.AsyncClient, url: str) -> Dict:
114
  """Scrape a single URL"""
115
  try:
116
+ response = await session.get(url, timeout=30.0)
117
  if response.status_code == 200:
118
+ # Simple text extraction (you can enhance this)
119
+ text_content = response.text[:5000] # Limit content
120
  return {
121
  "url": url,
122
  "text": text_content,
 
176
 
177
  def process_url_list(url_text: str, progress=gr.Progress()) -> str:
178
  """Main processing function for Gradio interface"""
 
 
 
179
  if not url_text.strip():
180
  return "Please provide URLs to process."
181
 
 
185
  if not urls:
186
  return "No valid URLs found."
187
 
188
+ if len(urls) > 50: # Limit for demo
189
+ return f"Too many URLs ({len(urls)}). Please limit to 50 URLs."
190
 
191
  try:
192
  # Process URLs
 
213
 
214
  # Create Gradio interface
215
  def create_interface():
 
 
216
  with gr.Blocks(title="Website Category Classifier") as interface:
217
  gr.HTML("<h1>πŸ” Website Category Classifier</h1>")
218
+ gr.HTML("<p>Classify websites into categories: OTHER, NEWS/BLOG, or E-commerce</p>")
 
219
 
220
  with gr.Row():
221
  with gr.Column():