mset commited on
Commit
31f371a
·
verified ·
1 Parent(s): 0ce8f6a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +380 -290
app.py CHANGED
@@ -1,329 +1,419 @@
1
- import gradio as gr
 
 
 
 
2
  import requests
3
  import re
4
- import xml.etree.ElementTree as ET
 
 
 
5
  import random
6
- from datetime import datetime
7
- from collections import defaultdict, Counter
 
 
8
 
9
- class SimpleQAAI:
10
- def __init__(self):
11
- self.knowledge_base = defaultdict(list)
12
- self.qa_patterns = {}
13
- self.vocabulary = set()
14
- self.total_tokens = 0
15
- self.is_trained = False
16
-
17
- # Initialize with basic Q&A patterns
18
- self.initialize_basic_knowledge()
19
-
20
- def initialize_basic_knowledge(self):
21
- """Initialize with basic Q&A knowledge"""
22
- basic_qa = {
23
- "what is artificial intelligence": "Artificial intelligence is a technology that enables machines to perform tasks that typically require human intelligence, such as learning, reasoning, and problem-solving.",
24
- "how do computers work": "Computers work by processing data through electronic circuits, following programmed instructions to perform calculations and operations.",
25
- "where is paris": "Paris is located in France and serves as the capital city of the country.",
26
- "why is education important": "Education is important because it develops knowledge, critical thinking skills, and prepares people for careers and civic participation.",
27
- "what is machine learning": "Machine learning is a subset of artificial intelligence that allows systems to automatically learn and improve from data without being explicitly programmed.",
28
- "how does the internet work": "The internet works through a global network of interconnected computers that communicate using standardized protocols to share information.",
29
- "what is climate change": "Climate change refers to long-term shifts in global weather patterns and temperatures, largely attributed to human activities.",
30
- "why renewable energy": "Renewable energy is important because it provides sustainable power sources that don't deplete natural resources and help reduce environmental impact."
31
- }
32
-
33
- for question, answer in basic_qa.items():
34
- self.qa_patterns[question] = answer
35
- words = question.split() + answer.split()
36
- self.vocabulary.update(words)
37
-
38
- self.total_tokens = sum(len(answer.split()) for answer in basic_qa.values())
39
- print(f"🧠 Initialized with {len(basic_qa)} Q&A patterns")
40
 
41
- def collect_training_data(self):
42
- """Collect training data from public sources"""
43
- print("🕷️ Collecting training data...")
 
 
 
 
 
 
 
 
 
 
44
 
45
- collected_data = []
 
 
46
 
47
- # Try to collect from news sources
48
- news_data = self.fetch_news_data()
49
- collected_data.extend(news_data)
 
50
 
51
- # Process collected data
52
- if collected_data:
53
- self.process_collected_data(collected_data)
54
- self.is_trained = True
55
- return f"✅ Training completed! Collected {len(collected_data)} articles and {self.total_tokens} total tokens."
56
- else:
57
- # Use fallback training
58
- self.is_trained = True
59
- return "�� Training completed using built-in knowledge patterns!"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
 
