Anupam007 commited on
Commit
f5bdf65
·
verified ·
1 Parent(s): ce9ac5a

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +311 -0
app.py ADDED
@@ -0,0 +1,311 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import gradio as gr
3
+ import matplotlib.pyplot as plt
4
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, Seq2SeqTrainingArguments, Seq2SeqTrainer, DataCollatorForSeq2Seq
5
+ from datasets import Dataset
6
+ import os
7
+ import base64
8
+ import io
9
+ import requests
10
+ #from IPython.display import display, Markdown, HTML # Remove IPython dependency
11
+ import time
12
+
13
+ # Check if GPU is available
14
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
+ print(f"Using device: {device}")
16
+
17
+ ## Loading the Pre-trained Model
18
+
19
+ model_name = "facebook/bart-large" # You could also use "t5-base" or other seq2seq models
20
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
21
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(device)
22
+
23
+ ## Define Training Data (Optional for Fine-tuning)
24
+
25
+ # Sample training data: [(text_description, mermaid_code), ...]
26
+ training_data = [
27
+ (
28
+ "A flowchart showing user login process with success and failure paths",
29
+ """graph TD
30
+ A[Start] --> B{User has account?}
31
+ B -->|Yes| C[Enter credentials]
32
+ B -->|No| D[Register]
33
+ C --> E{Valid credentials?}
34
+ E -->|Yes| F[Login successful]
35
+ E -->|No| G[Login failed]
36
+ D --> C
37
+ """
38
+ ),
39
+ (
40
+ "A sequence diagram showing client-server authentication",
41
+ """sequenceDiagram
42
+ participant Client
43
+ participant Server
44
+ Client->>Server: Authentication Request
45
+ Server->>Client: Challenge
46
+ Client->>Server: Challenge Response
47
+ Server->>Client: Auth Success/Failure
48
+ """
49
+ ),
50
+ (
51
+ "A simple entity relationship diagram for a blog system",
52
+ """erDiagram
53
+ AUTHOR ||--o{ POST : writes
54
+ POST ||--o{ COMMENT : contains
55
+ AUTHOR ||--o{ COMMENT : writes
56
+ """
57
+ ),
58
+ # Add more examples for better fine-tuning
59
+ ]
60
+
61
+ ## Fine-tuning (Optional but Recommended)
62
+
63
+ def fine_tune_model():
64
+ # Prepare dataset for fine-tuning
65
+ dataset_dict = {
66
+ "input_text": [item[0] for item in training_data],
67
+ "target_text": [item[1] for item in training_data]
68
+ }
69
+
70
+ dataset = Dataset.from_dict(dataset_dict)
71
+
72
+ # Tokenize the dataset
73
+ def preprocess_function(examples):
74
+ inputs = examples["input_text"]
75
+ targets = examples["target_text"]
76
+ model_inputs = tokenizer(inputs, max_length=128, truncation=True, padding="max_length")
77
+
78
+ with tokenizer.as_target_tokenizer():
79
+ labels = tokenizer(targets, max_length=256, truncation=True, padding="max_length")
80
+
81
+ model_inputs["labels"] = labels["input_ids"]
82
+ return model_inputs
83
+
84
+ tokenized_dataset = dataset.map(preprocess_function, batched=True)
85
+
86
+ # Define training arguments
87
+ training_args = Seq2SeqTrainingArguments(
88
+ output_dir="./results",
89
+ evaluation_strategy="epoch",
90
+ learning_rate=5e-5,
91
+ per_device_train_batch_size=4,
92
+ per_device_eval_batch_size=4,
93
+ weight_decay=0.01,
94
+ save_total_limit=3,
95
+ num_train_epochs=3,
96
+ predict_with_generate=True,
97
+ no_cuda=not torch.cuda.is_available() # Added to handle cases when no GPU is available
98
+ )
99
+
100
+ # Define data collator
101
+ data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)
102
+
103
+ # Create trainer
104
+ trainer = Seq2SeqTrainer(
105
+ model=model,
106
+ args=training_args,
107
+ train_dataset=tokenized_dataset,
108
+ data_collator=data_collator,
109
+ tokenizer=tokenizer,
110
+ )
111
+
112
+ # Start fine-tuning
113
+ trainer.train()
114
+
115
+ # Save fine-tuned model
116
+ model.save_pretrained("./fine_tuned_model")
117
+ tokenizer.save_pretrained("./fine_tuned_model")
118
+
119
+ return model, tokenizer
120
+
121
+ # Uncomment the line below to run fine-tuning
122
+ # model, tokenizer = fine_tune_model()
123
+
124
+ ## Text to Diagram Function
125
+
126
+ def get_entity_relationship_diagram():
127
+ """
128
+ Return a predefined entity relationship diagram for a blog system
129
+ """
130
+ return """erDiagram
131
+ AUTHOR ||--o{ POST : writes
132
+ POST ||--o{ COMMENT : contains
133
+ USER ||--o{ COMMENT : writes
134
+ USER ||--o{ AUTHOR : can_be
135
+ POST }|--|| CATEGORY : belongs_to
136
+ """
137
+
138
+ def get_flowchart_diagram():
139
+ """
140
+ Return a predefined flowchart diagram
141
+ """
142
+ return """graph TD
143
+ A[Start] --> B{User has account?}
144
+ B -->|Yes| C[Enter credentials]
145
+ B -->|No| D[Register]
146
+ C --> E{Valid credentials?}
147
+ E -->|Yes| F[Login successful]
148
+ E -->|No| G[Login failed]
149
+ D --> C
150
+ """
151
+
152
+ def get_sequence_diagram():
153
+ """
154
+ Return a predefined sequence diagram
155
+ """
156
+ return """sequenceDiagram
157
+ participant User
158
+ participant System
159
+ participant Database
160
+ User->>System: Request data
161
+ System->>Database: Query data
162
+ Database->>System: Return results
163
+ System->>User: Display results
164
+ """
165
+
166
+ def text_to_diagram(text_description):
167
+ """
168
+ Convert text description to a diagram using pattern matching or model
169
+ """
170
+ # For demonstration, use pattern matching for common cases
171
+ lower_text = text_description.lower()
172
+
173
+ # Pattern match common diagram types based on the input text
174
+ if "entity" in lower_text and "relation" in lower_text and "blog" in lower_text:
175
+ diagram_code = get_entity_relationship_diagram()
176
+ elif "flow" in lower_text and "login" in lower_text:
177
+ diagram_code = get_flowchart_diagram()
178
+ elif "sequence" in lower_text and "client" in lower_text and "server" in lower_text:
179
+ diagram_code = get_sequence_diagram()
180
+ else:
181
+ # Use the model for other cases
182
+ try:
183
+ # Tokenize input text
184
+ inputs = tokenizer(text_description, return_tensors="pt", max_length=128, truncation=True).to(device)
185
+
186
+ # Generate diagram code
187
+ outputs = model.generate(
188
+ inputs["input_ids"],
189
+ max_length=256,
190
+ num_beams=5,
191
+ early_stopping=True
192
+ )
193
+
194
+ # Decode the outputs to get the mermaid diagram code
195
+ diagram_code = tokenizer.decode(outputs[0], skip_special_tokens=True)
196
+
197
+ # For non-fine-tuned models, the output is unlikely to be valid Mermaid code
198
+ # So we'll apply pattern matching to generate appropriate Mermaid code
199
+ if "flowchart" in lower_text or "flow" in lower_text:
200
+ diagram_code = """graph TD
201
+ A[Start] --> B[Process]
202
+ B --> C[End]
203
+ """
204
+ elif "sequence" in lower_text:
205
+ diagram_code = """sequenceDiagram
206
+ participant A
207
+ participant B
208
+ A->>B: Message
209
+ B->>A: Response
210
+ """
211
+ elif "entity" in lower_text or "er" in lower_text:
212
+ diagram_code = """erDiagram
213
+ ENTITY1 ||--o{ ENTITY2 : relates
214
+ """
215
+ else:
216
+ # Default to a simple flowchart
217
+ diagram_code = """graph TD
218
+ A[Start] --> B[Process]
219
+ B --> C[End]
220
+ """
221
+ except Exception as e:
222
+ print(f"Error generating diagram code: {e}")
223
+ # Fallback to a simple diagram
224
+ diagram_code = """graph TD
225
+ A[Error] --> B[Could not generate diagram]
226
+ """
227
+
228
+ # Render the diagram to an image
229
+ try:
230
+ # Use Mermaid.ink API to render the diagram
231
+ img_url = render_mermaid_to_url(diagram_code)
232
+
233
+ # Download the image and convert to a data URL for Gradio
234
+ try:
235
+ response = requests.get(img_url, timeout=10)
236
+ if response.status_code == 200:
237
+ image_data = response.content
238
+ # Save temporarily to a file that Gradio can display
239
+ temp_img_path = "temp_diagram.png" # Fixed filename for simplicity
240
+ with open(temp_img_path, "wb") as f:
241
+ f.write(image_data)
242
+ return diagram_code, temp_img_path
243
+ else:
244
+ return diagram_code, None
245
+ except Exception as e:
246
+ print(f"Error downloading image: {e}")
247
+ return diagram_code, None
248
+ except Exception as e:
249
+ print(f"Error rendering diagram: {e}")
250
+ return diagram_code, None
251
+
252
+ def render_mermaid_to_url(mermaid_code):
253
+ """
254
+ Render mermaid code to an image URL using the Mermaid.live API
255
+ """
256
+ try:
257
+ # Encode the mermaid code to be used in a URL
258
+ encoded_code = base64.urlsafe_b64encode(mermaid_code.encode()).decode()
259
+
260
+ # Generate a URL for the Mermaid.ink service
261
+ mermaid_url = f"https://mermaid.ink/img/{encoded_code}"
262
+
263
+ return mermaid_url
264
+ except Exception as e:
265
+ print(f"Error encoding mermaid code: {e}")
266
+ # Return a fallback URL or None
267
+ return None
268
+
269
+ ## Gradio Interface
270
+
271
+ def gradio_interface(text_input):
272
+ """
273
+ Process user input and return diagram output via Gradio
274
+ """
275
+ try:
276
+ diagram_code, img_path = text_to_diagram(text_input)
277
+
278
+ # Display the diagram code for debugging
279
+ print("Generated diagram code:")
280
+ print(diagram_code)
281
+
282
+ if img_path:
283
+ print(f"Image saved to: {img_path}")
284
+ return diagram_code, img_path
285
+ else:
286
+ # If image generation failed, return code only
287
+ return diagram_code, None
288
+ except Exception as e:
289
+ print(f"Error in Gradio interface: {e}")
290
+ return f"Error generating diagram: {str(e)}", None
291
+
292
+ # Create the Gradio interface with error handling
293
+ iface = gr.Interface(
294
+ fn=gradio_interface,
295
+ inputs=gr.Textbox(lines=5, placeholder="Enter your diagram description here..."),
296
+ outputs=[
297
+ gr.Textbox(label="Generated Mermaid Code"),
298
+ gr.Image(label="Diagram Visualization", type="filepath")
299
+ ],
300
+ title="Text to Diagram Converter",
301
+ description="Convert natural language descriptions to diagrams using AI",
302
+ examples=[
303
+ ["A flowchart showing user login process with success and failure paths"],
304
+ ["A sequence diagram showing client-server authentication"],
305
+ ["A simple entity relationship diagram for a blog system"]
306
+ ],
307
+ allow_flagging="never"
308
+ )
309
+
310
+ # Launch the interface
311
+ iface.launch()