gokul-pv commited on
Commit
2532a2f
·
1 Parent(s): 02060d5

initial version of the app

Browse files
Files changed (2) hide show
  1. app.py +138 -0
  2. requirements.txt +225 -0
app.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer
4
+
5
+ # Initialize model and tokenizer
6
+ MODEL_PATH = "gokul-pv/Llama-3.2-1B-Instruct-16bit-TeSO"
7
+
8
+ def load_model():
9
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
10
+ model = AutoModelForCausalLM.from_pretrained(
11
+ MODEL_PATH,
12
+ torch_dtype=torch.float32, # Use float32 for CPU
13
+ device_map="cpu" # Ensure model runs on CPU
14
+ )
15
+ return model, tokenizer
16
+
17
+ class CustomTextStreamer:
18
+ """Custom streamer that captures only the model's response"""
19
+ def __init__(self, tokenizer):
20
+ self.tokenizer = tokenizer
21
+ self.generated_text = []
22
+ self.next_tokens_are_prompt = True
23
+
24
+ def put(self, value):
25
+ if isinstance(value, torch.Tensor):
26
+ if len(value.shape) > 1:
27
+ value = value[0]
28
+ decoded_text = self.tokenizer.decode(value.tolist(), skip_special_tokens=True)
29
+ else:
30
+ decoded_text = value
31
+
32
+ if self.next_tokens_are_prompt:
33
+ self.next_tokens_are_prompt = False # Skip prompt tokens
34
+ else:
35
+ self.generated_text.append(decoded_text)
36
+ print(decoded_text, end="", flush=True)
37
+
38
+ def end(self):
39
+ self.next_tokens_are_prompt = True
40
+ print("")
41
+
42
+ def get_generated_text(self):
43
+ return "".join(self.generated_text)
44
+
45
+ def analyze_architecture(code_input, temperature=1.5, max_tokens=512):
46
+ """
47
+ Analyze architecture code using the loaded model
48
+ """
49
+ model, tokenizer = load_model()
50
+
51
+ messages = [
52
+ {
53
+ "role": "system",
54
+ "content": "You are an expert in analyzing system architecture written using code. "
55
+ "You check the architecture and provide clear and detailed explanations "
56
+ "regarding how the architecture can be improved for better performance, "
57
+ "scalability, maintainability, and cost-effectiveness. You also check "
58
+ "for possible cybersecurity issues and if the components can be "
59
+ "replaced with newer and better components."
60
+ },
61
+ {
62
+ "role": "user",
63
+ "content": code_input
64
+ }
65
+ ]
66
+
67
+ # Tokenize input
68
+ inputs = tokenizer.apply_chat_template(
69
+ messages,
70
+ tokenize=True,
71
+ add_generation_prompt=True,
72
+ return_tensors="pt"
73
+ ).to("cpu") # Ensure tensors are on CPU
74
+
75
+ # Initialize text streamer
76
+ text_streamer = CustomTextStreamer(tokenizer)
77
+
78
+ # Generate response
79
+ with torch.inference_mode():
80
+ model.generate(
81
+ input_ids=inputs,
82
+ streamer=text_streamer,
83
+ max_new_tokens=max_tokens,
84
+ use_cache=True,
85
+ temperature=temperature,
86
+ min_p=0.1
87
+ )
88
+
89
+ return text_streamer.get_generated_text()
90
+
91
+ # Create Gradio interface
92
+ def create_gradio_interface():
93
+ with gr.Blocks() as demo:
94
+ gr.Markdown("# Tech Stack Optimizer - TeSO")
95
+
96
+ with gr.Row():
97
+ with gr.Column():
98
+ code_input = gr.Code(
99
+ label="Input Architecture Code",
100
+ language="python",
101
+ lines=10
102
+ )
103
+
104
+ with gr.Row():
105
+ temperature = gr.Slider(
106
+ minimum=0.1,
107
+ maximum=2.0,
108
+ value=1.5,
109
+ label="Temperature"
110
+ )
111
+ max_tokens = gr.Slider(
112
+ minimum=64,
113
+ maximum=2048,
114
+ value=512,
115
+ step=64,
116
+ label="Max Tokens"
117
+ )
118
+
119
+ submit_btn = gr.Button("Analyze Architecture")
120
+
121
+ with gr.Column():
122
+ output = gr.Markdown(label="Analysis Results")
123
+
124
+ submit_btn.click(
125
+ fn=analyze_architecture,
126
+ inputs=[code_input, temperature, max_tokens],
127
+ outputs=output
128
+ )
129
+
130
+ return demo
131
+
132
+ if __name__ == "__main__":
133
+ demo = create_gradio_interface()
134
+ demo.launch(
135
+ share=True, # Enable sharing
136
+ server_name="0.0.0.0", # Listen on all network interfaces
137
+ server_port=7860 # Default Gradio port
138
+ )
requirements.txt ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ absl-py==2.1.0
2
+ accelerate==1.3.0
3
+ aiofiles==23.2.1
4
+ aiohappyeyeballs==2.4.4
5
+ aiohttp==3.11.11
6
+ aiosignal==1.3.2
7
+ annotated-types==0.7.0
8
+ anyascii==0.3.2
9
+ anyio==4.8.0
10
+ argon2-cffi==23.1.0
11
+ argon2-cffi-bindings==21.2.0
12
+ arrow==1.3.0
13
+ asttokens==3.0.0
14
+ async-lru==2.0.4
15
+ async-timeout==5.0.1
16
+ attrs==24.3.0
17
+ babel==2.16.0
18
+ backoff==2.2.1
19
+ beautifulsoup4==4.12.3
20
+ bitsandbytes==0.45.2
21
+ bleach==6.2.0
22
+ boto3==1.36.3
23
+ botocore==1.36.3
24
+ cachetools==5.5.1
25
+ certifi==2024.12.14
26
+ cffi==1.17.1
27
+ charset-normalizer==3.4.1
28
+ click==8.1.8
29
+ comm==0.2.2
30
+ contourpy==1.3.1
31
+ contractions==0.1.73
32
+ cut-cross-entropy==25.1.1
33
+ cycler==0.12.1
34
+ datasets==3.2.0
35
+ debugpy==1.8.12
36
+ decorator==5.1.1
37
+ defusedxml==0.7.1
38
+ dill==0.3.8
39
+ docstring_parser==0.16
40
+ exceptiongroup==1.2.2
41
+ executing==2.1.0
42
+ fastapi==0.115.6
43
+ fastjsonschema==2.21.1
44
+ ffmpy==0.5.0
45
+ filelock==3.17.0
46
+ fire==0.7.0
47
+ fonttools==4.55.4
48
+ fqdn==1.5.1
49
+ frozenlist==1.5.0
50
+ fsspec==2024.9.0
51
+ google-auth==2.37.0
52
+ google-auth-oauthlib==1.2.1
53
+ gradio==5.16.0
54
+ gradio_client==1.7.0
55
+ grpcio==1.69.0
56
+ h11==0.14.0
57
+ hf_transfer==0.1.9
58
+ httpcore==1.0.7
59
+ httptools==0.6.4
60
+ httpx==0.28.1
61
+ huggingface-hub==0.28.1
62
+ idna==3.10
63
+ ipykernel==6.26.0
64
+ ipython==8.17.2
65
+ ipywidgets==8.1.1
66
+ isoduration==20.11.0
67
+ jedi==0.19.2
68
+ Jinja2==3.1.5
69
+ jmespath==1.0.1
70
+ joblib==1.4.2
71
+ json5==0.10.0
72
+ jsonpointer==3.0.0
73
+ jsonschema==4.23.0
74
+ jsonschema-specifications==2024.10.1
75
+ jupyter-events==0.11.0
76
+ jupyter-lsp==2.2.5
77
+ jupyter_client==8.6.3
78
+ jupyter_core==5.7.2
79
+ jupyter_server==2.15.0
80
+ jupyter_server_terminals==0.5.3
81
+ jupyterlab==4.2.0
82
+ jupyterlab_pygments==0.3.0
83
+ jupyterlab_server==2.27.3
84
+ jupyterlab_widgets==3.0.13
85
+ kiwisolver==1.4.8
86
+ lightning==2.5.0.post0
87
+ lightning-cloud==0.5.70
88
+ lightning-utilities==0.11.9
89
+ lightning_sdk==0.1.49
90
+ litdata==0.2.32
91
+ litserve==0.2.6
92
+ Markdown==3.7
93
+ markdown-it-py==3.0.0
94
+ MarkupSafe==2.1.5
95
+ matplotlib==3.8.2
96
+ matplotlib-inline==0.1.7
97
+ mdurl==0.1.2
98
+ mistune==3.1.0
99
+ mpmath==1.3.0
100
+ multidict==6.1.0
101
+ multiprocess==0.70.16
102
+ nbclient==0.10.2
103
+ nbconvert==7.16.5
104
+ nbformat==5.10.4
105
+ nest-asyncio==1.6.0
106
+ networkx==3.4.2
107
+ notebook_shim==0.2.4
108
+ numpy==1.26.4
109
+ nvidia-cublas-cu12==12.1.3.1
110
+ nvidia-cuda-cupti-cu12==12.1.105
111
+ nvidia-cuda-nvrtc-cu12==12.1.105
112
+ nvidia-cuda-runtime-cu12==12.1.105
113
+ nvidia-cudnn-cu12==8.9.2.26
114
+ nvidia-cufft-cu12==11.0.2.54
115
+ nvidia-curand-cu12==10.3.2.106
116
+ nvidia-cusolver-cu12==11.4.5.107
117
+ nvidia-cusparse-cu12==12.1.0.106
118
+ nvidia-nccl-cu12==2.19.3
119
+ nvidia-nvjitlink-cu12==12.6.85
120
+ nvidia-nvtx-cu12==12.1.105
121
+ oauthlib==3.2.2
122
+ orjson==3.10.15
123
+ overrides==7.7.0
124
+ packaging==24.2
125
+ pandas==2.1.4
126
+ pandocfilters==1.5.1
127
+ parso==0.8.4
128
+ peft==0.14.0
129
+ pexpect==4.9.0
130
+ pillow==11.1.0
131
+ platformdirs==4.3.6
132
+ prometheus_client==0.21.1
133
+ prompt_toolkit==3.0.50
134
+ propcache==0.2.1
135
+ protobuf==3.20.3
136
+ psutil==6.1.1
137
+ ptyprocess==0.7.0
138
+ pure_eval==0.2.3
139
+ pyahocorasick==2.1.0
140
+ pyarrow==19.0.0
141
+ pyasn1==0.6.1
142
+ pyasn1_modules==0.4.1
143
+ pycparser==2.22
144
+ pydantic==2.10.5
145
+ pydantic_core==2.27.2
146
+ pydub==0.25.1
147
+ Pygments==2.19.1
148
+ PyJWT==2.10.1
149
+ pyparsing==3.2.1
150
+ python-dateutil==2.9.0.post0
151
+ python-dotenv==1.0.1
152
+ python-json-logger==3.2.1
153
+ python-multipart==0.0.20
154
+ pytorch-lightning==2.5.0.post0
155
+ pytz==2024.2
156
+ PyYAML==6.0.2
157
+ pyzmq==26.2.0
158
+ referencing==0.36.1
159
+ regex==2024.11.6
160
+ requests==2.32.3
161
+ requests-oauthlib==2.0.0
162
+ rfc3339-validator==0.1.4
163
+ rfc3986-validator==0.1.1
164
+ rich==13.9.4
165
+ rpds-py==0.22.3
166
+ rsa==4.9
167
+ ruff==0.9.6
168
+ s3transfer==0.11.1
169
+ safehttpx==0.1.6
170
+ safetensors==0.5.2
171
+ scikit-learn==1.3.2
172
+ scipy==1.11.4
173
+ seaborn==0.13.2
174
+ semantic-version==2.10.0
175
+ Send2Trash==1.8.3
176
+ sentencepiece==0.2.0
177
+ shellingham==1.5.4
178
+ shtab==1.7.1
179
+ simple-term-menu==1.6.6
180
+ six==1.17.0
181
+ sniffio==1.3.1
182
+ soupsieve==2.6
183
+ stack-data==0.6.3
184
+ starlette==0.41.3
185
+ sympy==1.13.3
186
+ tensorboard==2.15.1
187
+ tensorboard-data-server==0.7.2
188
+ termcolor==2.5.0
189
+ terminado==0.18.1
190
+ textsearch==0.0.24
191
+ threadpoolctl==3.5.0
192
+ tinycss2==1.4.0
193
+ tokenizers==0.21.0
194
+ tomli==2.2.1
195
+ tomlkit==0.13.2
196
+ torch==2.2.1+cu121
197
+ torchmetrics==1.3.1
198
+ torchvision==0.17.1+cu121
199
+ tornado==6.4.2
200
+ tqdm==4.67.1
201
+ traitlets==5.14.3
202
+ transformers==4.48.1
203
+ triton==2.2.0
204
+ trl==0.8.6
205
+ typeguard==4.4.1
206
+ typer==0.15.1
207
+ types-python-dateutil==2.9.0.20241206
208
+ typing_extensions==4.12.2
209
+ tyro==0.9.14
210
+ tzdata==2025.1
211
+ uri-template==1.3.0
212
+ urllib3==2.3.0
213
+ uvicorn==0.34.0
214
+ uvloop==0.21.0
215
+ watchfiles==1.0.4
216
+ wcwidth==0.2.13
217
+ webcolors==24.11.1
218
+ webencodings==0.5.1
219
+ websocket-client==1.8.0
220
+ websockets==14.2
221
+ Werkzeug==3.1.3
222
+ widgetsnbextension==4.0.13
223
+ xformers==0.0.29
224
+ xxhash==3.5.0
225
+ yarl==1.18.3