pangxiang commited on
Commit
785102a
·
verified ·
1 Parent(s): e34f693

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +340 -56
app.py CHANGED
@@ -1,65 +1,349 @@
1
  import gradio as gr
2
- import torch
3
- from transformers import pipeline
 
 
4
 
5
- # 简单的代码修复演示函数
6
- def code_fix_function(input_code):
7
- """
8
- 简单的代码修复演示
9
- 在实际应用中,这里应该加载你训练好的模型
10
- """
11
- # 这里先使用一个简单的规则作为演示
12
- # 实际应该替换为你的模型预测代码
13
-
14
- # 示例修复规则
15
- fixes = {
16
- "print()": "print()",
17
- "if =": "if condition:",
18
- "for i in range(:": "for i in range():",
19
- }
20
-
21
- # 简单的模式匹配修复
22
- fixed_code = input_code
23
- for error, correction in fixes.items():
24
- if error in input_code:
25
- fixed_code = fixed_code.replace(error, correction)
26
-
27
- # 如果没匹配到特定错误,添加通用建议
28
- if fixed_code == input_code:
29
- fixed_code = input_code + "\n\n# 建议:检查语法错误和缩进"
30
-
31
- return fixed_code
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
- def main():
34
- # 创建Gradio界面
35
- iface = gr.Interface(
36
- fn=code_fix_function,
37
- inputs=gr.Textbox(
38
- lines=10,
39
- placeholder="输入需要修复的代码...",
40
- label="输入代码"
41
- ),
42
- outputs=gr.Textbox(
43
- lines=10,
44
- label="修复后的代码",
45
- show_copy_button=True
46
- ),
47
- title="🐑 Capricode 代码修复助手",
48
- description="输入有问题的代码,获AI修复建议",
49
- examples=[
50
- ["print('Hello World'"], # 缺少右括号
51
- ["if = 10"], # 错误的条件语句
52
- ["for i in range(10"], # 缺少右括号
53
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  )
55
 
56
- # 启动应用
57
- iface.launch(
58
- server_name="0.0.0.0",
59
- server_port=7860,
60
- share=False
 
 
 
 
 
 
 
 
 
 
 
61
  )
62
 
63
  if __name__ == "__main__":
64
- main()
 
 
 
 
65
 
 
1
  import gradio as gr
2
+ import json
3
+ import os
4
+ from datetime import datetime
5
+ from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
6
 
7
+ class SmartCodeFixer:
8
+ def __init__(self):
9
+ self.feedback_file = "user_feedback.json"
10
+ self.model = None
11
+ self.tokenizer = None
12
+ self.load_model()
13
+ self.load_feedback_data()
14
+
15
+ def load_model(self):
16
+ """加载预训练模型"""
17
+ try:
18
+ # 使用一个较好的代码生成模型
19
+ model_name = "microsoft/DialoGPT-medium" # 或者 "codellama/CodeLlama-7b-hf"
20
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
21
+ self.model = AutoModelForCausalLM.from_pretrained(model_name)
22
+ print("模型加载成功!")
23
+ except Exception as e:
24
+ print(f"模型加载失败: {e}")
25
+ self.model = None
26
+
27
+ def load_feedback_data(self):
28
+ """加载用户反馈数据用于学习"""
29
+ if os.path.exists(self.feedback_file):
30
+ with open(self.feedback_file, 'r', encoding='utf-8') as f:
31
+ self.feedback_data = json.load(f)
32
+ else:
33
+ self.feedback_data = []
34
+
35
+ def save_feedback(self, original_code, fixed_code, user_feedback, is_correct):
36
+ """保存用户反馈用于模型进化"""
37
+ feedback_entry = {
38
+ "timestamp": datetime.now().isoformat(),
39
+ "original": original_code,
40
+ "fixed": fixed_code,
41
+ "feedback": user_feedback,
42
+ "correct": is_correct,
43
+ "language": self.detect_language(original_code)
44
+ }
45
+
46
+ self.feedback_data.append(feedback_entry)
47
+
48
+ # 保存到文件
49
+ with open(self.feedback_file, 'w', encoding='utf-8') as f:
50
+ json.dump(self.feedback_data, f, ensure_ascii=False, indent=2)
51
+
52
+ # 定期重新训练模型(简化版)
53
+ if len(self.feedback_data) % 10 == 0: # 每10个反馈重新学习
54
+ self.retrain_from_feedback()
55
+
56
+ def detect_language(self, code):
57
+ """智能检测编程语言"""
58
+ code_lower = code.lower()
59
+
60
+ language_indicators = {
61
+ 'html': ['<!doctype', '<html', '<div', '<span', 'class="', 'id="'],
62
+ 'python': ['def ', 'import ', 'print(', 'if __name__', 'lambda '],
63
+ 'javascript': ['function ', 'console.log', 'document.', 'addEventListener'],
64
+ 'java': ['public class', 'public static', 'System.out.println'],
65
+ 'cpp': ['#include', 'using namespace', 'cout <<', 'std::'],
66
+ 'css': ['{', '}', ':', ';', 'font-size', 'color:']
67
+ }
68
+
69
+ scores = {lang: 0 for lang in language_indicators}
70
+
71
+ for lang, indicators in language_indicators.items():
72
+ for indicator in indicators:
73
+ if indicator in code_lower:
74
+ scores[lang] += 1
75
+
76
+ return max(scores.items(), key=lambda x: x[1])[0]
77
+
78
+ def ai_fix_code(self, code, language):
79
+ """使用AI模型修复代码"""
80
+ if self.model is None:
81
+ return self.rule_based_fix(code, language)
82
+
83
+ try:
84
+ # 构建修复提示
85
+ prompt = f"""修复以下{language}代码的错误:
86
 
87
+ 错误代码:
88
+ ```{language}
89
+ {code}
90
+ """
91
+
92
+ inputs = self.tokenizer.encode(prompt, return_tensors="pt")
93
+ outputs = self.model.generate(
94
+ inputs,
95
+ max_length=len(inputs[0]) + 100,
96
+ num_return_sequences=1,
97
+ temperature=0.7,
98
+ do_sample=True
99
+ )
100
+
101
+ response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
102
+ # 提取修复后的代码
103
+ if "```" in response:
104
+ fixed_code = response.split("```")[2].strip()
105
+ else:
106
+ fixed_code = response.replace(prompt, "").strip()
107
+
108
+ return fixed_code
109
+
110
+ except Exception as e:
111
+ print(f"AI修复失败: {e}")
112
+ return self.rule_based_fix(code, language)
113
+
114
+ def rule_based_fix(self, code, language):
115
+ """基于规则的代码修复"""
116
+ fixes = {
117
+ 'html': self.fix_html,
118
+ 'python': self.fix_python,
119
+ 'javascript': self.fix_javascript,
120
+ 'java': self.fix_java,
121
+ 'cpp': self.fix_cpp,
122
+ 'css': self.fix_css
123
+ }
124
+
125
+ fix_function = fixes.get(language, self.fix_generic)
126
+ return fix_function(code)
127
+
128
+ def fix_html(self, code):
129
+ """修复HTML代码"""
130
+ fixes = []
131
+
132
+ # 检查标签闭合
133
+ if '<div>' in code and '</div>' not in code:
134
+ code += '\n</div>'
135
+ fixes.append("添加了缺失的 </div> 标签")
136
+
137
+ # 检查属性引号
138
+ if 'class=' in code and 'class="' not in code:
139
+ code = code.replace('class=', 'class="')
140
+ if '"' not in code[code.find('class="')+7:code.find('class="')+20]:
141
+ code = code.replace('class="', 'class=""')
142
+ fixes.append("修复了属性引号")
143
+
144
+ # 添加基本的HTML结构
145
+ if '<!DOCTYPE html>' not in code and '<html>' not in code:
146
+ code = f"""<!DOCTYPE html>
147
+ <html>
148
+ <head>
149
+ <meta charset="UTF-8">
150
+ <title>Document</title>
151
+ </head>
152
+ <body>
153
+ {code}
154
+ </body>
155
+ </html>"""
156
+ fixes.append("添加了基本的HTML结构")
157
+
158
+ return code, fixes
159
+
160
+ def fix_python(self, code):
161
+ """修复Python代码"""
162
+ fixes = []
163
+
164
+ # 修复括号
165
+ if code.count('(') > code.count(')'):
166
+ code += ')' * (code.count('(') - code.count(')'))
167
+ fixes.append("修复了不匹配的括号")
168
+
169
+ # 修复引号
170
+ if code.count('"') % 2 != 0:
171
+ code += '"'
172
+ fixes.append("修复了不匹配的双引号")
173
+
174
+ if code.count("'") % 2 != 0:
175
+ code += "'"
176
+ fixes.append("修复了不匹配的单引号")
177
+
178
+ # 修复冒号
179
+ lines = code.split('\n')
180
+ for i, line in enumerate(lines):
181
+ if any(keyword in line for keyword in ['if ', 'for ', 'def ', 'class ', 'while ']) and not line.rstrip().endswith(':'):
182
+ lines[i] = line.rstrip() + ':'
183
+ fixes.append("在条件/函数声明后添加了冒号")
184
+
185
+ return '\n'.join(lines), fixes
186
+
187
+ def fix_javascript(self, code):
188
+ """修复JavaScript代码"""
189
+ fixes = []
190
+
191
+ # 修复括号
192
+ if code.count('(') > code.count(')'):
193
+ code += ')' * (code.count('(') - code.count(')'))
194
+ fixes.append("修复了不匹配的括号")
195
+
196
+ # 修复花括号
197
+ if code.count('{') > code.count('}'):
198
+ code += '}' * (code.count('{') - code.count('}'))
199
+ fixes.append("修复了不匹配的花括号")
200
+
201
+ return code, fixes
202
+
203
+ def fix_css(self, code):
204
+ """修复CSS代码"""
205
+ fixes = []
206
+
207
+ # 修复选择器
208
+ if ':' in code and ';' not in code:
209
+ code += ';'
210
+ fixes.append("添加了缺失的分号")
211
+
212
+ return code, fixes
213
+
214
+ def fix_java(self, code):
215
+ """修复Java代码"""
216
+ fixes = []
217
+
218
+ # 添加基本的类结构
219
+ if 'public class' in code and '{' not in code:
220
+ code = code.replace('public class', 'public class Main {') + '\n public static void main(String[] args) {\n \n }\n}'
221
+ fixes.append("添加了基本的类结构")
222
+
223
+ return code, fixes
224
+
225
+ def fix_cpp(self, code):
226
+ """修复C++代码"""
227
+ fixes = []
228
+
229
+ if '#include' in code and 'int main' not in code:
230
+ code += '\n\nint main() {\n return 0;\n}'
231
+ fixes.append("添加了main函数")
232
+
233
+ return code, fixes
234
+
235
+ def fix_generic(self, code):
236
+ """通用修复"""
237
+ fixes = ["进行了通用语法检查"]
238
+ return code, fixes
239
+
240
+ def retrain_from_feedback(self):
241
+ """根据用户反馈重新训练模型(简化版)"""
242
+ print("正在从用户反馈中学习...")
243
+ # 这里可以添加增量学习的逻辑
244
+ # 目前先记录反馈,后续可以真正重新训练模型
245
+
246
+ # 创建修复器实例
247
+ fixer = SmartCodeFixer()
248
+
249
+ def process_code(input_code, use_ai=True):
250
+ """处理代码修复"""
251
+ language = fixer.detect_language(input_code)
252
+
253
+ if use_ai and fixer.model is not None:
254
+ fixed_code = fixer.ai_fix_code(input_code, language)
255
+ fixes = ["使用AI模型修复"]
256
+ else:
257
+ fixed_code, fixes = fixer.rule_based_fix(input_code, language)
258
+
259
+ # 生成修复报告
260
+ report = f"""🔧 修复报告
261
+ 📝 检测语言: {language}
262
+ ✅ 修复内容: {', '.join(fixes) if fixes else '代码看起来没问题'}
263
+
264
+ 修复后的代码:"""
265
+
266
+ return fixed_code, report
267
+
268
+ def handle_feedback(original_code, fixed_code, user_feedback, is_correct):
269
+ """处理用户反馈"""
270
+ fixer.save_feedback(original_code, fixed_code, user_feedback, is_correct)
271
+ return "感谢您的反馈!系统正在学习改进... 💡"
272
+
273
+ # 创建Gradio界面
274
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
275
+ gr.Markdown("# 🚀 Capricode 智能代码修复助手")
276
+ gr.Markdown("支持 HTML, Python, JavaScript, Java, C++, CSS 等多种语言!")
277
+
278
+ with gr.Tab("代码修复"):
279
+ with gr.Row():
280
+ with gr.Column():
281
+ input_code = gr.Textbox(
282
+ label="📥 输入需要修复的代码",
283
+ placeholder="粘贴你的代码到这里...",
284
+ lines=10
285
+ )
286
+ use_ai = gr.Checkbox(label="使用AI智能修复", value=True)
287
+ fix_btn = gr.Button("🔧 修复代码", variant="primary")
288
+
289
+ with gr.Column():
290
+ output_code = gr.Textbox(
291
+ label="📤 修复后的代码",
292
+ lines=10,
293
+ show_copy_button=True
294
+ )
295
+ report = gr.Textbox(
296
+ label="📊 修复报告",
297
+ lines=3
298
+ )
299
+
300
+ with gr.Tab("反馈学习"):
301
+ gr.Markdown("## 💡 帮助系统变得更好")
302
+ with gr.Row():
303
+ with gr.Column():
304
+ feedback_original = gr.Textbox(label="原始代码", lines=3)
305
+ feedback_fixed = gr.Textbox(label="修复后的代码", lines=3)
306
+ user_feedback = gr.Textbox(
307
+ label="您的反馈建议",
308
+ placeholder="这里可以如何改进?",
309
+ lines=3
310
+ )
311
+ is_correct = gr.Radio(
312
+ choices=[("正确修复", True), ("需要改进", False)],
313
+ label="修复是否正确?"
314
+ )
315
+ feedback_btn = gr.Button("提交反馈", variant="secondary")
316
+ feedback_result = gr.Textbox(label="反馈结果", interactive=False)
317
+
318
+ # 事件处理
319
+ fix_btn.click(
320
+ fn=process_code,
321
+ inputs=[input_code, use_ai],
322
+ outputs=[output_code, report]
323
  )
324
 
325
+ feedback_btn.click(
326
+ fn=handle_feedback,
327
+ inputs=[feedback_original, feedback_fixed, user_feedback, is_correct],
328
+ outputs=[feedback_result]
329
+ )
330
+
331
+ # 示例
332
+ gr.Markdown("## 🎯 试试这些例子:")
333
+ gr.Examples(
334
+ examples=[
335
+ ["<div>Hello World", True], # HTML
336
+ ["print('Hello World'", True], # Python
337
+ ["function test() {", True], # JavaScript
338
+ ["public class MyClass", True], # Java
339
+ ],
340
+ inputs=[input_code, use_ai]
341
  )
342
 
343
  if __name__ == "__main__":
344
+ demo.launch(
345
+ server_name="0.0.0.0",
346
+ server_port=7860,
347
+ share=True
348
+ )
349