Alibrown commited on
Commit
094916c
Β·
verified Β·
1 Parent(s): bf2c919

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +334 -38
src/streamlit_app.py CHANGED
@@ -1,40 +1,336 @@
1
- import altair as alt
2
- import numpy as np
3
- import pandas as pd
4
  import streamlit as st
 
 
 
 
 
 
 
5
 
6
- """
7
- # Welcome to Streamlit!
8
-
9
- Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
11
- forums](https://discuss.streamlit.io).
12
-
13
- In the meantime, below is an example of what you can do with just a few lines of code:
14
- """
15
-
16
- num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
- num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
18
-
19
- indices = np.linspace(0, 1, num_points)
20
- theta = 2 * np.pi * num_turns * indices
21
- radius = indices
22
-
23
- x = radius * np.cos(theta)
24
- y = radius * np.sin(theta)
25
-
26
- df = pd.DataFrame({
27
- "x": x,
28
- "y": y,
29
- "idx": indices,
30
- "rand": np.random.randn(num_points),
31
- })
32
-
33
- st.altair_chart(alt.Chart(df, height=700, width=700)
34
- .mark_point(filled=True)
35
- .encode(
36
- x=alt.X("x", axis=None),
37
- y=alt.Y("y", axis=None),
38
- color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
- size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
40
- ))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # FREE to use under MIT + ESOL v.1.0
2
+ # See https://github.com/volkansah
 
3
  import streamlit as st
4
+ import requests
5
+ import re
6
+ import os
7
+ import tempfile
8
+ from typing import Dict, List, Tuple
9
+ import json
10
+ from huggingface_hub import InferenceClient
11
 
