Spaces:
Sleeping
Sleeping
| import torch | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, GemmaTokenizer | |
| import gradio as gr | |
| from gradio.themes.base import Base | |
| from gradio.themes.utils import colors, fonts, sizes | |
| from typing import Iterable | |
| class SQLGEN(Base): | |
| def __init__( | |
| self, | |
| *, | |
| primary_hue: colors.Color | str = colors.stone, | |
| secondary_hue: colors.Color | str = colors.green, | |
| neutral_hue: colors.Color | str = colors.gray, | |
| spacing_size: sizes.Size | str = sizes.spacing_md, | |
| radius_size: sizes.Size | str = sizes.radius_md, | |
| text_size: sizes.Size | str = sizes.text_lg, | |
| font: fonts.Font | |
| | str | |
| | Iterable[fonts.Font | str] = ( | |
| fonts.GoogleFont("IBM Plex Mono"), | |
| "ui-sans-serif", | |
| "sans-serif", | |
| ), | |
| font_mono: fonts.Font | |
| | str | |
| | Iterable[fonts.Font | str] = ( | |
| fonts.GoogleFont("IBM Plex Mono"), | |
| "ui-monospace", | |
| "monospace", | |
| ), | |
| ): | |
| super().__init__( | |
| primary_hue=primary_hue, | |
| secondary_hue=secondary_hue, | |
| neutral_hue=neutral_hue, | |
| spacing_size=spacing_size, | |
| radius_size=radius_size, | |
| text_size=text_size, | |
| font=font, | |
| font_mono=font_mono, | |
| ) | |
| model_id = "alibidaran/Gemma2_SQLGEN" | |
| #bnb_config for GPU usage | |
| #bnb_config = BitsAndBytesConfig( | |
| # load_in_4bit=True, | |
| # bnb_4bit_quant_type="nf4", | |
| # bnb_4bit_compute_dtype=torch.bfloat16 | |
| #) | |
| tokenizer = AutoTokenizer.from_pretrained(model_id) | |
| model = AutoModelForCausalLM.from_pretrained(model_id, device_map='auto') | |
| tokenizer.padding_side = 'right' | |
| def generate_sql(query,context): | |
| prompt = query | |
| context=context | |
| text=f"<s>##Question: {prompt} \n ##Context: {context} \n ##Answer:" | |
| inputs=tokenizer(text,return_tensors='pt').to('cpu') | |
| with torch.no_grad(): | |
| outputs=model.generate(**inputs,max_new_tokens=100,do_sample=True,top_p=0.99,top_k=10,temperature=0.5) | |
| output_text=outputs[:, inputs.input_ids.shape[1]:] | |
| output_text=tokenizer.decode(output_text[0], skip_special_tokens=True) | |
| return output_text | |
| interface=gr.Interface(generate_sql,['text','text'],gr.Code(),title='SQLGEN', theme=SQLGEN()) | |
| if __name__=='__main__': | |
| interface.launch() |