Spaces:
Sleeping
Sleeping
| import logging | |
| import torch | |
| logger = logging.getLogger(__name__) | |
| _model = None | |
| _tokenizer = None | |
| _device = "cpu" | |
| def load_model(): | |
| """Load TinyLlama for explanations.""" | |
| global _model, _tokenizer, _device | |
| if _model is not None: | |
| return True | |
| try: | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" | |
| logger.info(f"Loading SLM: {model_name}...") | |
| # Check acceleration | |
| if torch.cuda.is_available(): | |
| _device = "cuda" | |
| _tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| _model = AutoModelForCausalLM.from_pretrained( | |
| model_name, | |
| torch_dtype=torch.float16 if _device == "cuda" else torch.float32, | |
| low_cpu_mem_usage=True | |
| ) | |
| if _device == "cuda": | |
| _model = _model.cuda() | |
| logger.info(f"SLM loaded on {_device}") | |
| return True | |
| except Exception as e: | |
| logger.error(f"Failed to load SLM: {e}") | |
| return False | |
| def explain(url, result): | |
| """Generate explanation using SLM if available, else fallback.""" | |
| if _model: | |
| return _explain_with_slm(url, result) | |
| else: | |
| # Auto-load attempt | |
| if load_model(): | |
| return _explain_with_slm(url, result) | |
| return _template_explain(url, result) | |
| def _explain_with_slm(url, result): | |
| """Generate professional explanation using TinyLlama.""" | |
| findings = result.get("findings", []) | |
| verdict = result.get("verdict", "UNKNOWN") | |
| content_data = result.get("content", {}) | |
| # Content context - Truncated for speed | |
| content_text = content_data.get('text', '')[:300] | |
| content_text = content_data.get('text', '')[:300] | |
| # Handle unknown status in prompt | |
| is_unknown = result.get('is_unknown', False) | |
| findings_context = ", ".join(findings) if findings else "No specific red flags." | |
| if is_unknown and not findings: | |
| findings_context += " Note: This website is not in our trusted database." | |
| findings_str = findings_context | |
| # MAJOR FIX: Added correct spacing and </s> stop tokens for TinyLlama | |
| prompt = f"""<|system|> | |
| You are a Cyber Security Analyst. | |
| Task: Briefly describe what this website appears to be based on the content, and explain WHY it is {verdict}. | |
| Rules: | |
| 1. Do NOT start with "URL:", "Verdict:", or "Signals:". | |
| 2. Do NOT repeat the input data. | |
| 3. Start directly with the explanation. | |
| 4. Be serious and direct.</s> | |
| <|user|> | |
| Analyze this URL: '{url}' | |
| Verdict: {verdict} | |
| Signals: {findings_str} | |
| Content Snippet: {content_text} | |
| Explain this analysis.</s> | |
| <|assistant|>""" | |
| try: | |
| inputs = _tokenizer(prompt, return_tensors="pt").to(_device) | |
| with torch.no_grad(): | |
| outputs = _model.generate( | |
| **inputs, | |
| max_new_tokens=100, # Reduced for speed | |
| do_sample=True, | |
| temperature=0.2, # Lower temp = Less hallucination | |
| top_p=0.9, | |
| repetition_penalty=1.1 # Prevents repeating phrases | |
| ) | |
| # Decode only new tokens | |
| generated_ids = outputs[0][len(inputs['input_ids'][0]):] | |
| response_text = _tokenizer.decode(generated_ids, skip_special_tokens=True) | |
| # Clean up response | |
| full_response = response_text.strip() | |
| # --- FIX: Ensure we end on a complete sentence --- | |
| if "." in full_response: | |
| # Keep everything up to the last period | |
| full_response = full_response.rsplit('.', 1)[0] + "." | |
| # ------------------------------------------------- | |
| return { | |
| "summary": "AI Analysis Completed", | |
| "explanation": full_response, | |
| "advice": "Do not enter credentials." if verdict in ["DANGER", "WARNING"] else "Safe to browse." | |
| } | |
| except Exception as e: | |
| logger.error(f"SLM generation failed: {e}") | |
| return _template_explain(url, result) | |
| def _template_explain(url, result): | |
| """Fallback template explanations - Professional Tone.""" | |
| verdict = result.get("verdict", "UNKNOWN") | |
| findings = result.get("findings", []) | |
| content_data = result.get("content", {}) | |
| issues = "; ".join(findings) if findings else "clean URL structure confirmed" | |
| if verdict == "DANGER": | |
| return { | |
| "summary": "Critical Threat Detected", | |
| "explanation": f"High-risk indicators identified: {issues}. The domain exhibits patterns consistent with phishing or malware delivery.", | |
| "advice": "Do not proceed. Close this page immediately." | |
| } | |
| elif verdict == "WARNING": | |
| return { | |
| "summary": "Suspicious Activity", | |
| "explanation": f"Potential risks detected: {issues}. The site may be legitimate but lacks verified trust indicators.", | |
| "advice": "Proceed only if you trust the source. Do not enter credentials." | |
| } | |
| else: | |
| is_unknown = result.get('is_unknown', False) | |
| explanation = "No malicious patterns detected. The URL structure and domain reputation appear legitimate." | |
| if is_unknown: | |
| explanation = "This website is not in our trusted database, but no malicious patterns were detected. It appears safe." | |
| return { | |
| "summary": "Low Risk / Safe", | |
| "explanation": explanation, | |
| "advice": "Safe to visit." | |
| } | |