12
+ # ============================================
13
+ # STREAMLIT PERMISSION FIX
14
+ # ============================================
15
+ TEMP_STREAMLIT_HOME = os.path.join(tempfile.gettempdir(), "st_config_workaround")
16
+ os.makedirs(TEMP_STREAMLIT_HOME, exist_ok=True)
17
+ os.environ["STREAMLIT_HOME"] = TEMP_STREAMLIT_HOME
18
+ os.environ["STREAMLIT_GATHER_USAGE_STATS"] = "false"
19
+ CONFIG_PATH = os.path.join(TEMP_STREAMLIT_HOME, "config.toml")
20
+ if not os.path.exists(CONFIG_PATH):
21
+ with open(CONFIG_PATH, "w") as f:
22
+ f.write("[browser]\ngatherUsageStats = false\n")
23
+
24
+ # ============================================
25
+ # LLM-POWERED ANALYZER
26
+ # ============================================
27
+
28
+ class MLRepoAnalyzerLLM:
29
+ def __init__(self, hf_token: str = None):
30
+ self.hf_token = hf_token
31
+ if hf_token:
32
+ self.client = InferenceClient(token=hf_token)
33
+
34
+ # Fallback patterns (wenn kein Token)
35
+ self.fake_indicators = [
36
+ r'openai\.', r'anthropic\.', r'cohere\.',
37
+ r'replicate\.', r'api\.mistral', r'groq\.',
38
+ r'requests\.post.*api', r'urllib.*api'
39
+ ]
40
+ self.legit_indicators = [
41
+ r'torch\.optim', r'loss\.backward\(\)', r'model\.train\(\)',
42
+ r'optimizer\.step\(\)', r'tf\.keras\.optimizers',
43
+ r'from\s+transformers\s+import\s+Trainer',
44
+ r'accelerator\.backward', r'DeepSpeed',
45
+ r'torch\.nn\.Module', r'forward\(self'
46
+ ]
47
+
48
+ def extract_repo_info(self, url: str) -> Tuple[str, str, str]:
49
+ """Extract owner, repo, branch from GitHub URL"""
50
+ pattern = r'github\.com/([^/]+)/([^/]+)(?:/tree/([^/]+))?'
51
+ match = re.search(pattern, url)
52
+ if not match:
53
+ raise ValueError("Invalid GitHub URL")
54
+ owner, repo = match.group(1), match.group(2)
55
+ branch = match.group(3) or 'main'
56
+ return owner, repo.replace('.git', ''), branch
57
+
58
+ def fetch_repo_tree(self, owner: str, repo: str, branch: str) -> List[Dict]:
59
+ """Fetch file tree via GitHub API"""
60
+ api_url = f"https://api.github.com/repos/{owner}/{repo}/git/trees/{branch}?recursive=1"
61
+ response = requests.get(api_url, timeout=10)
62
+ if response.status_code != 200:
63
+ raise Exception(f"GitHub API error: {response.status_code}")
64
+ return response.json().get('tree', [])
65
+
66
+ def fetch_file_content(self, owner: str, repo: str, branch: str, path: str) -> str:
67
+ """Fetch raw file content"""
68
+ raw_url = f"https://raw.githubusercontent.com/{owner}/{repo}/{branch}/{path}"
69
+ response = requests.get(raw_url, timeout=10)
70
+ return response.text if response.status_code == 200 else ""
71
+
72
+ def analyze_with_llm(self, code_snippet: str, filename: str) -> Dict:
73
+ """Use HF Inference API to analyze code"""
74
+ if not self.hf_token:
75
+ return None
76
+
77
+ prompt = f"""Analyze this Python file from a machine learning repository: {filename}
78
+
79
+ Code snippet:
80
+ ```python
81
+ {code_snippet[:2000]} # Limit to avoid token limits
82
+ ```
83
+
84
+ Determine if this is:
85
+ 1. REAL ML TRAINING CODE (contains actual model training, backprop, optimizers)
86
+ 2. API WRAPPER (just calls external APIs like OpenAI, Anthropic, etc.)
87
+ 3. UNCLEAR
88
+
89
+ Respond in JSON format:
90
+ {{
91
+ "classification": "REAL_TRAINING|API_WRAPPER|UNCLEAR",
92
+ "confidence": 0-100,
93
+ "reasoning": "brief explanation",
94
+ "key_indicators": ["indicator1", "indicator2"]
95
+ }}"""
96
+
97
+ try:
98
+ # Use Qwen2.5-Coder or similar code-focused model
99
+ response = self.client.chat_completion(
100
+ messages=[{"role": "user", "content": prompt}],
101
+ model="Qwen/Qwen2.5-Coder-32B-Instruct", # Free on HF Inference
102
+ max_tokens=500,
103
+ temperature=0.1
104
+ )
105
+
106
+ result_text = response.choices[0].message.content
107
+
108
+ # Extract JSON (handle markdown code blocks)
109
+ json_match = re.search(r'```json\s*(\{.*?\})\s*```', result_text, re.DOTALL)
110
+ if json_match:
111
+ return json.loads(json_match.group(1))
112
+ else:
113
+ # Try direct parse
114
+ return json.loads(result_text)
115
+
116
+ except Exception as e:
117
+ st.warning(f"LLM analysis failed for {filename}: {e}")
118
+ return None
119
+
120
+ def analyze_file_structure(self, files: List[Dict]) -> Dict:
121
+ """Quick structure check"""
122
+ py_files = [f for f in files if f['path'].endswith('.py')]
123
+
124
+ return {
125
+ 'has_train_script': any('train' in f['path'].lower() for f in py_files),
126
+ 'has_model_files': any('model' in f['path'].lower() for f in py_files),
127
+ 'has_config': any(f['path'].endswith(('.yaml', '.yml', '.json', '.toml')) for f in files),
128
+ 'has_requirements': any('requirements' in f['path'] or 'pyproject.toml' in f['path'] for f in files),
129
+ 'python_file_count': len(py_files)
130
+ }
131
+
132
+ def analyze_with_patterns(self, content: str) -> Tuple[int, int]:
133
+ """Fallback pattern matching"""
134
+ fake_score = sum(5 for pattern in self.fake_indicators if re.search(pattern, content, re.IGNORECASE))
135
+ legit_score = sum(10 for pattern in self.legit_indicators if re.search(pattern, content, re.IGNORECASE))
136
+ return fake_score, legit_score
137
+
138
+ def classify_repo(self, url: str, use_llm: bool = True) -> Dict:
139
+ """Main classification"""
140
+ try:
141
+ owner, repo, branch = self.extract_repo_info(url)
142
+ files = self.fetch_repo_tree(owner, repo, branch)
143
+
144
+ structure = self.analyze_file_structure(files)
145
+ py_files = [f for f in files if f['path'].endswith('.py')][:10]
146
+
147
+ llm_results = []
148
+ pattern_fake_score = 0
149
+ pattern_legit_score = 0
150
+
151
+ for file_info in py_files:
152
+ content = self.fetch_file_content(owner, repo, branch, file_info['path'])
153
+ if not content:
154
+ continue
155
+
156
+ # LLM Analysis (if token available)
157
+ if use_llm and self.hf_token:
158
+ llm_result = self.analyze_with_llm(content, file_info['path'])
159
+ if llm_result:
160
+ llm_results.append({
161
+ 'file': file_info['path'],
162
+ 'result': llm_result
163
+ })
164
+
165
+ # Pattern fallback
166
+ fake, legit = self.analyze_with_patterns(content)
167
+ pattern_fake_score += fake
168
+ pattern_legit_score += legit
169
+
170
+ # Combine LLM + Pattern results
171
+ if llm_results:
172
+ llm_real_count = sum(1 for r in llm_results if r['result']['classification'] == 'REAL_TRAINING')
173
+ llm_fake_count = sum(1 for r in llm_results if r['result']['classification'] == 'API_WRAPPER')
174
+
175
+ # LLM gets more weight
176
+ total_score = (llm_real_count * 30) - (llm_fake_count * 30) + (pattern_legit_score - pattern_fake_score)
177
+ else:
178
+ total_score = pattern_legit_score - pattern_fake_score
179
+
180
+ # Verdict
181
+ if total_score > 30:
182
+ verdict = "βœ… LEGIT - Real ML Training Code"
183
+ confidence = "High"
184
+ elif total_score > 0:
185
+ verdict = "⚠️ MIXED - Contains some training code"
186
+ confidence = "Medium"
187
+ else:
188
+ verdict = "❌ FAKE - API Wrapper / No Real Training"
189
+ confidence = "High"
190
+
191
+ return {
192
+ 'verdict': verdict,
193
+ 'confidence': confidence,
194
+ 'score': total_score,
195
+ 'structure': structure,
196
+ 'llm_results': llm_results,
197
+ 'pattern_scores': {
198
+ 'fake': pattern_fake_score,
199
+ 'legit': pattern_legit_score
200
+ },
201
+ 'repo_info': f"{owner}/{repo}@{branch}"
202
+ }
203
+
204
+ except Exception as e:
205
+ return {'error': str(e)}
206
+
207
+ # ============================================
208
+ # STREAMLIT UI
209
+ # ============================================
210
+
211
+ st.set_page_config(page_title="ML Repo Detector πŸ”", page_icon="πŸ€–", layout="wide")
212
+
213
+ st.title("πŸ€– ML Training Repo Analyzer (LLM-Powered)")
214
+ st.markdown("**AI-powered detection of fake ML repos using your HuggingFace token**")
215
+
216
+ # Token input in sidebar
217
+ with st.sidebar:
218
+ st.markdown("### πŸ”‘ HuggingFace Setup")
219
+ hf_token = st.text_input(
220
+ "HF Token (optional)",
221
+ type="password",
222
+ help="Get your free token at https://huggingface.co/settings/tokens"
223
+ )
224
+
225
+ use_llm = st.checkbox(
226
+ "Use LLM Analysis",
227
+ value=bool(hf_token),
228
+ disabled=not hf_token,
229
+ help="Requires HF token. Uses Qwen2.5-Coder for deep analysis"
230
+ )
231
+
232
+ st.markdown("---")
233
+ st.markdown("### πŸ› οΈ Models Used")
234
+ if use_llm:
235
+ st.success("βœ… Qwen2.5-Coder-32B (Free)")
236
+ else:
237
+ st.info("πŸ“Š Pattern Matching Only")
238
+
239
+ st.markdown("---")
240
+ st.markdown("### πŸ’‘ How it works")
241
+ st.markdown("""
242
+ **With LLM:**
243
+ - Deep code understanding
244
+ - Context-aware analysis
245
+ - Higher accuracy
246
+
247
+ **Without LLM:**
248
+ - Pattern matching
249
+ - Regex-based detection
250
+ - Still pretty good!
251
+ """)
252
+
253
+ # Main interface
254
+ analyzer = MLRepoAnalyzerLLM(hf_token=hf_token if hf_token else None)
255
+
256
+ repo_url = st.text_input(
257
+ "GitHub Repository URL",
258
+ placeholder="https://github.com/username/repo",
259
+ help="Enter a public GitHub repository URL"
260
+ )
261
+
262
+ col1, col2 = st.columns([1, 4])
263
+ with col1:
264
+ analyze_btn = st.button("πŸš€ Analyze", type="primary", use_container_width=True)
265
+
266
+ if analyze_btn:
267
+ if not repo_url:
268
+ st.error("Enter a GitHub URL!")
269
+ else:
270
+ with st.spinner("πŸ” Analyzing repository..." + (" (using LLM)" if use_llm else " (pattern matching)")):
271
+ result = analyzer.classify_repo(repo_url, use_llm=use_llm and bool(hf_token))
272
+
273
+ if 'error' in result:
274
+ st.error(f"❌ Error: {result['error']}")
275
+ else:
276
+ # Verdict
277
+ st.markdown("---")
278
+ col1, col2, col3 = st.columns([3, 1, 1])
279
+ with col1:
280
+ st.markdown(f"## {result['verdict']}")
281
+ with col2:
282
+ st.metric("Confidence", result['confidence'])
283
+ with col3:
284
+ st.metric("Score", result['score'])
285
+
286
+ # LLM Results
287
+ if result.get('llm_results'):
288
+ st.markdown("### πŸ€– LLM Analysis Results")
289
+ for llm_res in result['llm_results'][:5]:
290
+ with st.expander(f"πŸ“„ {llm_res['file']}"):
291
+ res = llm_res['result']
292
+
293
+ col1, col2 = st.columns(2)
294
+ with col1:
295
+ classification = res.get('classification', 'UNKNOWN')
296
+ if classification == 'REAL_TRAINING':
297
+ st.success(f"βœ… {classification}")
298
+ elif classification == 'API_WRAPPER':
299
+ st.error(f"❌ {classification}")
300
+ else:
301
+ st.warning(f"⚠️ {classification}")
302
+
303
+ with col2:
304
+ st.metric("Confidence", f"{res.get('confidence', 0)}%")
305
+
306
+ st.markdown(f"**Reasoning:** {res.get('reasoning', 'N/A')}")
307
+
308
+ if res.get('key_indicators'):
309
+ st.markdown("**Key Indicators:**")
310
+ for indicator in res['key_indicators']:
311
+ st.markdown(f"- {indicator}")
312
+
313
+ # Pattern Analysis (fallback/additional)
314
+ st.markdown("### πŸ“Š Pattern Analysis")
315
+ col1, col2 = st.columns(2)
316
+ with col1:
317
+ st.metric("Legit Patterns", result['pattern_scores']['legit'])
318
+ with col2:
319
+ st.metric("Fake Patterns", result['pattern_scores']['fake'])
320
+
321
+ # Structure
322
+ st.markdown("### πŸ“ Repository Structure")
323
+ struct = result['structure']
324
+ cols = st.columns(4)
325
+ with cols[0]:
326
+ st.metric("Python Files", struct['python_file_count'])
327
+ with cols[1]:
328
+ st.write("βœ…" if struct['has_train_script'] else "❌", "train.py")
329
+ with cols[2]:
330
+ st.write("βœ…" if struct['has_model_files'] else "❌", "model files")
331
+ with cols[3]:
332
+ st.write("βœ…" if struct['has_config'] else "❌", "configs")
333
+
334
+ # Footer
335
+ st.markdown("---")
336
+ st.markdown("**πŸ’‘ Your HF token = your quota. No data stored. Analysis runs on HF's free inference API.**")