File size: 5,415 Bytes
2e7b8f4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import aiohttp
import asyncio
import socket
import ipaddress
from urllib.parse import urlparse
from typing import Optional, List
from pydantic import BaseModel

from app.models import SourceConfig, SourceType
from app.grabber import GitHubGrabber


class SourceValidationResult(BaseModel):
    valid: bool
    error_message: Optional[str] = None
    proxy_count: int = 0
    sample_proxies: List[str] = []


class SourceValidator:
    def __init__(self, timeout: int = 15):
        self.timeout = aiohttp.ClientTimeout(total=timeout)
        self.grabber = GitHubGrabber()

    def is_internal_url(self, url: str) -> bool:
        """Check if the URL points to an internal network (SSRF protection)."""
        try:
            parsed = urlparse(url)
            hostname = parsed.hostname
            if not hostname:
                return True

            # Check for common internal hostnames
            if hostname.lower() in ["localhost", "127.0.0.1", "::1", "0.0.0.0"]:
                return True

            # Resolve hostname to IP and check if it's private
            addr_info = socket.getaddrinfo(hostname, None)
            for item in addr_info:
                ip = item[4][0]
                if ipaddress.ip_address(ip).is_private:
                    return True
                if ipaddress.ip_address(ip).is_loopback:
                    return True
                if ipaddress.ip_address(ip).is_link_local:
                    return True

            return False
        except Exception:
            # If resolution fails, we'll treat it as potentially unsafe or handle it in reachable check
            return False

    async def validate_url_reachable(self, url: str) -> tuple[bool, Optional[str]]:
        if self.is_internal_url(url):
            return False, "Access to internal networks is restricted (SSRF protection)"

        try:
            async with aiohttp.ClientSession(timeout=self.timeout) as session:
                async with session.get(url, ssl=False) as resp:
                    if resp.status == 200:
                        content_type = resp.headers.get("Content-Type", "")
                        content = await resp.text()

                        if len(content) < 10:
                            return False, "Source content too short (< 10 characters)"

                        if len(content) > 50_000_000:
                            return False, "Source content too large (> 50MB)"

                        return True, None
                    elif resp.status == 404:
                        return False, "Source not found (404)"
                    elif resp.status == 403:
                        return False, "Access forbidden (403)"
                    elif resp.status >= 500:
                        return False, f"Server error ({resp.status})"
                    else:
                        return False, f"HTTP error {resp.status}"

        except asyncio.TimeoutError:
            return False, "Connection timeout - source took too long to respond"
        except aiohttp.ClientConnectorError:
            return False, "Cannot connect to source URL"
        except Exception as e:
            return False, f"Error: {str(e)[:100]}"

    async def validate_source_format(
        self, source: SourceConfig
    ) -> tuple[bool, Optional[str]]:
        url_str = str(source.url)

        if source.type == SourceType.GITHUB_RAW:
            if "github.com" not in url_str:
                return False, "GitHub source must contain 'github.com'"
            if "/raw/" not in url_str and "githubusercontent.com" not in url_str:
                return False, "GitHub source must be a raw file URL"

        elif source.type == SourceType.SUBSCRIPTION_BASE64:
            if not url_str.startswith(("http://", "https://")):
                return False, "Subscription source must start with http:// or https://"

        return True, None

    async def test_proxy_extraction(
        self, source: SourceConfig
    ) -> tuple[int, List[str], Optional[str]]:
        try:
            proxies = await self.grabber.extract_proxies(source)

            if not proxies:
                return 0, [], "No proxies found in source"

            proxy_urls = [p.url for p in proxies[:5]]
            return len(proxies), proxy_urls, None

        except Exception as e:
            return 0, [], f"Failed to extract proxies: {str(e)[:100]}"

    async def validate_source(self, source: SourceConfig) -> SourceValidationResult:
        is_format_valid, format_error = await self.validate_source_format(source)
        if not is_format_valid:
            return SourceValidationResult(valid=False, error_message=format_error)

        is_reachable, reachable_error = await self.validate_url_reachable(
            str(source.url)
        )
        if not is_reachable:
            return SourceValidationResult(valid=False, error_message=reachable_error)

        (
            proxy_count,
            sample_proxies,
            extraction_error,
        ) = await self.test_proxy_extraction(source)
        if extraction_error:
            return SourceValidationResult(
                valid=False, error_message=extraction_error, proxy_count=proxy_count
            )

        return SourceValidationResult(
            valid=True, proxy_count=proxy_count, sample_proxies=sample_proxies
        )


source_validator = SourceValidator()