limitedonly41 commited on
Commit
f61233e
·
verified ·
1 Parent(s): 81aab5a

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +260 -0
  2. requirements.txt +13 -0
app.py ADDED
@@ -0,0 +1,260 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import spaces
3
+ import asyncio
4
+ import json
5
+ import time
6
+ from typing import List, Dict, Any
7
+ from datetime import datetime, timezone
8
+ import httpx
9
+ from deep_translator import GoogleTranslator
10
+ import torch
11
+ from torch.amp import autocast
12
+ from unsloth import FastLanguageModel
13
+
14
+ # Initialize model globally (outside GPU decorator)
15
+ max_seq_length = 2048
16
+ dtype = None
17
+ load_in_4bit = True
18
+ peft_model_name = "limitedonly41/website_mistral7b_v02"
19
+
20
+ # Load model once at startup
21
+ print("Loading model...")
22
+ model, tokenizer = FastLanguageModel.from_pretrained(
23
+ model_name=peft_model_name,
24
+ max_seq_length=max_seq_length,
25
+ dtype=dtype,
26
+ load_in_4bit=load_in_4bit,
27
+ )
28
+ FastLanguageModel.for_inference(model)
29
+ print("Model loaded successfully")
30
+
31
+ # In-memory storage (replacing Redis)
32
+ task_storage = {}
33
+ task_counter = 0
34
+
35
+ class TaskManager:
36
+ def __init__(self):
37
+ self.tasks = {}
38
+
39
+ def create_task(self, urls: List[str]) -> str:
40
+ global task_counter
41
+ task_counter += 1
42
+ task_id = f"task_{task_counter}"
43
+
44
+ self.tasks[task_id] = {
45
+ "total": len(urls),
46
+ "completed": 0,
47
+ "scraped": 0,
48
+ "status": "processing",
49
+ "urls": urls,
50
+ "results": {},
51
+ "created_time": datetime.now(timezone.utc).isoformat()
52
+ }
53
+ return task_id
54
+
55
+ def update_progress(self, task_id: str, field: str, value: Any):
56
+ if task_id in self.tasks:
57
+ self.tasks[task_id][field] = value
58
+
59
+ def get_task(self, task_id: str) -> Dict:
60
+ return self.tasks.get(task_id, {})
61
+
62
+ task_manager = TaskManager()
63
+
64
+ def translate_text(text: str) -> str:
65
+ """Translate text to English"""
66
+ try:
67
+ text = text[:4990]
68
+ translated_text = GoogleTranslator(source='auto', target='en').translate(text)
69
+ return translated_text
70
+ except Exception as e:
71
+ print(f"Translation error: {e}")
72
+ return text[:4990]
73
+
74
+ @spaces.GPU
75
+ def predict_inference(translated_text: str) -> str:
76
+ """GPU-accelerated inference function"""
77
+ try:
78
+ if len(translated_text) < 150:
79
+ return 'Short'
80
+
81
+ prompt = f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
82
+
83
+ ### Instruction:
84
+ Categorize the website into one of the 3 categories:\n\n1) OTHER \n2) NEWS/BLOG\n3) E-commerce
85
+
86
+ ### Input:
87
+ {translated_text}
88
+
89
+ ### Response:"""
90
+
91
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
92
+
93
+ with autocast(device_type='cuda'):
94
+ inputs = tokenizer(prompt, return_tensors="pt").to(device)
95
+ outputs = model.generate(**inputs, max_new_tokens=16, use_cache=True)
96
+ ans = tokenizer.batch_decode(outputs)[0]
97
+
98
+ ans_pred = ans.split('### Response:')[1].split('<')[0].strip()
99
+
100
+ if 'OTHER' in ans_pred:
101
+ return 'OTHER'
102
+ elif 'NEWS/BLOG' in ans_pred:
103
+ return 'NEWS/BLOG'
104
+ elif 'E-commerce' in ans_pred:
105
+ return 'E-commerce'
106
+ else:
107
+ return 'ERROR'
108
+
109
+ except Exception as e:
110
+ print(f"Inference error: {e}")
111
+ return 'ERROR'
112
+
113
+ async def scrape_single_url(session: httpx.AsyncClient, url: str) -> Dict:
114
+ """Scrape a single URL"""
115
+ try:
116
+ response = await session.get(url, timeout=30.0)
117
+ if response.status_code == 200:
118
+ # Simple text extraction (you can enhance this)
119
+ text_content = response.text[:5000] # Limit content
120
+ return {
121
+ "url": url,
122
+ "text": text_content,
123
+ "status": "success"
124
+ }
125
+ else:
126
+ return {
127
+ "url": url,
128
+ "text": "",
129
+ "status": f"error_{response.status_code}"
130
+ }
131
+ except Exception as e:
132
+ return {
133
+ "url": url,
134
+ "text": "",
135
+ "status": f"error_{str(e)[:100]}"
136
+ }
137
+
138
+ async def process_urls_batch(urls: List[str], progress_callback=None) -> Dict[str, str]:
139
+ """Process a batch of URLs"""
140
+ task_id = task_manager.create_task(urls)
141
+ results = {}
142
+
143
+ async with httpx.AsyncClient() as client:
144
+ for i, url in enumerate(urls):
145
+ try:
146
+ # Scrape URL
147
+ scraped_data = await scrape_single_url(client, url)
148
+ task_manager.update_progress(task_id, "scraped", i + 1)
149
+
150
+ # Process text
151
+ text = scraped_data.get("text", "")
152
+
153
+ if len(text) < 150:
154
+ prediction = "Short"
155
+ else:
156
+ # Translate text
157
+ translated = translate_text(text)
158
+ # Get prediction using GPU
159
+ prediction = predict_inference(translated)
160
+
161
+ results[url] = prediction
162
+ task_manager.update_progress(task_id, "completed", i + 1)
163
+
164
+ # Update progress
165
+ if progress_callback:
166
+ progress = f"Processed {i + 1}/{len(urls)} URLs"
167
+ progress_callback(progress)
168
+
169
+ except Exception as e:
170
+ results[url] = f"Error: {str(e)[:100]}"
171
+
172
+ task_manager.update_progress(task_id, "status", "completed")
173
+ task_manager.update_progress(task_id, "results", results)
174
+
175
+ return results
176
+
177
+ def process_url_list(url_text: str, progress=gr.Progress()) -> str:
178
+ """Main processing function for Gradio interface"""
179
+ if not url_text.strip():
180
+ return "Please provide URLs to process."
181
+
182
+ # Parse URLs
183
+ urls = [url.strip() for url in url_text.strip().split('\n') if url.strip()]
184
+
185
+ if not urls:
186
+ return "No valid URLs found."
187
+
188
+ if len(urls) > 50: # Limit for demo
189
+ return f"Too many URLs ({len(urls)}). Please limit to 50 URLs."
190
+
191
+ try:
192
+ # Process URLs
193
+ progress(0, desc="Starting processing...")
194
+
195
+ def progress_callback(msg):
196
+ progress(None, desc=msg)
197
+
198
+ # Run async function
199
+ loop = asyncio.new_event_loop()
200
+ asyncio.set_event_loop(loop)
201
+ results = loop.run_until_complete(process_urls_batch(urls, progress_callback))
202
+ loop.close()
203
+
204
+ # Format results
205
+ output_lines = []
206
+ for url, prediction in results.items():
207
+ output_lines.append(f"{url} → {prediction}")
208
+
209
+ return "\n".join(output_lines)
210
+
211
+ except Exception as e:
212
+ return f"Error processing URLs: {str(e)}"
213
+
214
+ # Create Gradio interface
215
+ def create_interface():
216
+ with gr.Blocks(title="Website Category Classifier") as interface:
217
+ gr.HTML("<h1>🔍 Website Category Classifier</h1>")
218
+ gr.HTML("<p>Classify websites into categories: OTHER, NEWS/BLOG, or E-commerce</p>")
219
+
220
+ with gr.Row():
221
+ with gr.Column():
222
+ url_input = gr.Textbox(
223
+ label="URLs (one per line)",
224
+ placeholder="https://example1.com\nhttps://example2.com\nhttps://example3.com",
225
+ lines=10,
226
+ max_lines=20
227
+ )
228
+
229
+ process_btn = gr.Button("🚀 Classify Websites", variant="primary")
230
+
231
+ with gr.Column():
232
+ output = gr.Textbox(
233
+ label="Results",
234
+ lines=15,
235
+ max_lines=30,
236
+ interactive=False
237
+ )
238
+
239
+ # Examples
240
+ gr.Examples(
241
+ examples=[
242
+ ["https://news.google.com\nhttps://amazon.com\nhttps://github.com"],
243
+ ["https://techcrunch.com\nhttps://shopify.com\nhttps://stackoverflow.com"]
244
+ ],
245
+ inputs=[url_input],
246
+ )
247
+
248
+ process_btn.click(
249
+ fn=process_url_list,
250
+ inputs=[url_input],
251
+ outputs=[output],
252
+ show_progress=True
253
+ )
254
+
255
+ return interface
256
+
257
+ # Launch the app
258
+ if __name__ == "__main__":
259
+ interface = create_interface()
260
+ interface.launch()
requirements.txt ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio>=4.44.0
2
+ spaces
3
+ torch>=2.1.0,<2.6.0
4
+ transformers>=4.40.0
5
+ unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git
6
+ deep-translator>=1.11.4
7
+ httpx>=0.25.0
8
+ beautifulsoup4>=4.12.0
9
+ accelerate>=0.21.0
10
+ bitsandbytes>=0.41.0
11
+ peft>=0.5.0
12
+ datasets>=2.14.0
13
+ safetensors>=0.3.2