61
- def fetch_news_data(self):
62
- """Fetch data from news sources"""
63
- news_sources = [
64
- "https://feeds.reuters.com/reuters/worldNews",
65
- "https://feeds.bbci.co.uk/news/world/rss.xml"
66
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
 
68
- articles = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
 
70
- for source in news_sources:
 
 
 
 
 
 
71
  try:
72
- response = requests.get(source, timeout=5)
73
  if response.status_code == 200:
74
- root = ET.fromstring(response.content)
75
- for item in root.findall(".//item")[:3]: # Limit to 3 per source
76
- title = item.find("title")
77
- if title is not None and title.text:
78
- clean_title = re.sub(r'[^\w\s]', ' ', title.text).strip()
79
- if len(clean_title) > 10:
80
- articles.append(clean_title)
81
- print(f"📰 Collected {len(articles)} articles from {source}")
82
- except Exception as e:
83
- print(f"⚠️ Failed to collect from {source}: {str(e)}")
84
  continue
85
 
86
- return articles
87
-
88
- def process_collected_data(self, data):
89
- """Process collected data into knowledge base"""
90
- for text in data:
91
- # Extract key topics and add to knowledge base
92
- words = text.lower().split()
93
- self.vocabulary.update(words)
94
-
95
- # Simple topic extraction
96
- if any(word in text.lower() for word in ['technology', 'ai', 'computer']):
97
- self.knowledge_base['technology'].append(text)
98
- elif any(word in text.lower() for word in ['climate', 'environment', 'energy']):
99
- self.knowledge_base['environment'].append(text)
100
- elif any(word in text.lower() for word in ['economy', 'market', 'business']):
101
- self.knowledge_base['economy'].append(text)
102
- else:
103
- self.knowledge_base['general'].append(text)
104
-
105
- # Update token count
106
- self.total_tokens += sum(len(text.split()) for text in data)
107
- print(f"📚 Processed data into {len(self.knowledge_base)} knowledge categories")
108
-
109
- def answer_question(self, question):
110
- """Answer a question using available knowledge"""
111
- if not question.strip():
112
- return "Hello! I'm an AI that learns from data. Ask me a question and I'll try to answer based on what I've learned!"
113
 
114
- question_clean = question.lower().strip()
 
115
 
116
- # Direct pattern matching
117
- for pattern, answer in self.qa_patterns.items():
118
- if self.calculate_similarity(question_clean, pattern) > 0.6:
119
- return f"Based on my training: {answer}"
 
 
 
 
 
120
 
121
- # Topic-based responses
122
- topic_response = self.get_topic_response(question_clean)
123
- if topic_response:
124
- return topic_response
 
125
 
126
- # Fallback response
127
- return self.generate_fallback_response(question_clean)
 
 
 
 
 
 
 
 
 
 
 
128
 
129
- def calculate_similarity(self, text1, text2):
130
- """Calculate similarity between two texts"""
131
- words1 = set(text1.split())
132
- words2 = set(text2.split())
 
 
 
 
 
133
 
134
- if not words1 or not words2:
135
- return 0.0
 
136
 
137
- intersection = len(words1.intersection(words2))
138
- union = len(words1.union(words2))
 
 
 
139
 
140
- return intersection / union if union > 0 else 0.0
141
-
142
- def get_topic_response(self, question):
143
- """Get response based on topic matching"""
144
- topic_keywords = {
145
- 'technology': ['technology', 'computer', 'ai', 'artificial', 'machine', 'internet', 'digital'],
146
- 'environment': ['climate', 'environment', 'energy', 'renewable', 'carbon', 'sustainability'],
147
- 'economy': ['economy', 'economic', 'market', 'business', 'finance', 'money'],
148
- 'education': ['education', 'learning', 'school', 'university', 'knowledge', 'study']
149
- }
150
-
151
- # Find matching topic
152
- for topic, keywords in topic_keywords.items():
153
- if any(keyword in question for keyword in keywords):
154
- if topic in self.knowledge_base and self.knowledge_base[topic]:
155
- return f"Based on recent information about {topic}: {self.knowledge_base[topic][0][:150]}..."
156
- else:
157
- return self.get_topic_template_response(topic, question)
158
-
159
- return None
160
-
161
- def get_topic_template_response(self, topic, question):
162
- """Get template response for a topic"""
163
- templates = {
164
- 'technology': "Technology is rapidly evolving and transforming how we work, communicate, and solve problems. Modern technological advances include artificial intelligence, machine learning, and digital innovations.",
165
- 'environment': "Environmental issues like climate change require urgent attention. Solutions include renewable energy adoption, sustainable practices, and reduced carbon emissions.",
166
- 'economy': "Economic factors influence global markets, employment, and business growth. Understanding economic principles helps in making informed decisions.",
167
- 'education': "Education plays a crucial role in personal development and societal progress. It provides knowledge, skills, and opportunities for growth."
168
- }
169
-
170
- base_response = templates.get(topic, "This is an important topic that involves multiple factors and considerations.")
171
-
172
- if '?' in question:
173
- return f"Regarding your question about {topic}: {base_response}"
174
- else:
175
- return f"About {topic}: {base_response}"
176
-
177
- def generate_fallback_response(self, question):
178
- """Generate fallback response for unknown questions"""
179
- fallback_responses = [
180
- "That's an interesting question. Based on general knowledge, this topic involves various factors that need consideration.",
181
- "From what I understand, this subject has multiple aspects worth exploring further.",
182
- "This is a complex topic that relates to several areas of knowledge and research.",
183
- "Based on my training data, this question touches on important concepts that merit detailed analysis."
184
- ]
185
 
186
- return random.choice(fallback_responses)
187
-
188
- def get_system_status(self):
189
- """Get current system status"""
190
- status = "🤖 **SIMPLE Q&A AI STATUS**\n\n"
191
 
192
- if self.is_trained:
193
- status += "✅ **System is trained and ready**\n\n"
194
- else:
195
- status += " **System ready for training**\n\n"
196
-
197
- status += "**📊 Statistics:**\n"
198
- status += f"• **Total tokens processed:** {self.total_tokens:,}\n"
199
- status += f"• **Vocabulary size:** {len(self.vocabulary):,} words\n"
200
- status += f"• **Q&A patterns:** {len(self.qa_patterns)} direct patterns\n"
201
- status += f"• **Knowledge categories:** {len(self.knowledge_base)}\n"
202
- status += f"• **Training status:** {'Completed' if self.is_trained else 'Pending'}\n"
203
-
204
- status += "\n**🎯 Capabilities:**\n"
205
- status += "• Answers questions using pattern matching\n"
206
- status += "• Learns from news articles and data\n"
207
- status += "• Handles multiple topics and domains\n"
208
- status += "• Provides fallback responses for unknown queries\n"
209
-
210
- return status
211
-
212
- # Initialize the AI system
213
- ai_system = SimpleQAAI()
214
-
215
- def start_training():
216
- """Start the training process"""
217
- try:
218
- result = ai_system.collect_training_data()
219
- return result
220
- except Exception as e:
221
- return f"❌ Training failed: {str(e)}"
222
-
223
- def chat_function(message, history):
224
- """Handle chat interactions"""
225
- if not message:
226
- return history, ""
227
-
228
- try:
229
- response = ai_system.answer_question(message)
230
- history.append([message, response])
231
- return history, ""
232
- except Exception as e:
233
- error_response = f"Sorry, I encountered an error: {str(e)}"
234
- history.append([message, error_response])
235
- return history, ""
236
-
237
- def refresh_status():
238
- """Refresh system status"""
239
- return ai_system.get_system_status()
240
-
241
- # Create Gradio interface
242
- with gr.Blocks(theme=gr.themes.Soft(), title="Simple Q&A AI") as app:
243
-
244
- gr.HTML("""
245
- <div style="text-align: center; padding: 20px; background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); color: white; border-radius: 10px; margin-bottom: 20px;">
246
- <h1>🤖 Simple Question Answering AI</h1>
247
- <p><b>Learn from data and answer questions intelligently</b></p>
248
- <p>Stable • Fast • Reliable</p>
249
- </div>
250
- """)
251
-
252
- with gr.Row():
253
- with gr.Column(scale=3):
254
- gr.Markdown("### 💬 Chat with AI")
255
 
256
- chatbot = gr.Chatbot(
257
- value=[],
258
- label="AI Assistant",
259
- height=400
260
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
261
 
262
- msg_input = gr.Textbox(
263
- label="Your Question",
264
- placeholder="Ask me anything: What is AI? How does technology work?",
265
- lines=2
266
- )
267
 
268
- with gr.Row():
269
- send_btn = gr.Button("💬 Send", variant="primary")
270
- clear_btn = gr.Button("🗑️ Clear", variant="secondary")
271
 
272
- with gr.Column(scale=1):
273
- gr.Markdown("### ⚙️ System Control")
 
 
 
 
 
 
 
274
 
275
- status_box = gr.Textbox(
276
- label="System Status",
277
- value=ai_system.get_system_status(),
278
- lines=16,
279
- interactive=False
280
- )
 
281
 
282
- train_btn = gr.Button("🚀 Start Training", variant="primary")
283
- refresh_btn = gr.Button("🔄 Refresh Status", variant="secondary")
 
 
 
284
 
285
- # Example questions
286
- gr.Examples(
287
- examples=[
288
- "What is artificial intelligence?",
289
- "How do computers work?",
290
- "Why is education important?",
291
- "What is climate change?",
292
- "How does the internet work?",
293
- "What is machine learning?"
294
- ],
295
- inputs=msg_input,
296
- label="📝 Try these questions"
297
- )
298
 
299
- # Event handlers
300
- send_btn.click(
301
- fn=chat_function,
302
- inputs=[msg_input, chatbot],
303
- outputs=[chatbot, msg_input]
304
- )
 
 
 
 
305
 
306
- msg_input.submit(
307
- fn=chat_function,
308
- inputs=[msg_input, chatbot],
309
- outputs=[chatbot, msg_input]
310
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
311
 
312
- clear_btn.click(
313
- fn=lambda: ([], ""),
314
- outputs=[chatbot, msg_input]
315
- )
 
 
 
 
316
 
317
- train_btn.click(
318
- fn=start_training,
319
- outputs=[status_box]
320
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
321
 
322
- refresh_btn.click(
323
- fn=refresh_status,
324
- outputs=[status_box]
325
- )
326
 
327
- # Launch the app
328
  if __name__ == "__main__":
329
- app.launch()
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.optim as optim
4
+ from torch.utils.data import Dataset, DataLoader
5
+ import numpy as np
6
  import requests
7
  import re
8
+ import json
9
+ import os
10
+ from collections import Counter
11
+ from typing import List, Tuple, Dict
12
  import random
13
+ import math
14
+ from datasets import load_dataset
15
+ from transformers import AutoTokenizer
16
+ import gradio as gr
17
 
18
+ class SelfOrganizingTokenizer:
19
+ def __init__(self, vocab_size=30000):
20
+ self.vocab_size = vocab_size
21
+ self.token_to_id = {'<PAD>': 0, '<UNK>': 1, '<BOS>': 2, '<EOS>': 3}
22
+ self.id_to_token = {0: '<PAD>', 1: '<UNK>', 2: '<BOS>', 3: '<EOS>'}
23
+ self.word_freq = Counter()
24
+
25
+ def build_vocab(self, texts):
26
+ for text in texts:
27
+ words = re.findall(r'\w+|[^\w\s]', text.lower())
28
+ self.word_freq.update(words)
29
+
30
+ most_common = self.word_freq.most_common(self.vocab_size - 4)
31
+ for i, (word, _) in enumerate(most_common):
32
+ idx = i + 4
33
+ self.token_to_id[word] = idx
34
+ self.id_to_token[idx] = word
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
+ def encode(self, text):
37
+ words = re.findall(r'\w+|[^\w\s]', text.lower())
38
+ return [self.token_to_id.get(word, 1) for word in words]
39
+
40
+ def decode(self, ids):
41
+ return ' '.join([self.id_to_token.get(id, '<UNK>') for id in ids])
42
+
43
+ class SelfOrganizingAttention(nn.Module):
44
+ def __init__(self, embed_dim, num_heads):
45
+ super().__init__()
46
+ self.embed_dim = embed_dim
47
+ self.num_heads = num_heads
48
+ self.head_dim = embed_dim // num_heads
49
 
50
+ self.qkv = nn.Linear(embed_dim, embed_dim * 3)
51
+ self.proj = nn.Linear(embed_dim, embed_dim)
52
+ self.adaptation_layer = nn.Linear(embed_dim, embed_dim)
53
 
54
+ def forward(self, x):
55
+ B, T, C = x.shape
56
+ qkv = self.qkv(x).reshape(B, T, 3, self.num_heads, self.head_dim)
57
+ q, k, v = qkv.permute(2, 0, 3, 1, 4)
58
 
59
+ att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
60
+ att = torch.softmax(att, dim=-1)
61
+
62
+ y = att @ v
63
+ y = y.transpose(1, 2).reshape(B, T, C)
64
+ y = self.proj(y)
65
+
66
+ # Auto-organizzazione
67
+ adaptation = torch.tanh(self.adaptation_layer(x))
68
+ y = y * (1 + 0.1 * adaptation)
69
+
70
+ return y
71
+
72
+ class SelfOrganizingTransformer(nn.Module):
73
+ def __init__(self, vocab_size, embed_dim=512, num_heads=8, num_layers=6, max_len=1024):
74
+ super().__init__()
75
+ self.embed_dim = embed_dim
76
+ self.tok_embed = nn.Embedding(vocab_size, embed_dim)
77
+ self.pos_embed = nn.Embedding(max_len, embed_dim)
78
+
79
+ self.layers = nn.ModuleList([
80
+ nn.ModuleDict({
81
+ 'attn': SelfOrganizingAttention(embed_dim, num_heads),
82
+ 'norm1': nn.LayerNorm(embed_dim),
83
+ 'mlp': nn.Sequential(
84
+ nn.Linear(embed_dim, 4 * embed_dim),
85
+ nn.GELU(),
86
+ nn.Linear(4 * embed_dim, embed_dim),
87
+ ),
88
+ 'norm2': nn.LayerNorm(embed_dim),
89
+ 'adaptation': nn.Linear(embed_dim, embed_dim)
90
+ }) for _ in range(num_layers)
91
+ ])
92
+
93
+ self.ln_f = nn.LayerNorm(embed_dim)
94
+ self.head = nn.Linear(embed_dim, vocab_size)
95
+
96
+ # Parametri per auto-organizzazione
97
+ self.plasticity = nn.Parameter(torch.ones(num_layers) * 0.01)
98
+
99
+ def forward(self, x):
100
+ B, T = x.shape
101
+ pos = torch.arange(0, T, dtype=torch.long, device=x.device)
102
+
103
+ x = self.tok_embed(x) + self.pos_embed(pos)
104
+
105
+ for i, layer in enumerate(self.layers):
106
+ residual = x
107
+ x = layer['norm1'](x)
108
+ x = layer['attn'](x)
109
+
110
+ # Auto-organizzazione adattiva
111
+ adaptation = torch.tanh(layer['adaptation'](x))
112
+ x = residual + x * (1 + self.plasticity[i] * adaptation)
113
+
114
+ residual = x
115
+ x = layer['norm2'](x)
116
+ x = layer['mlp'](x)
117
+ x = residual + x
118
+
119
+ x = self.ln_f(x)
120
+ logits = self.head(x)
121
+ return logits
122
+
123
+ class TextDataset(Dataset):
124
+ def __init__(self, texts, tokenizer, max_len=512):
125
+ self.texts = texts
126
+ self.tokenizer = tokenizer
127
+ self.max_len = max_len
128
+
129
+ def __len__(self):
130
+ return len(self.texts)
131
 
132
+ def __getitem__(self, idx):
133
+ text = self.texts[idx]
134
+ tokens = self.tokenizer.encode(text)
135
+
136
+ if len(tokens) < self.max_len:
137
+ tokens = tokens + [0] * (self.max_len - len(tokens))
138
+ else:
139
+ tokens = tokens[:self.max_len]
140
+
141
+ return torch.tensor(tokens[:-1]), torch.tensor(tokens[1:])
142
+
143
+ class AITrainer:
144
+ def __init__(self):
145
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
146
+ self.tokenizer = None
147
+ self.model = None
148
+ self.datasets = []
149
+
150
+ def load_public_datasets(self):
151
+ """Carica dataset pubblici senza API key"""
152
+ datasets = []
153
 
154
+ try:
155
+ # Wikipedia in italiano
156
+ wiki = load_dataset("wikipedia", "20220301.it", split="train[:10000]")
157
+ for item in wiki:
158
+ if len(item['text']) > 100:
159
+ datasets.append(item['text'])
160
+ except:
161
+ pass
162
+
163
+ try:
164
+ # Common Crawl
165
+ cc = load_dataset("cc100", lang="it", split="train[:5000]")
166
+ for item in cc:
167
+ if len(item['text']) > 100:
168
+ datasets.append(item['text'])
169
+ except:
170
+ pass
171
+
172
+ try:
173
+ # OSCAR
174
+ oscar = load_dataset("oscar-corpus/OSCAR-2201", "it", split="train[:5000]")
175
+ for item in oscar:
176
+ if len(item['text']) > 100:
177
+ datasets.append(item['text'])
178
+ except:
179
+ pass
180
 
181
+ # Dataset di testo semplice da URL pubblici
182
+ urls = [
183
+ "https://www.gutenberg.org/files/2000/2000-0.txt", # Divina Commedia
184
+ "https://www.gutenberg.org/files/1065/1065-0.txt" # I Promessi Sposi
185
+ ]
186
+
187
+ for url in urls:
188
  try:
189
+ response = requests.get(url, timeout=30)
190
  if response.status_code == 200:
191
+ text = response.text
192
+ chunks = [text[i:i+2000] for i in range(0, len(text), 2000)]
193
+ datasets.extend(chunks[:500])
194
+ except:
 
 
 
 
 
 
195
  continue
196
 
197
+ # Genera dati sintetici se necessario
198
+ if len(datasets) < 1000:
199
+ synthetic_texts = self.generate_synthetic_data(5000)
200
+ datasets.extend(synthetic_texts)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
201
 
202
+ self.datasets = datasets[:10000] # Limita a 10k esempi
203
+ print(f"Caricati {len(self.datasets)} esempi di training")
204
 
205
+ def generate_synthetic_data(self, num_samples):
206
+ """Genera dati sintetici per il training"""
207
+ templates = [
208
+ "Il {sostantivo} {verbo} nel {luogo} durante {tempo}.",
209
+ "La {sostantivo} è molto {aggettivo} e {verbo} sempre.",
210
+ "Quando {verbo}, il {sostantivo} diventa {aggettivo}.",
211
+ "Nel {luogo}, la {sostantivo} {verbo} con {sostantivo}.",
212
+ "Il {aggettivo} {sostantivo} {verbo} ogni {tempo}."
213
+ ]
214
 
215
+ sostantivi = ["gatto", "cane", "casa", "albero", "fiume", "montagna", "libro", "sole"]
216
+ verbi = ["corre", "salta", "vola", "nuota", "dorme", "mangia", "gioca", "legge"]
217
+ aggettivi = ["bello", "grande", "piccolo", "veloce", "lento", "intelligente", "forte"]
218
+ luoghi = ["parco", "giardino", "bosco", "città", "mare", "cielo", "campo"]
219
+ tempi = ["giorno", "notte", "mattina", "sera", "inverno", "estate", "primavera"]
220
 
221
+ texts = []
222
+ for _ in range(num_samples):
223
+ template = random.choice(templates)
224
+ text = template.format(
225
+ sostantivo=random.choice(sostantivi),
226
+ verbo=random.choice(verbi),
227
+ aggettivo=random.choice(aggettivi),
228
+ luogo=random.choice(luoghi),
229
+ tempo=random.choice(tempi)
230
+ )
231
+ texts.append(text)
232
+
233
+ return texts
234
 
235
+ def setup_model(self, vocab_size=30000):
236
+ """Configura il modello transformer auto-organizzante"""
237
+ self.model = SelfOrganizingTransformer(
238
+ vocab_size=vocab_size,
239
+ embed_dim=512,
240
+ num_heads=8,
241
+ num_layers=6,
242
+ max_len=512
243
+ ).to(self.device)
244
 
245
+ # Calcola parametri
246
+ total_params = sum(p.numel() for p in self.model.parameters())
247
+ print(f"Modello creato con {total_params:,} parametri")
248
 
249
+ def train(self, epochs=5, batch_size=16, lr=3e-4):
250
+ """Training del modello"""
251
+ print("Inizializzazione tokenizer...")
252
+ self.tokenizer = SelfOrganizingTokenizer()
253
+ self.tokenizer.build_vocab(self.datasets)
254
 
255
+ print("Configurazione modello...")
256
+ self.setup_model(len(self.tokenizer.token_to_id))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
257
 
258
+ print("Preparazione dataset...")
259
+ dataset = TextDataset(self.datasets, self.tokenizer)
260
+ dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
 
 
261
 
262
+ optimizer = optim.AdamW(self.model.parameters(), lr=lr, weight_decay=0.01)
263
+ criterion = nn.CrossEntropyLoss(ignore_index=0)
264
+
265
+ print("Inizio training...")
266
+ self.model.train()
267
+
268
+ for epoch in range(epochs):
269
+ total_loss = 0
270
+ num_batches = 0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
271
 
272
+ for batch_idx, (input_ids, target_ids) in enumerate(dataloader):
273
+ input_ids = input_ids.to(self.device)
274
+ target_ids = target_ids.to(self.device)
275
+
276
+ optimizer.zero_grad()
277
+
278
+ logits = self.model(input_ids)
279
+ loss = criterion(logits.reshape(-1, logits.size(-1)), target_ids.reshape(-1))
280
+
281
+ loss.backward()
282
+ torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
283
+ optimizer.step()
284
+
285
+ total_loss += loss.item()
286
+ num_batches += 1
287
+
288
+ if batch_idx % 50 == 0:
289
+ print(f"Epoch {epoch+1}/{epochs}, Batch {batch_idx}, Loss: {loss.item():.4f}")
290
 
291
+ avg_loss = total_loss / num_batches
292
+ print(f"Epoch {epoch+1}/{epochs} completata. Loss media: {avg_loss:.4f}")
 
 
 
293
 
294
+ # Test generazione
295
+ if epoch % 2 == 0:
296
+ self.test_generation("Il gatto")
297
 
298
+ print("Training completato!")
299
+ self.save_model()
300
+
301
+ def test_generation(self, prompt, max_length=50):
302
+ """Test di generazione testo"""
303
+ self.model.eval()
304
+ with torch.no_grad():
305
+ tokens = self.tokenizer.encode(prompt)
306
+ input_ids = torch.tensor([tokens]).to(self.device)
307
 
308
+ for _ in range(max_length):
309
+ logits = self.model(input_ids)
310
+ next_token = torch.argmax(logits[0, -1, :], dim=-1)
311
+ input_ids = torch.cat([input_ids, next_token.unsqueeze(0).unsqueeze(0)], dim=1)
312
+
313
+ if next_token.item() == self.tokenizer.token_to_id.get('<EOS>', 3):
314
+ break
315
 
316
+ generated = self.tokenizer.decode(input_ids[0].cpu().numpy())
317
+ print(f"Generazione: {generated}")
318
+
319
+ self.model.train()
320
+ return generated
321
 
322
+ def save_model(self):
323
+ """Salva il modello"""
324
+ torch.save({
325
+ 'model_state_dict': self.model.state_dict(),
326
+ 'tokenizer': self.tokenizer,
327
+ 'vocab_size': len(self.tokenizer.token_to_id)
328
+ }, 'ai_model.pth')
329
+ print("Modello salvato in ai_model.pth")
 
 
 
 
 
330
 
331
+ def load_model(self):
332
+ """Carica il modello"""
333
+ if os.path.exists('ai_model.pth'):
334
+ checkpoint = torch.load('ai_model.pth', map_location=self.device)
335
+ self.tokenizer = checkpoint['tokenizer']
336
+ self.setup_model(checkpoint['vocab_size'])
337
+ self.model.load_state_dict(checkpoint['model_state_dict'])
338
+ print("Modello caricato da ai_model.pth")
339
+ return True
340
+ return False
341
 
342
+ def generate_text(self, prompt, max_length=100, temperature=0.8):
343
+ """Genera testo dal prompt"""
344
+ if not self.model or not self.tokenizer:
345
+ return "Modello non caricato. Esegui prima il training."
346
+
347
+ self.model.eval()
348
+ with torch.no_grad():
349
+ tokens = self.tokenizer.encode(prompt)
350
+ input_ids = torch.tensor([tokens]).to(self.device)
351
+
352
+ for _ in range(max_length):
353
+ logits = self.model(input_ids)
354
+ logits = logits[0, -1, :] / temperature
355
+ probs = torch.softmax(logits, dim=-1)
356
+ next_token = torch.multinomial(probs, 1)
357
+
358
+ input_ids = torch.cat([input_ids, next_token.unsqueeze(0)], dim=1)
359
+
360
+ if next_token.item() == self.tokenizer.token_to_id.get('<EOS>', 3):
361
+ break
362
+
363
+ generated = self.tokenizer.decode(input_ids[0].cpu().numpy())
364
+ return generated
365
+
366
+ def create_interface():
367
+ """Crea interfaccia Gradio"""
368
+ trainer = AITrainer()
369
+
370
+ def start_training():
371
+ try:
372
+ trainer.load_public_datasets()
373
+ trainer.train(epochs=3)
374
+ return "Training completato con successo!"
375
+ except Exception as e:
376
+ return f"Errore durante il training: {str(e)}"
377
 
378
+ def generate(prompt, max_len, temp):
379
+ try:
380
+ if not trainer.load_model():
381
+ return "Modello non trovato. Esegui prima il training."
382
+ result = trainer.generate_text(prompt, max_len, temp)
383
+ return result
384
+ except Exception as e:
385
+ return f"Errore nella generazione: {str(e)}"
386
 
387
+ with gr.Blocks(title="AI Token Trainer") as demo:
388
+ gr.Markdown("# AI Training System - Predizione Token")
389
+
390
+ with gr.Tab("Training"):
391
+ train_btn = gr.Button("Avvia Training", variant="primary")
392
+ train_output = gr.Textbox(label="Stato Training", lines=5)
393
+ train_btn.click(start_training, outputs=train_output)
394
+
395
+ with gr.Tab("Generazione"):
396
+ prompt_input = gr.Textbox(label="Prompt", placeholder="Inserisci il testo di partenza...")
397
+ max_len_slider = gr.Slider(10, 200, value=50, label="Lunghezza massima")
398
+ temp_slider = gr.Slider(0.1, 2.0, value=0.8, label="Temperatura")
399
+ generate_btn = gr.Button("Genera Testo", variant="primary")
400
+ output_text = gr.Textbox(label="Testo Generato", lines=10)
401
+
402
+ generate_btn.click(
403
+ generate,
404
+ inputs=[prompt_input, max_len_slider, temp_slider],
405
+ outputs=output_text
406
+ )
407
 
408
+ return demo
 
 
 
409
 
 
410
  if __name__ == "__main__":
411
+ # Training automatico se richiesto
412
+ if len(os.sys.argv) > 1 and os.sys.argv[1] == "train":
413
+ trainer = AITrainer()
414
+ trainer.load_public_datasets()
415
+ trainer.train()
416
+ else:
417
+ # Interfaccia Gradio
418
+ demo = create_interface()
419
+ demo.launch(share=True)