Spaces:
Sleeping
Sleeping
| # app.py | |
| import torch | |
| from transformers import AutoTokenizer, EncoderDecoderModel | |
| import gradio as gr | |
| from spaces import GPU | |
| # デバイス設定 (Spacesのハードウェア設定に依存) | |
| # SpacesでGPUを利用する場合、自動的にCUDAが利用可能になります | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print(f"Using device: {device}") # デバイス確認用ログ | |
| model_name = "Shuu12121/CodeEncoderDecoderModel-Ghost-large" | |
| print(f"Loading model: {model_name}") # モデル読み込み開始ログ | |
| # --- Tokenizerの読み込み --- | |
| try: | |
| # subfolder引数を使用してサブディレクトリを指定 | |
| encoder_tokenizer = AutoTokenizer.from_pretrained(model_name, subfolder="encoder_tokenizer") | |
| decoder_tokenizer = AutoTokenizer.from_pretrained(model_name, subfolder="decoder_tokenizer") | |
| print("Tokenizers loaded successfully.") | |
| except Exception as e: | |
| print(f"Error loading tokenizers: {e}") | |
| raise # ここではエラーを再発生させて、起動を停止させます | |
| # decoder_tokenizerのpad_token設定 | |
| if decoder_tokenizer.pad_token is None: | |
| if decoder_tokenizer.eos_token is not None: | |
| decoder_tokenizer.pad_token = decoder_tokenizer.eos_token | |
| print("Set decoder pad_token to eos_token.") | |
| else: | |
| # eos_tokenもない場合の代替処理(例: '<pad>'トークンを追加) | |
| decoder_tokenizer.add_special_tokens({'pad_token': '<pad>'}) | |
| print("Added '<pad>' as pad_token.") | |
| # モデルのリサイズが必要になる場合がある | |
| # model.resize_token_embeddings(len(decoder_tokenizer)) # 必要に応じて | |
| # --- モデルの読み込み --- | |
| try: | |
| # モデルの読み込みは通常通りリポジトリ名を指定すればOK | |
| # config.jsonが適切に設定されていれば、エンコーダー/デコーダー部分は自動的に読み込まれる | |
| model = EncoderDecoderModel.from_pretrained(model_name).to(device) | |
| model.eval() # 評価モードに設定 | |
| print("Model loaded successfully and moved to device.") | |
| except Exception as e: | |
| print(f"Error loading model: {e}") | |
| raise | |
| # --- Docstring生成関数 --- | |
| def generate_docstring(code: str) -> str: | |
| print("Received code snippet for docstring generation.") # 関数呼び出しログ | |
| if not code: | |
| return "Please provide a code snippet." | |
| try: | |
| # エンコーダー入力の準備 | |
| inputs = encoder_tokenizer( | |
| code, | |
| return_tensors="pt", | |
| padding=True, | |
| truncation=True, | |
| max_length=2048 # モデルが許容する最大長に合わせる(必要なら調整) | |
| ).to(device) | |
| print(f"Input tokens length: {inputs.input_ids.shape[1]}") | |
| # 生成実行 | |
| with torch.no_grad(): | |
| # pad_token_idを明示的に指定 (重要: Noneでないことを確認) | |
| pad_token_id = decoder_tokenizer.pad_token_id if decoder_tokenizer.pad_token_id is not None else decoder_tokenizer.eos_token_id | |
| output_ids = model.generate( | |
| input_ids=inputs.input_ids, | |
| attention_mask=inputs.attention_mask, | |
| max_length=256, | |
| num_beams=10, | |
| early_stopping=True, | |
| eos_token_id=decoder_tokenizer.eos_token_id, | |
| pad_token_id=pad_token_id, | |
| no_repeat_ngram_size=3, | |
| bad_words_ids=decoder_tokenizer(["sexual", "abuse", "child"], add_special_tokens=False).input_ids | |
| ) | |
| print(f"Generated output tokens length: {output_ids.shape[1]}") | |
| # デコードしてテキストに変換 | |
| generated_docstring = decoder_tokenizer.decode(output_ids[0], skip_special_tokens=True) | |
| print("Docstring generated successfully.") | |
| return generated_docstring | |
| except Exception as e: | |
| print(f"Error during generation: {e}") | |
| # ユーザーにエラーを通知 | |
| return f"An error occurred during generation: {e}" | |
| # --- Gradio UI --- | |
| iface = gr.Interface( | |
| fn=generate_docstring, | |
| inputs=gr.Textbox( | |
| label="Code Snippet", | |
| lines=10, | |
| placeholder="Paste your Python function or code block here...", | |
| value="""public static String readFileToString(File file, Charset encoding) throws IOException { | |
| try (BufferedReader reader = new BufferedReader(new InputStreamReader(new FileInputStream(file), encoding))) { | |
| StringBuilder sb = new StringBuilder(); | |
| String line; | |
| while ((line = reader.readLine()) != null) { | |
| sb.append(line).append("\\n"); | |
| } | |
| return sb.toString(); | |
| } | |
| }""" | |
| ), | |
| outputs=gr.Textbox(label="Generated Docstring"), | |
| title="Code-to-Docstring Generator (Shuu12121/CodeEncoderDecoderModel-Ghost)", | |
| description="This demo uses the Shuu12121/CodeEncoderDecoderModel-Ghost model to automatically generate Python docstrings from code snippets. Paste your code below and click 'Submit'." | |
| ) | |
| # --- アプリケーションの起動 --- | |
| # Hugging Face Spacesで実行する場合、share=Trueは不要 | |
| if __name__ == "__main__": | |
| print("Launching Gradio interface...") | |
| iface.launch() | |
| print("Gradio interface launched.") |